source: branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/LeafModels/DampenedModel.cs @ 15967

Last change on this file since 15967 was 15967, checked in by bwerth, 12 months ago

#2847 added logistic dampening and some minor changes

File size: 4.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
27using HeuristicLab.Problems.DataAnalysis;
28
29namespace HeuristicLab.Algorithms.DataAnalysis {
30  //mulitdimensional extension of http://www2.stat.duke.edu/~tjl13/s101/slides/unit6lec3H.pdf
31  [StorableClass]
32  public class DampenedModel : RegressionModel {
33    [Storable]
34    protected IRegressionModel Model;
35    [Storable]
36    private double Min;
37    [Storable]
38    private double Max;
39    [Storable]
40    private double Dampening;
41
42    [StorableConstructor]
43    protected DampenedModel(bool deserializing) : base(deserializing) { }
44    protected DampenedModel(DampenedModel original, Cloner cloner) : base(original, cloner) {
45      Model = cloner.Clone(original.Model);
46      Min = original.Min;
47      Max = original.Max;
48      Dampening = original.Dampening;
49    }
50    protected DampenedModel(IRegressionModel model, IRegressionProblemData pd, double dampening) : base(model.TargetVariable) {
51      Model = model;
52      Min = pd.TargetVariableTrainingValues.Min();
53      Max = pd.TargetVariableTrainingValues.Max();
54      Dampening = dampening;
55    }
56    public override IDeepCloneable Clone(Cloner cloner) {
57      return new DampenedModel(this, cloner);
58    }
59
60    public static IConfidenceRegressionModel DampenModel(IConfidenceRegressionModel model, IRegressionProblemData pd, double dampening) {
61      return new ConfidenceDampenedModel(model, pd, dampening);
62    }
63    public static IRegressionModel DampenModel(IRegressionModel model, IRegressionProblemData pd, double dampening) {
64      var cmodel = model as IConfidenceRegressionModel;
65      return cmodel != null ? new ConfidenceDampenedModel(cmodel, pd, dampening) : new DampenedModel(model, pd, dampening);
66    }
67
68    public override IEnumerable<string> VariablesUsedForPrediction {
69      get { return Model.VariablesUsedForPrediction; }
70    }
71    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
72      var slow = Sigmoid(-Dampening);
73      var shigh = Sigmoid(Dampening);
74      foreach (var x in Model.GetEstimatedValues(dataset, rows)) {
75        var y = Rescale(x, Min, Max, -Dampening, Dampening);
76        y = Sigmoid(y);
77        y = Rescale(y, slow, shigh, Min, Max);
78        yield return y;
79      }
80    }
81    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
82      return new RegressionSolution(this, problemData);
83    }
84
85    private static double Rescale(double x, double oMin, double oMax, double nMin, double nMax) {
86      var d = oMax - oMin;
87      var nd = nMax - nMin;
88      if (d.IsAlmost(0)) {
89        d = 1;
90        nMin += nd / 2;
91        nd = 0;
92      }
93      return ((x - oMin) / d) * nd + nMin;
94    }
95    private static double Sigmoid(double x) {
96      return 1 / (1 + Math.Exp(-x));
97    }
98
99
100    [StorableClass]
101    private sealed class ConfidenceDampenedModel : DampenedModel, IConfidenceRegressionModel {
102      #region HLConstructors
103      [StorableConstructor]
104      private ConfidenceDampenedModel(bool deserializing) : base(deserializing) { }
105      private ConfidenceDampenedModel(ConfidenceDampenedModel original, Cloner cloner) : base(original, cloner) { }
106      public ConfidenceDampenedModel(IConfidenceRegressionModel model, IRegressionProblemData pd, double dampening) : base(model, pd, dampening) { }
107      public override IDeepCloneable Clone(Cloner cloner) {
108        return new ConfidenceDampenedModel(this, cloner);
109      }
110      #endregion
111
112      public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
113        return ((IConfidenceRegressionModel)Model).GetEstimatedVariances(dataset, rows);
114      }
115
116      public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
117        return new ConfidenceRegressionSolution(this, problemData);
118      }
119    }
120  }
121}
Note: See TracBrowser for help on using the repository browser.