Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs @ 12973

Last change on this file since 12973 was 12509, checked in by mkommend, 9 years ago

#2276: Reintegrated branch for dataset refactoring.

File size: 22.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using AutoDiff;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
34  [Item("Constant Optimization Evaluator", "Calculates Pearson R² of a symbolic regression solution and optimizes the constant used.")]
35  [StorableClass]
36  public class SymbolicRegressionConstantOptimizationEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
37    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
38    private const string ConstantOptimizationImprovementParameterName = "ConstantOptimizationImprovement";
39    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
40    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
41    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
42
43    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
44      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
45    }
46    public IFixedValueParameter<DoubleValue> ConstantOptimizationImprovementParameter {
47      get { return (IFixedValueParameter<DoubleValue>)Parameters[ConstantOptimizationImprovementParameterName]; }
48    }
49    public IFixedValueParameter<PercentValue> ConstantOptimizationProbabilityParameter {
50      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationProbabilityParameterName]; }
51    }
52    public IFixedValueParameter<PercentValue> ConstantOptimizationRowsPercentageParameter {
53      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
54    }
55    public IFixedValueParameter<BoolValue> UpdateConstantsInTreeParameter {
56      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
57    }
58
59    public IntValue ConstantOptimizationIterations {
60      get { return ConstantOptimizationIterationsParameter.Value; }
61    }
62    public DoubleValue ConstantOptimizationImprovement {
63      get { return ConstantOptimizationImprovementParameter.Value; }
64    }
65    public PercentValue ConstantOptimizationProbability {
66      get { return ConstantOptimizationProbabilityParameter.Value; }
67    }
68    public PercentValue ConstantOptimizationRowsPercentage {
69      get { return ConstantOptimizationRowsPercentageParameter.Value; }
70    }
71    public bool UpdateConstantsInTree {
72      get { return UpdateConstantsInTreeParameter.Value.Value; }
73      set { UpdateConstantsInTreeParameter.Value.Value = value; }
74    }
75
76    public override bool Maximization {
77      get { return true; }
78    }
79
80    [StorableConstructor]
81    protected SymbolicRegressionConstantOptimizationEvaluator(bool deserializing) : base(deserializing) { }
82    protected SymbolicRegressionConstantOptimizationEvaluator(SymbolicRegressionConstantOptimizationEvaluator original, Cloner cloner)
83      : base(original, cloner) {
84    }
85    public SymbolicRegressionConstantOptimizationEvaluator()
86      : base() {
87      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true));
88      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true));
89      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
90      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
91      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
92    }
93
94    public override IDeepCloneable Clone(Cloner cloner) {
95      return new SymbolicRegressionConstantOptimizationEvaluator(this, cloner);
96    }
97
98    [StorableHook(HookType.AfterDeserialization)]
99    private void AfterDeserialization() {
100      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
101        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
102    }
103
104    public override IOperation InstrumentedApply() {
105      var solution = SymbolicExpressionTreeParameter.ActualValue;
106      double quality;
107      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
108        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
109        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
110           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value,
111           EstimationLimitsParameter.ActualValue.Upper, EstimationLimitsParameter.ActualValue.Lower, UpdateConstantsInTree);
112
113        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
114          var evaluationRows = GenerateRowsToEvaluate();
115          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
116        }
117      } else {
118        var evaluationRows = GenerateRowsToEvaluate();
119        quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
120      }
121      QualityParameter.ActualValue = new DoubleValue(quality);
122
123      return base.InstrumentedApply();
124    }
125
126    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows) {
127      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
128      EstimationLimitsParameter.ExecutionContext = context;
129      ApplyLinearScalingParameter.ExecutionContext = context;
130
131      // Pearson R² evaluator is used on purpose instead of the const-opt evaluator,
132      // because Evaluate() is used to get the quality of evolved models on
133      // different partitions of the dataset (e.g., best validation model)
134      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value);
135
136      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
137      EstimationLimitsParameter.ExecutionContext = null;
138      ApplyLinearScalingParameter.ExecutionContext = null;
139
140      return r2;
141    }
142
143    #region derivations of functions
144    // create function factory for arctangent
145    private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
146      eval: Math.Atan,
147      diff: x => 1 / (1 + x * x));
148    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
149      eval: Math.Sin,
150      diff: Math.Cos);
151    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
152       eval: Math.Cos,
153       diff: x => -Math.Sin(x));
154    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
155      eval: Math.Tan,
156      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
157    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
158      eval: alglib.errorfunction,
159      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
160    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
161      eval: alglib.normaldistribution,
162      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
163    #endregion
164
165
166    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData,
167      IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, double upperEstimationLimit = double.MaxValue, double lowerEstimationLimit = double.MinValue, bool updateConstantsInTree = true) {
168
169      List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
170      List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
171      List<string> variableNames = new List<string>();
172
173      AutoDiff.Term func;
174      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, out func))
175        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
176      if (variableNames.Count == 0) return 0.0;
177
178      AutoDiff.IParametricCompiledTerm compiledFunc = AutoDiff.TermUtils.Compile(func, variables.ToArray(), parameters.ToArray());
179
180      List<SymbolicExpressionTreeTerminalNode> terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
181      double[] c = new double[variables.Count];
182
183      {
184        c[0] = 0.0;
185        c[1] = 1.0;
186        //extract inital constants
187        int i = 2;
188        foreach (var node in terminalNodes) {
189          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
190          VariableTreeNode variableTreeNode = node as VariableTreeNode;
191          if (constantTreeNode != null)
192            c[i++] = constantTreeNode.Value;
193          else if (variableTreeNode != null)
194            c[i++] = variableTreeNode.Weight;
195        }
196      }
197      double[] originalConstants = (double[])c.Clone();
198      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
199
200      alglib.lsfitstate state;
201      alglib.lsfitreport rep;
202      int info;
203
204      IDataset ds = problemData.Dataset;
205      double[,] x = new double[rows.Count(), variableNames.Count];
206      int row = 0;
207      foreach (var r in rows) {
208        for (int col = 0; col < variableNames.Count; col++) {
209          x[row, col] = ds.GetDoubleValue(variableNames[col], r);
210        }
211        row++;
212      }
213      double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
214      int n = x.GetLength(0);
215      int m = x.GetLength(1);
216      int k = c.Length;
217
218      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
219      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
220
221      try {
222        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
223        alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
224        //alglib.lsfitsetgradientcheck(state, 0.001);
225        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
226        alglib.lsfitresults(state, out info, out c, out rep);
227      }
228      catch (ArithmeticException) {
229        return originalQuality;
230      }
231      catch (alglib.alglibexception) {
232        return originalQuality;
233      }
234
235      //info == -7  => constant optimization failed due to wrong gradient
236      if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray());
237      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
238
239      if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray());
240      if (originalQuality - quality > 0.001 || double.IsNaN(quality)) {
241        UpdateConstants(tree, originalConstants.Skip(2).ToArray());
242        return originalQuality;
243      }
244      return quality;
245    }
246
247    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) {
248      int i = 0;
249      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
250        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
251        VariableTreeNode variableTreeNode = node as VariableTreeNode;
252        if (constantTreeNode != null)
253          constantTreeNode.Value = constants[i++];
254        else if (variableTreeNode != null)
255          variableTreeNode.Weight = constants[i++];
256      }
257    }
258
259    private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
260      return (double[] c, double[] x, ref double func, object o) => {
261        func = compiledFunc.Evaluate(c, x);
262      };
263    }
264
265    private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
266      return (double[] c, double[] x, ref double func, double[] grad, object o) => {
267        var tupel = compiledFunc.Differentiate(c, x);
268        func = tupel.Item2;
269        Array.Copy(tupel.Item1, grad, grad.Length);
270      };
271    }
272
273    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, out AutoDiff.Term term) {
274      if (node.Symbol is Constant) {
275        var var = new AutoDiff.Variable();
276        variables.Add(var);
277        term = var;
278        return true;
279      }
280      if (node.Symbol is Variable) {
281        var varNode = node as VariableTreeNode;
282        var par = new AutoDiff.Variable();
283        parameters.Add(par);
284        variableNames.Add(varNode.VariableName);
285        var w = new AutoDiff.Variable();
286        variables.Add(w);
287        term = AutoDiff.TermBuilder.Product(w, par);
288        return true;
289      }
290      if (node.Symbol is Addition) {
291        List<AutoDiff.Term> terms = new List<Term>();
292        foreach (var subTree in node.Subtrees) {
293          AutoDiff.Term t;
294          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out t)) {
295            term = null;
296            return false;
297          }
298          terms.Add(t);
299        }
300        term = AutoDiff.TermBuilder.Sum(terms);
301        return true;
302      }
303      if (node.Symbol is Subtraction) {
304        List<AutoDiff.Term> terms = new List<Term>();
305        for (int i = 0; i < node.SubtreeCount; i++) {
306          AutoDiff.Term t;
307          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, out t)) {
308            term = null;
309            return false;
310          }
311          if (i > 0) t = -t;
312          terms.Add(t);
313        }
314        term = AutoDiff.TermBuilder.Sum(terms);
315        return true;
316      }
317      if (node.Symbol is Multiplication) {
318        AutoDiff.Term a, b;
319        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
320          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
321          term = null;
322          return false;
323        } else {
324          List<AutoDiff.Term> factors = new List<Term>();
325          foreach (var subTree in node.Subtrees.Skip(2)) {
326            AutoDiff.Term f;
327            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
328              term = null;
329              return false;
330            }
331            factors.Add(f);
332          }
333          term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray());
334          return true;
335        }
336      }
337      if (node.Symbol is Division) {
338        // only works for at least two subtrees
339        AutoDiff.Term a, b;
340        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
341          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
342          term = null;
343          return false;
344        } else {
345          List<AutoDiff.Term> factors = new List<Term>();
346          foreach (var subTree in node.Subtrees.Skip(2)) {
347            AutoDiff.Term f;
348            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
349              term = null;
350              return false;
351            }
352            factors.Add(1.0 / f);
353          }
354          term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray());
355          return true;
356        }
357      }
358      if (node.Symbol is Logarithm) {
359        AutoDiff.Term t;
360        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
361          term = null;
362          return false;
363        } else {
364          term = AutoDiff.TermBuilder.Log(t);
365          return true;
366        }
367      }
368      if (node.Symbol is Exponential) {
369        AutoDiff.Term t;
370        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
371          term = null;
372          return false;
373        } else {
374          term = AutoDiff.TermBuilder.Exp(t);
375          return true;
376        }
377      }
378      if (node.Symbol is Square) {
379        AutoDiff.Term t;
380        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
381          term = null;
382          return false;
383        } else {
384          term = AutoDiff.TermBuilder.Power(t, 2.0);
385          return true;
386        }
387      } if (node.Symbol is SquareRoot) {
388        AutoDiff.Term t;
389        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
390          term = null;
391          return false;
392        } else {
393          term = AutoDiff.TermBuilder.Power(t, 0.5);
394          return true;
395        }
396      } if (node.Symbol is Sine) {
397        AutoDiff.Term t;
398        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
399          term = null;
400          return false;
401        } else {
402          term = sin(t);
403          return true;
404        }
405      } if (node.Symbol is Cosine) {
406        AutoDiff.Term t;
407        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
408          term = null;
409          return false;
410        } else {
411          term = cos(t);
412          return true;
413        }
414      } if (node.Symbol is Tangent) {
415        AutoDiff.Term t;
416        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
417          term = null;
418          return false;
419        } else {
420          term = tan(t);
421          return true;
422        }
423      } if (node.Symbol is Erf) {
424        AutoDiff.Term t;
425        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
426          term = null;
427          return false;
428        } else {
429          term = erf(t);
430          return true;
431        }
432      } if (node.Symbol is Norm) {
433        AutoDiff.Term t;
434        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
435          term = null;
436          return false;
437        } else {
438          term = norm(t);
439          return true;
440        }
441      }
442      if (node.Symbol is StartSymbol) {
443        var alpha = new AutoDiff.Variable();
444        var beta = new AutoDiff.Variable();
445        variables.Add(beta);
446        variables.Add(alpha);
447        AutoDiff.Term branchTerm;
448        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out branchTerm)) {
449          term = branchTerm * alpha + beta;
450          return true;
451        } else {
452          term = null;
453          return false;
454        }
455      }
456      term = null;
457      return false;
458    }
459
460    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
461      var containsUnknownSymbol = (
462        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
463        where
464         !(n.Symbol is Variable) &&
465         !(n.Symbol is Constant) &&
466         !(n.Symbol is Addition) &&
467         !(n.Symbol is Subtraction) &&
468         !(n.Symbol is Multiplication) &&
469         !(n.Symbol is Division) &&
470         !(n.Symbol is Logarithm) &&
471         !(n.Symbol is Exponential) &&
472         !(n.Symbol is SquareRoot) &&
473         !(n.Symbol is Square) &&
474         !(n.Symbol is Sine) &&
475         !(n.Symbol is Cosine) &&
476         !(n.Symbol is Tangent) &&
477         !(n.Symbol is Erf) &&
478         !(n.Symbol is Norm) &&
479         !(n.Symbol is StartSymbol)
480        select n).
481      Any();
482      return !containsUnknownSymbol;
483    }
484  }
485}
Note: See TracBrowser for help on using the repository browser.