Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs @ 8985

Last change on this file since 8985 was 8985, checked in by gkronber, 10 years ago

#1976 added comment about call to R² evaluator in const-opt evaluator

File size: 22.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using AutoDiff;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
34  [Item("Constant Optimization Evaluator", "Calculates Pearson R² of a symbolic regression solution and optimizes the constant used.")]
35  [StorableClass]
36  public class SymbolicRegressionConstantOptimizationEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
37    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
38    private const string ConstantOptimizationImprovementParameterName = "ConstantOptimizationImprovement";
39    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
40    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
41    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
42
43    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
44      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
45    }
46    public IFixedValueParameter<DoubleValue> ConstantOptimizationImprovementParameter {
47      get { return (IFixedValueParameter<DoubleValue>)Parameters[ConstantOptimizationImprovementParameterName]; }
48    }
49    public IFixedValueParameter<PercentValue> ConstantOptimizationProbabilityParameter {
50      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationProbabilityParameterName]; }
51    }
52    public IFixedValueParameter<PercentValue> ConstantOptimizationRowsPercentageParameter {
53      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
54    }
55    public IFixedValueParameter<BoolValue> UpdateConstantsInTreeParameter {
56      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
57    }
58
59    public IntValue ConstantOptimizationIterations {
60      get { return ConstantOptimizationIterationsParameter.Value; }
61    }
62    public DoubleValue ConstantOptimizationImprovement {
63      get { return ConstantOptimizationImprovementParameter.Value; }
64    }
65    public PercentValue ConstantOptimizationProbability {
66      get { return ConstantOptimizationProbabilityParameter.Value; }
67    }
68    public PercentValue ConstantOptimizationRowsPercentage {
69      get { return ConstantOptimizationRowsPercentageParameter.Value; }
70    }
71    public bool UpdateConstantsInTree {
72      get { return UpdateConstantsInTreeParameter.Value.Value; }
73      set { UpdateConstantsInTreeParameter.Value.Value = value; }
74    }
75
76    public override bool Maximization {
77      get { return true; }
78    }
79
80    [StorableConstructor]
81    protected SymbolicRegressionConstantOptimizationEvaluator(bool deserializing) : base(deserializing) { }
82    protected SymbolicRegressionConstantOptimizationEvaluator(SymbolicRegressionConstantOptimizationEvaluator original, Cloner cloner)
83      : base(original, cloner) {
84    }
85    public SymbolicRegressionConstantOptimizationEvaluator()
86      : base() {
87      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true));
88      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true));
89      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
90      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
91      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
92    }
93
94    public override IDeepCloneable Clone(Cloner cloner) {
95      return new SymbolicRegressionConstantOptimizationEvaluator(this, cloner);
96    }
97
98    [StorableHook(HookType.AfterDeserialization)]
99    private void AfterDeserialization() {
100      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
101        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
102    }
103
104    public override IOperation Apply() {
105      var solution = SymbolicExpressionTreeParameter.ActualValue;
106      double quality;
107      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
108        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
109        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
110           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value,
111           EstimationLimitsParameter.ActualValue.Upper, EstimationLimitsParameter.ActualValue.Lower, UpdateConstantsInTree);
112
113        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
114          var evaluationRows = GenerateRowsToEvaluate();
115          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
116        }
117      } else {
118        var evaluationRows = GenerateRowsToEvaluate();
119        quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
120      }
121      QualityParameter.ActualValue = new DoubleValue(quality);
122
123      if (Successor != null)
124        return ExecutionContext.CreateOperation(Successor);
125      else
126        return null;
127    }
128
129    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows) {
130      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
131      EstimationLimitsParameter.ExecutionContext = context;
132      ApplyLinearScalingParameter.ExecutionContext = context;
133
134    // Pearson R² evaluator is used on purpose instead of the const-opt evaluator,
135    // because Evaluate() is used to get the quality of evolved models on
136    // different partitions of the dataset (e.g., best validation model)
137      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value);
138
139      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
140      EstimationLimitsParameter.ExecutionContext = null;
141      ApplyLinearScalingParameter.ExecutionContext = context;
142
143      return r2;
144    }
145
146    #region derivations of functions
147    // create function factory for arctangent
148    private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
149      eval: Math.Atan,
150      diff: x => 1 / (1 + x * x));
151    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
152      eval: Math.Sin,
153      diff: Math.Cos);
154    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
155       eval: Math.Cos,
156       diff: x => -Math.Sin(x));
157    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
158      eval: Math.Tan,
159      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
160    private static readonly Func<Term, UnaryFunc> square = UnaryFunc.Factory(
161       eval: x => x * x,
162       diff: x => 2 * x);
163    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
164      eval: alglib.errorfunction,
165      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
166    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
167      eval: alglib.normaldistribution,
168      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
169    #endregion
170
171
172    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData,
173      IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, double upperEstimationLimit = double.MaxValue, double lowerEstimationLimit = double.MinValue, bool updateConstantsInTree = true) {
174
175      List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
176      List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
177      List<string> variableNames = new List<string>();
178
179      AutoDiff.Term func;
180      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, out func))
181        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
182      if (variableNames.Count == 0) return 0.0;
183
184      AutoDiff.IParametricCompiledTerm compiledFunc = AutoDiff.TermUtils.Compile(func, variables.ToArray(), parameters.ToArray());
185
186      List<SymbolicExpressionTreeTerminalNode> terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
187      double[] c = new double[variables.Count];
188
189      {
190        c[0] = 0.0;
191        c[1] = 1.0;
192        //extract inital constants
193        int i = 2;
194        foreach (var node in terminalNodes) {
195          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
196          VariableTreeNode variableTreeNode = node as VariableTreeNode;
197          if (constantTreeNode != null)
198            c[i++] = constantTreeNode.Value;
199          else if (variableTreeNode != null)
200            c[i++] = variableTreeNode.Weight;
201        }
202      }
203      double[] originalConstants = (double[])c.Clone();
204      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
205
206      alglib.lsfitstate state;
207      alglib.lsfitreport rep;
208      int info;
209
210      Dataset ds = problemData.Dataset;
211      double[,] x = new double[rows.Count(), variableNames.Count];
212      int row = 0;
213      foreach (var r in rows) {
214        for (int col = 0; col < variableNames.Count; col++) {
215          x[row, col] = ds.GetDoubleValue(variableNames[col], r);
216        }
217        row++;
218      }
219      double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
220      int n = x.GetLength(0);
221      int m = x.GetLength(1);
222      int k = c.Length;
223
224      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
225      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
226
227      try {
228        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
229        alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
230        //alglib.lsfitsetgradientcheck(state, 0.001);
231        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
232        alglib.lsfitresults(state, out info, out c, out rep);
233      }
234      catch (ArithmeticException) {
235        return originalQuality;
236      }
237      catch (alglib.alglibexception) {
238        return originalQuality;
239      }
240
241      //info == -7  => constant optimization failed due to wrong gradient
242      if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray());
243      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
244
245      if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray());
246      if (originalQuality - quality > 0.001 || double.IsNaN(quality)) {
247        UpdateConstants(tree, originalConstants.Skip(2).ToArray());
248        return originalQuality;
249      }
250      return quality;
251    }
252
253    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) {
254      int i = 0;
255      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
256        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
257        VariableTreeNode variableTreeNode = node as VariableTreeNode;
258        if (constantTreeNode != null)
259          constantTreeNode.Value = constants[i++];
260        else if (variableTreeNode != null)
261          variableTreeNode.Weight = constants[i++];
262      }
263    }
264
265    private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
266      return (double[] c, double[] x, ref double func, object o) => {
267        func = compiledFunc.Evaluate(c, x);
268      };
269    }
270
271    private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
272      return (double[] c, double[] x, ref double func, double[] grad, object o) => {
273        var tupel = compiledFunc.Differentiate(c, x);
274        func = tupel.Item2;
275        Array.Copy(tupel.Item1, grad, grad.Length);
276      };
277    }
278
279    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, out AutoDiff.Term term) {
280      if (node.Symbol is Constant) {
281        var var = new AutoDiff.Variable();
282        variables.Add(var);
283        term = var;
284        return true;
285      }
286      if (node.Symbol is Variable) {
287        var varNode = node as VariableTreeNode;
288        var par = new AutoDiff.Variable();
289        parameters.Add(par);
290        variableNames.Add(varNode.VariableName);
291        var w = new AutoDiff.Variable();
292        variables.Add(w);
293        term = AutoDiff.TermBuilder.Product(w, par);
294        return true;
295      }
296      if (node.Symbol is Addition) {
297        List<AutoDiff.Term> terms = new List<Term>();
298        foreach (var subTree in node.Subtrees) {
299          AutoDiff.Term t;
300          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out t)) {
301            term = null;
302            return false;
303          }
304          terms.Add(t);
305        }
306        term = AutoDiff.TermBuilder.Sum(terms);
307        return true;
308      }
309      if (node.Symbol is Subtraction) {
310        List<AutoDiff.Term> terms = new List<Term>();
311        for (int i = 0; i < node.SubtreeCount; i++) {
312          AutoDiff.Term t;
313          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, out t)) {
314            term = null;
315            return false;
316          }
317          if (i > 0) t = -t;
318          terms.Add(t);
319        }
320        term = AutoDiff.TermBuilder.Sum(terms);
321        return true;
322      }
323      if (node.Symbol is Multiplication) {
324        AutoDiff.Term a, b;
325        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
326          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
327          term = null;
328          return false;
329        } else {
330          List<AutoDiff.Term> factors = new List<Term>();
331          foreach (var subTree in node.Subtrees.Skip(2)) {
332            AutoDiff.Term f;
333            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
334              term = null;
335              return false;
336            }
337            factors.Add(f);
338          }
339          term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray());
340          return true;
341        }
342      }
343      if (node.Symbol is Division) {
344        // only works for at least two subtrees
345        AutoDiff.Term a, b;
346        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
347          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
348          term = null;
349          return false;
350        } else {
351          List<AutoDiff.Term> factors = new List<Term>();
352          foreach (var subTree in node.Subtrees.Skip(2)) {
353            AutoDiff.Term f;
354            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
355              term = null;
356              return false;
357            }
358            factors.Add(1.0 / f);
359          }
360          term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray());
361          return true;
362        }
363      }
364      if (node.Symbol is Logarithm) {
365        AutoDiff.Term t;
366        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
367          term = null;
368          return false;
369        } else {
370          term = AutoDiff.TermBuilder.Log(t);
371          return true;
372        }
373      }
374      if (node.Symbol is Exponential) {
375        AutoDiff.Term t;
376        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
377          term = null;
378          return false;
379        } else {
380          term = AutoDiff.TermBuilder.Exp(t);
381          return true;
382        }
383      } if (node.Symbol is Sine) {
384        AutoDiff.Term t;
385        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
386          term = null;
387          return false;
388        } else {
389          term = sin(t);
390          return true;
391        }
392      } if (node.Symbol is Cosine) {
393        AutoDiff.Term t;
394        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
395          term = null;
396          return false;
397        } else {
398          term = cos(t);
399          return true;
400        }
401      } if (node.Symbol is Tangent) {
402        AutoDiff.Term t;
403        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
404          term = null;
405          return false;
406        } else {
407          term = tan(t);
408          return true;
409        }
410      }
411      if (node.Symbol is Square) {
412        AutoDiff.Term t;
413        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
414          term = null;
415          return false;
416        } else {
417          term = square(t);
418          return true;
419        }
420      } if (node.Symbol is Erf) {
421        AutoDiff.Term t;
422        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
423          term = null;
424          return false;
425        } else {
426          term = erf(t);
427          return true;
428        }
429      } if (node.Symbol is Norm) {
430        AutoDiff.Term t;
431        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
432          term = null;
433          return false;
434        } else {
435          term = norm(t);
436          return true;
437        }
438      }
439      if (node.Symbol is StartSymbol) {
440        var alpha = new AutoDiff.Variable();
441        var beta = new AutoDiff.Variable();
442        variables.Add(beta);
443        variables.Add(alpha);
444        AutoDiff.Term branchTerm;
445        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out branchTerm)) {
446          term = branchTerm * alpha + beta;
447          return true;
448        } else {
449          term = null;
450          return false;
451        }
452      }
453      term = null;
454      return false;
455    }
456
457    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
458      var containsUnknownSymbol = (
459        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
460        where
461         !(n.Symbol is Variable) &&
462         !(n.Symbol is Constant) &&
463         !(n.Symbol is Addition) &&
464         !(n.Symbol is Subtraction) &&
465         !(n.Symbol is Multiplication) &&
466         !(n.Symbol is Division) &&
467         !(n.Symbol is Logarithm) &&
468         !(n.Symbol is Exponential) &&
469         !(n.Symbol is Sine) &&
470         !(n.Symbol is Cosine) &&
471         !(n.Symbol is Tangent) &&
472         !(n.Symbol is Square) &&
473         !(n.Symbol is Erf) &&
474         !(n.Symbol is Norm) &&
475         !(n.Symbol is StartSymbol)
476        select n).
477      Any();
478      return !containsUnknownSymbol;
479    }
480  }
481}
Note: See TracBrowser for help on using the repository browser.