Free cookie consent management tool by TermsFeed Policy Generator

source: branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SymbolicDataAnalysisModel.cs @ 15580

Last change on this file since 15580 was 14029, checked in by gkronber, 8 years ago

#2434: merged trunk changes r12934:14026 from trunk to branch

File size: 7.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  /// <summary>
33  /// Abstract base class for symbolic data analysis models
34  /// </summary>
35  [StorableClass]
36  public abstract class SymbolicDataAnalysisModel : NamedItem, ISymbolicDataAnalysisModel {
37    public static new Image StaticItemImage {
38      get { return HeuristicLab.Common.Resources.VSImageLibrary.Function; }
39    }
40
41    #region properties
42    [Storable]
43    private double lowerEstimationLimit;
44    public double LowerEstimationLimit { get { return lowerEstimationLimit; } }
45    [Storable]
46    private double upperEstimationLimit;
47    public double UpperEstimationLimit { get { return upperEstimationLimit; } }
48
49    [Storable]
50    private ISymbolicExpressionTree symbolicExpressionTree;
51    public ISymbolicExpressionTree SymbolicExpressionTree {
52      get { return symbolicExpressionTree; }
53    }
54
55    [Storable]
56    private ISymbolicDataAnalysisExpressionTreeInterpreter interpreter;
57    public ISymbolicDataAnalysisExpressionTreeInterpreter Interpreter {
58      get { return interpreter; }
59    }
60
61    public IEnumerable<string> VariablesUsedForPrediction {
62      get {
63        var variables =
64          SymbolicExpressionTree.IterateNodesPrefix()
65            .OfType<VariableTreeNode>()
66            .Select(x => x.VariableName)
67            .Distinct();
68        var variableConditions = SymbolicExpressionTree.IterateNodesPrefix()
69          .OfType<VariableConditionTreeNode>().Select(x => x.VariableName).Distinct();
70
71        return variables.Union(variableConditions).OrderBy(x => x);
72      }
73    }
74
75    #endregion
76
77    [StorableConstructor]
78    protected SymbolicDataAnalysisModel(bool deserializing) : base(deserializing) { }
79    protected SymbolicDataAnalysisModel(SymbolicDataAnalysisModel original, Cloner cloner)
80      : base(original, cloner) {
81      this.symbolicExpressionTree = cloner.Clone(original.symbolicExpressionTree);
82      this.interpreter = cloner.Clone(original.interpreter);
83      this.lowerEstimationLimit = original.lowerEstimationLimit;
84      this.upperEstimationLimit = original.upperEstimationLimit;
85    }
86    protected SymbolicDataAnalysisModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
87       double lowerEstimationLimit, double upperEstimationLimit)
88      : base() {
89      this.name = ItemName;
90      this.description = ItemDescription;
91      this.symbolicExpressionTree = tree;
92      this.interpreter = interpreter;
93      this.lowerEstimationLimit = lowerEstimationLimit;
94      this.upperEstimationLimit = upperEstimationLimit;
95    }
96
97    #region Scaling
98    protected void Scale(IDataAnalysisProblemData problemData, string targetVariable) {
99      var dataset = problemData.Dataset;
100      var rows = problemData.TrainingIndices;
101      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows);
102      var targetValues = dataset.GetDoubleValues(targetVariable, rows);
103
104      var linearScalingCalculator = new OnlineLinearScalingParameterCalculator();
105      var targetValuesEnumerator = targetValues.GetEnumerator();
106      var estimatedValuesEnumerator = estimatedValues.GetEnumerator();
107      while (targetValuesEnumerator.MoveNext() & estimatedValuesEnumerator.MoveNext()) {
108        double target = targetValuesEnumerator.Current;
109        double estimated = estimatedValuesEnumerator.Current;
110        if (!double.IsNaN(estimated) && !double.IsInfinity(estimated))
111          linearScalingCalculator.Add(estimated, target);
112      }
113      if (linearScalingCalculator.ErrorState == OnlineCalculatorError.None && (targetValuesEnumerator.MoveNext() || estimatedValuesEnumerator.MoveNext()))
114        throw new ArgumentException("Number of elements in target and estimated values enumeration do not match.");
115
116      double alpha = linearScalingCalculator.Alpha;
117      double beta = linearScalingCalculator.Beta;
118      if (linearScalingCalculator.ErrorState != OnlineCalculatorError.None) return;
119
120      ConstantTreeNode alphaTreeNode = null;
121      ConstantTreeNode betaTreeNode = null;
122      // check if model has been scaled previously by analyzing the structure of the tree
123      var startNode = SymbolicExpressionTree.Root.GetSubtree(0);
124      if (startNode.GetSubtree(0).Symbol is Addition) {
125        var addNode = startNode.GetSubtree(0);
126        if (addNode.SubtreeCount == 2 && addNode.GetSubtree(0).Symbol is Multiplication && addNode.GetSubtree(1).Symbol is Constant) {
127          alphaTreeNode = addNode.GetSubtree(1) as ConstantTreeNode;
128          var mulNode = addNode.GetSubtree(0);
129          if (mulNode.SubtreeCount == 2 && mulNode.GetSubtree(1).Symbol is Constant) {
130            betaTreeNode = mulNode.GetSubtree(1) as ConstantTreeNode;
131          }
132        }
133      }
134      // if tree structure matches the structure necessary for linear scaling then reuse the existing tree nodes
135      if (alphaTreeNode != null && betaTreeNode != null) {
136        betaTreeNode.Value *= beta;
137        alphaTreeNode.Value *= beta;
138        alphaTreeNode.Value += alpha;
139      } else {
140        var mainBranch = startNode.GetSubtree(0);
141        startNode.RemoveSubtree(0);
142        var scaledMainBranch = MakeSum(MakeProduct(mainBranch, beta), alpha);
143        startNode.AddSubtree(scaledMainBranch);
144      }
145    }
146
147    private static ISymbolicExpressionTreeNode MakeSum(ISymbolicExpressionTreeNode treeNode, double alpha) {
148      if (alpha.IsAlmost(0.0)) {
149        return treeNode;
150      } else {
151        var addition = new Addition();
152        var node = addition.CreateTreeNode();
153        var alphaConst = MakeConstant(alpha);
154        node.AddSubtree(treeNode);
155        node.AddSubtree(alphaConst);
156        return node;
157      }
158    }
159
160    private static ISymbolicExpressionTreeNode MakeProduct(ISymbolicExpressionTreeNode treeNode, double beta) {
161      if (beta.IsAlmost(1.0)) {
162        return treeNode;
163      } else {
164        var multipliciation = new Multiplication();
165        var node = multipliciation.CreateTreeNode();
166        var betaConst = MakeConstant(beta);
167        node.AddSubtree(treeNode);
168        node.AddSubtree(betaConst);
169        return node;
170      }
171    }
172
173    private static ISymbolicExpressionTreeNode MakeConstant(double c) {
174      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
175      node.Value = c;
176      return node;
177    }
178    #endregion
179
180  }
181}
Note: See TracBrowser for help on using the repository browser.