Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ProblemRefactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs @ 13331

Last change on this file since 13331 was 13300, checked in by gkronber, 9 years ago

#2175 made some changes while reviewing the code

File size: 22.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using AutoDiff;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
34  [Item("Constant Optimization Evaluator", "Calculates Pearson R² of a symbolic regression solution and optimizes the constant used.")]
35  [StorableClass]
36  public class SymbolicRegressionConstantOptimizationEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
37    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
38    private const string ConstantOptimizationImprovementParameterName = "ConstantOptimizationImprovement";
39    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
40    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
41    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
42
43    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
44      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
45    }
46    public IFixedValueParameter<DoubleValue> ConstantOptimizationImprovementParameter {
47      get { return (IFixedValueParameter<DoubleValue>)Parameters[ConstantOptimizationImprovementParameterName]; }
48    }
49    public IFixedValueParameter<PercentValue> ConstantOptimizationProbabilityParameter {
50      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationProbabilityParameterName]; }
51    }
52    public IFixedValueParameter<PercentValue> ConstantOptimizationRowsPercentageParameter {
53      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
54    }
55    public IFixedValueParameter<BoolValue> UpdateConstantsInTreeParameter {
56      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
57    }
58
59    public IntValue ConstantOptimizationIterations {
60      get { return ConstantOptimizationIterationsParameter.Value; }
61    }
62    public DoubleValue ConstantOptimizationImprovement {
63      get { return ConstantOptimizationImprovementParameter.Value; }
64    }
65    public PercentValue ConstantOptimizationProbability {
66      get { return ConstantOptimizationProbabilityParameter.Value; }
67    }
68    public PercentValue ConstantOptimizationRowsPercentage {
69      get { return ConstantOptimizationRowsPercentageParameter.Value; }
70    }
71    public bool UpdateConstantsInTree {
72      get { return UpdateConstantsInTreeParameter.Value.Value; }
73      set { UpdateConstantsInTreeParameter.Value.Value = value; }
74    }
75
76    public override bool Maximization {
77      get { return true; }
78    }
79
80    [StorableConstructor]
81    protected SymbolicRegressionConstantOptimizationEvaluator(bool deserializing) : base(deserializing) { }
82    protected SymbolicRegressionConstantOptimizationEvaluator(SymbolicRegressionConstantOptimizationEvaluator original, Cloner cloner)
83      : base(original, cloner) {
84    }
85    public SymbolicRegressionConstantOptimizationEvaluator()
86      : base() {
87      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true));
88      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true));
89      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
90      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
91      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
92    }
93
94    public override IDeepCloneable Clone(Cloner cloner) {
95      return new SymbolicRegressionConstantOptimizationEvaluator(this, cloner);
96    }
97
98    [StorableHook(HookType.AfterDeserialization)]
99    private void AfterDeserialization() {
100      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
101        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
102    }
103
104    public override IOperation InstrumentedApply() {
105      var solution = SymbolicExpressionTreeParameter.ActualValue;
106      double quality;
107      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
108        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
109        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
110           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value,
111           EstimationLimitsParameter.ActualValue.Upper, EstimationLimitsParameter.ActualValue.Lower, UpdateConstantsInTree);
112
113        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
114          var evaluationRows = GenerateRowsToEvaluate();
115          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
116        }
117      } else {
118        var evaluationRows = GenerateRowsToEvaluate();
119        quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
120      }
121      QualityParameter.ActualValue = new DoubleValue(quality);
122
123      return base.InstrumentedApply();
124    }
125
126    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows) {
127      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
128      EstimationLimitsParameter.ExecutionContext = context;
129      ApplyLinearScalingParameter.ExecutionContext = context;
130
131      // Pearson R² evaluator is used on purpose instead of the const-opt evaluator,
132      // because Evaluate() is used to get the quality of evolved models on
133      // different partitions of the dataset (e.g., best validation model)
134      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value);
135
136      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
137      EstimationLimitsParameter.ExecutionContext = null;
138      ApplyLinearScalingParameter.ExecutionContext = null;
139
140      return r2;
141    }
142
143    #region derivations of functions
144    // create function factory for arctangent
145    private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
146      eval: Math.Atan,
147      diff: x => 1 / (1 + x * x));
148    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
149      eval: Math.Sin,
150      diff: Math.Cos);
151    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
152       eval: Math.Cos,
153       diff: x => -Math.Sin(x));
154    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
155      eval: Math.Tan,
156      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
157    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
158      eval: alglib.errorfunction,
159      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
160    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
161      eval: alglib.normaldistribution,
162      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
163    #endregion
164
165
166    // TODO: swap positions of lowerEstimationLimit and upperEstimationLimit parameters
167    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData,
168      IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, double upperEstimationLimit = double.MaxValue, double lowerEstimationLimit = double.MinValue, bool updateConstantsInTree = true) {
169
170      List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
171      List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
172      List<string> variableNames = new List<string>();
173
174      AutoDiff.Term func;
175      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, out func))
176        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
177      if (variableNames.Count == 0) return 0.0;
178
179      AutoDiff.IParametricCompiledTerm compiledFunc = AutoDiff.TermUtils.Compile(func, variables.ToArray(), parameters.ToArray());
180
181      List<SymbolicExpressionTreeTerminalNode> terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
182      double[] c = new double[variables.Count];
183
184      {
185        c[0] = 0.0;
186        c[1] = 1.0;
187        //extract inital constants
188        int i = 2;
189        foreach (var node in terminalNodes) {
190          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
191          VariableTreeNode variableTreeNode = node as VariableTreeNode;
192          if (constantTreeNode != null)
193            c[i++] = constantTreeNode.Value;
194          else if (variableTreeNode != null)
195            c[i++] = variableTreeNode.Weight;
196        }
197      }
198      double[] originalConstants = (double[])c.Clone();
199      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
200
201      alglib.lsfitstate state;
202      alglib.lsfitreport rep;
203      int info;
204
205      IDataset ds = problemData.Dataset;
206      double[,] x = new double[rows.Count(), variableNames.Count];
207      int row = 0;
208      foreach (var r in rows) {
209        for (int col = 0; col < variableNames.Count; col++) {
210          x[row, col] = ds.GetDoubleValue(variableNames[col], r);
211        }
212        row++;
213      }
214      double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
215      int n = x.GetLength(0);
216      int m = x.GetLength(1);
217      int k = c.Length;
218
219      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
220      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
221
222      try {
223        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
224        alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
225        //alglib.lsfitsetgradientcheck(state, 0.001);
226        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
227        alglib.lsfitresults(state, out info, out c, out rep);
228      }
229      catch (ArithmeticException) {
230        return originalQuality;
231      }
232      catch (alglib.alglibexception) {
233        return originalQuality;
234      }
235
236      //info == -7  => constant optimization failed due to wrong gradient
237      if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray());
238      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
239
240      if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray());
241      if (originalQuality - quality > 0.001 || double.IsNaN(quality)) {
242        UpdateConstants(tree, originalConstants.Skip(2).ToArray());
243        return originalQuality;
244      }
245      return quality;
246    }
247
248    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) {
249      int i = 0;
250      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
251        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
252        VariableTreeNode variableTreeNode = node as VariableTreeNode;
253        if (constantTreeNode != null)
254          constantTreeNode.Value = constants[i++];
255        else if (variableTreeNode != null)
256          variableTreeNode.Weight = constants[i++];
257      }
258    }
259
260    private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
261      return (double[] c, double[] x, ref double func, object o) => {
262        func = compiledFunc.Evaluate(c, x);
263      };
264    }
265
266    private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
267      return (double[] c, double[] x, ref double func, double[] grad, object o) => {
268        var tupel = compiledFunc.Differentiate(c, x);
269        func = tupel.Item2;
270        Array.Copy(tupel.Item1, grad, grad.Length);
271      };
272    }
273
274    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, out AutoDiff.Term term) {
275      if (node.Symbol is Constant) {
276        var var = new AutoDiff.Variable();
277        variables.Add(var);
278        term = var;
279        return true;
280      }
281      if (node.Symbol is Variable) {
282        var varNode = node as VariableTreeNode;
283        var par = new AutoDiff.Variable();
284        parameters.Add(par);
285        variableNames.Add(varNode.VariableName);
286        var w = new AutoDiff.Variable();
287        variables.Add(w);
288        term = AutoDiff.TermBuilder.Product(w, par);
289        return true;
290      }
291      if (node.Symbol is Addition) {
292        List<AutoDiff.Term> terms = new List<Term>();
293        foreach (var subTree in node.Subtrees) {
294          AutoDiff.Term t;
295          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out t)) {
296            term = null;
297            return false;
298          }
299          terms.Add(t);
300        }
301        term = AutoDiff.TermBuilder.Sum(terms);
302        return true;
303      }
304      if (node.Symbol is Subtraction) {
305        List<AutoDiff.Term> terms = new List<Term>();
306        for (int i = 0; i < node.SubtreeCount; i++) {
307          AutoDiff.Term t;
308          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, out t)) {
309            term = null;
310            return false;
311          }
312          if (i > 0) t = -t;
313          terms.Add(t);
314        }
315        term = AutoDiff.TermBuilder.Sum(terms);
316        return true;
317      }
318      if (node.Symbol is Multiplication) {
319        AutoDiff.Term a, b;
320        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
321          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
322          term = null;
323          return false;
324        } else {
325          List<AutoDiff.Term> factors = new List<Term>();
326          foreach (var subTree in node.Subtrees.Skip(2)) {
327            AutoDiff.Term f;
328            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
329              term = null;
330              return false;
331            }
332            factors.Add(f);
333          }
334          term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray());
335          return true;
336        }
337      }
338      if (node.Symbol is Division) {
339        // only works for at least two subtrees
340        AutoDiff.Term a, b;
341        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
342          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
343          term = null;
344          return false;
345        } else {
346          List<AutoDiff.Term> factors = new List<Term>();
347          foreach (var subTree in node.Subtrees.Skip(2)) {
348            AutoDiff.Term f;
349            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
350              term = null;
351              return false;
352            }
353            factors.Add(1.0 / f);
354          }
355          term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray());
356          return true;
357        }
358      }
359      if (node.Symbol is Logarithm) {
360        AutoDiff.Term t;
361        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
362          term = null;
363          return false;
364        } else {
365          term = AutoDiff.TermBuilder.Log(t);
366          return true;
367        }
368      }
369      if (node.Symbol is Exponential) {
370        AutoDiff.Term t;
371        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
372          term = null;
373          return false;
374        } else {
375          term = AutoDiff.TermBuilder.Exp(t);
376          return true;
377        }
378      }
379      if (node.Symbol is Square) {
380        AutoDiff.Term t;
381        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
382          term = null;
383          return false;
384        } else {
385          term = AutoDiff.TermBuilder.Power(t, 2.0);
386          return true;
387        }
388      } if (node.Symbol is SquareRoot) {
389        AutoDiff.Term t;
390        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
391          term = null;
392          return false;
393        } else {
394          term = AutoDiff.TermBuilder.Power(t, 0.5);
395          return true;
396        }
397      } if (node.Symbol is Sine) {
398        AutoDiff.Term t;
399        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
400          term = null;
401          return false;
402        } else {
403          term = sin(t);
404          return true;
405        }
406      } if (node.Symbol is Cosine) {
407        AutoDiff.Term t;
408        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
409          term = null;
410          return false;
411        } else {
412          term = cos(t);
413          return true;
414        }
415      } if (node.Symbol is Tangent) {
416        AutoDiff.Term t;
417        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
418          term = null;
419          return false;
420        } else {
421          term = tan(t);
422          return true;
423        }
424      } if (node.Symbol is Erf) {
425        AutoDiff.Term t;
426        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
427          term = null;
428          return false;
429        } else {
430          term = erf(t);
431          return true;
432        }
433      } if (node.Symbol is Norm) {
434        AutoDiff.Term t;
435        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
436          term = null;
437          return false;
438        } else {
439          term = norm(t);
440          return true;
441        }
442      }
443      if (node.Symbol is StartSymbol) {
444        var alpha = new AutoDiff.Variable();
445        var beta = new AutoDiff.Variable();
446        variables.Add(beta);
447        variables.Add(alpha);
448        AutoDiff.Term branchTerm;
449        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out branchTerm)) {
450          term = branchTerm * alpha + beta;
451          return true;
452        } else {
453          term = null;
454          return false;
455        }
456      }
457      term = null;
458      return false;
459    }
460
461    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
462      var containsUnknownSymbol = (
463        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
464        where
465         !(n.Symbol is Variable) &&
466         !(n.Symbol is Constant) &&
467         !(n.Symbol is Addition) &&
468         !(n.Symbol is Subtraction) &&
469         !(n.Symbol is Multiplication) &&
470         !(n.Symbol is Division) &&
471         !(n.Symbol is Logarithm) &&
472         !(n.Symbol is Exponential) &&
473         !(n.Symbol is SquareRoot) &&
474         !(n.Symbol is Square) &&
475         !(n.Symbol is Sine) &&
476         !(n.Symbol is Cosine) &&
477         !(n.Symbol is Tangent) &&
478         !(n.Symbol is Erf) &&
479         !(n.Symbol is Norm) &&
480         !(n.Symbol is StartSymbol)
481        select n).
482      Any();
483      return !containsUnknownSymbol;
484    }
485  }
486}
Note: See TracBrowser for help on using the repository browser.