Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs @ 8984

Last change on this file since 8984 was 8984, checked in by mkommend, 10 years ago

#1976: Returned original quality in case of an exception in the SymbolicRegressionConstantOptimizationEvaluator.

File size: 21.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using AutoDiff;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
34  [Item("Constant Optimization Evaluator", "Calculates Pearson R² of a symbolic regression solution and optimizes the constant used.")]
35  [StorableClass]
36  public class SymbolicRegressionConstantOptimizationEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
37    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
38    private const string ConstantOptimizationImprovementParameterName = "ConstantOptimizationImprovement";
39    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
40    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
41    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
42
43    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
44      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
45    }
46    public IFixedValueParameter<DoubleValue> ConstantOptimizationImprovementParameter {
47      get { return (IFixedValueParameter<DoubleValue>)Parameters[ConstantOptimizationImprovementParameterName]; }
48    }
49    public IFixedValueParameter<PercentValue> ConstantOptimizationProbabilityParameter {
50      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationProbabilityParameterName]; }
51    }
52    public IFixedValueParameter<PercentValue> ConstantOptimizationRowsPercentageParameter {
53      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
54    }
55    public IFixedValueParameter<BoolValue> UpdateConstantsInTreeParameter {
56      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
57    }
58
59    public IntValue ConstantOptimizationIterations {
60      get { return ConstantOptimizationIterationsParameter.Value; }
61    }
62    public DoubleValue ConstantOptimizationImprovement {
63      get { return ConstantOptimizationImprovementParameter.Value; }
64    }
65    public PercentValue ConstantOptimizationProbability {
66      get { return ConstantOptimizationProbabilityParameter.Value; }
67    }
68    public PercentValue ConstantOptimizationRowsPercentage {
69      get { return ConstantOptimizationRowsPercentageParameter.Value; }
70    }
71    public bool UpdateConstantsInTree {
72      get { return UpdateConstantsInTreeParameter.Value.Value; }
73      set { UpdateConstantsInTreeParameter.Value.Value = value; }
74    }
75
76    public override bool Maximization {
77      get { return true; }
78    }
79
80    [StorableConstructor]
81    protected SymbolicRegressionConstantOptimizationEvaluator(bool deserializing) : base(deserializing) { }
82    protected SymbolicRegressionConstantOptimizationEvaluator(SymbolicRegressionConstantOptimizationEvaluator original, Cloner cloner)
83      : base(original, cloner) {
84    }
85    public SymbolicRegressionConstantOptimizationEvaluator()
86      : base() {
87      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true));
88      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true));
89      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
90      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
91      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
92    }
93
94    public override IDeepCloneable Clone(Cloner cloner) {
95      return new SymbolicRegressionConstantOptimizationEvaluator(this, cloner);
96    }
97
98    [StorableHook(HookType.AfterDeserialization)]
99    private void AfterDeserialization() {
100      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
101        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
102    }
103
104    public override IOperation Apply() {
105      var solution = SymbolicExpressionTreeParameter.ActualValue;
106      double quality;
107      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
108        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
109        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
110           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value,
111           EstimationLimitsParameter.ActualValue.Upper, EstimationLimitsParameter.ActualValue.Lower, UpdateConstantsInTree);
112
113        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
114          var evaluationRows = GenerateRowsToEvaluate();
115          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
116        }
117      } else {
118        var evaluationRows = GenerateRowsToEvaluate();
119        quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
120      }
121      QualityParameter.ActualValue = new DoubleValue(quality);
122
123      if (Successor != null)
124        return ExecutionContext.CreateOperation(Successor);
125      else
126        return null;
127    }
128
129    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows) {
130      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
131      EstimationLimitsParameter.ExecutionContext = context;
132      ApplyLinearScalingParameter.ExecutionContext = context;
133
134      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value);
135
136      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
137      EstimationLimitsParameter.ExecutionContext = null;
138      ApplyLinearScalingParameter.ExecutionContext = context;
139
140      return r2;
141    }
142
143    #region derivations of functions
144    // create function factory for arctangent
145    private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
146      eval: Math.Atan,
147      diff: x => 1 / (1 + x * x));
148    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
149      eval: Math.Sin,
150      diff: Math.Cos);
151    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
152       eval: Math.Cos,
153       diff: x => -Math.Sin(x));
154    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
155      eval: Math.Tan,
156      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
157    private static readonly Func<Term, UnaryFunc> square = UnaryFunc.Factory(
158       eval: x => x * x,
159       diff: x => 2 * x);
160    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
161      eval: alglib.errorfunction,
162      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
163    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
164      eval: alglib.normaldistribution,
165      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
166    #endregion
167
168
169    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData,
170      IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, double upperEstimationLimit = double.MaxValue, double lowerEstimationLimit = double.MinValue, bool updateConstantsInTree = true) {
171
172      List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
173      List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
174      List<string> variableNames = new List<string>();
175
176      AutoDiff.Term func;
177      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, out func))
178        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
179      if (variableNames.Count == 0) return 0.0;
180
181      AutoDiff.IParametricCompiledTerm compiledFunc = AutoDiff.TermUtils.Compile(func, variables.ToArray(), parameters.ToArray());
182
183      List<SymbolicExpressionTreeTerminalNode> terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
184      double[] c = new double[variables.Count];
185
186      {
187        c[0] = 0.0;
188        c[1] = 1.0;
189        //extract inital constants
190        int i = 2;
191        foreach (var node in terminalNodes) {
192          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
193          VariableTreeNode variableTreeNode = node as VariableTreeNode;
194          if (constantTreeNode != null)
195            c[i++] = constantTreeNode.Value;
196          else if (variableTreeNode != null)
197            c[i++] = variableTreeNode.Weight;
198        }
199      }
200      double[] originalConstants = (double[])c.Clone();
201      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
202
203      alglib.lsfitstate state;
204      alglib.lsfitreport rep;
205      int info;
206
207      Dataset ds = problemData.Dataset;
208      double[,] x = new double[rows.Count(), variableNames.Count];
209      int row = 0;
210      foreach (var r in rows) {
211        for (int col = 0; col < variableNames.Count; col++) {
212          x[row, col] = ds.GetDoubleValue(variableNames[col], r);
213        }
214        row++;
215      }
216      double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
217      int n = x.GetLength(0);
218      int m = x.GetLength(1);
219      int k = c.Length;
220
221      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
222      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
223
224      try {
225        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
226        alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
227        //alglib.lsfitsetgradientcheck(state, 0.001);
228        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
229        alglib.lsfitresults(state, out info, out c, out rep);
230      }
231      catch (ArithmeticException) {
232        return originalQuality;
233      }
234      catch (alglib.alglibexception) {
235        return originalQuality;
236      }
237
238      //info == -7  => constant optimization failed due to wrong gradient
239      if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray());
240      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
241
242      if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray());
243      if (originalQuality - quality > 0.001 || double.IsNaN(quality)) {
244        UpdateConstants(tree, originalConstants.Skip(2).ToArray());
245        return originalQuality;
246      }
247      return quality;
248    }
249
250    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) {
251      int i = 0;
252      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
253        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
254        VariableTreeNode variableTreeNode = node as VariableTreeNode;
255        if (constantTreeNode != null)
256          constantTreeNode.Value = constants[i++];
257        else if (variableTreeNode != null)
258          variableTreeNode.Weight = constants[i++];
259      }
260    }
261
262    private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
263      return (double[] c, double[] x, ref double func, object o) => {
264        func = compiledFunc.Evaluate(c, x);
265      };
266    }
267
268    private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
269      return (double[] c, double[] x, ref double func, double[] grad, object o) => {
270        var tupel = compiledFunc.Differentiate(c, x);
271        func = tupel.Item2;
272        Array.Copy(tupel.Item1, grad, grad.Length);
273      };
274    }
275
276    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, out AutoDiff.Term term) {
277      if (node.Symbol is Constant) {
278        var var = new AutoDiff.Variable();
279        variables.Add(var);
280        term = var;
281        return true;
282      }
283      if (node.Symbol is Variable) {
284        var varNode = node as VariableTreeNode;
285        var par = new AutoDiff.Variable();
286        parameters.Add(par);
287        variableNames.Add(varNode.VariableName);
288        var w = new AutoDiff.Variable();
289        variables.Add(w);
290        term = AutoDiff.TermBuilder.Product(w, par);
291        return true;
292      }
293      if (node.Symbol is Addition) {
294        List<AutoDiff.Term> terms = new List<Term>();
295        foreach (var subTree in node.Subtrees) {
296          AutoDiff.Term t;
297          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out t)) {
298            term = null;
299            return false;
300          }
301          terms.Add(t);
302        }
303        term = AutoDiff.TermBuilder.Sum(terms);
304        return true;
305      }
306      if (node.Symbol is Subtraction) {
307        List<AutoDiff.Term> terms = new List<Term>();
308        for (int i = 0; i < node.SubtreeCount; i++) {
309          AutoDiff.Term t;
310          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, out t)) {
311            term = null;
312            return false;
313          }
314          if (i > 0) t = -t;
315          terms.Add(t);
316        }
317        term = AutoDiff.TermBuilder.Sum(terms);
318        return true;
319      }
320      if (node.Symbol is Multiplication) {
321        AutoDiff.Term a, b;
322        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
323          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
324          term = null;
325          return false;
326        } else {
327          List<AutoDiff.Term> factors = new List<Term>();
328          foreach (var subTree in node.Subtrees.Skip(2)) {
329            AutoDiff.Term f;
330            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
331              term = null;
332              return false;
333            }
334            factors.Add(f);
335          }
336          term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray());
337          return true;
338        }
339      }
340      if (node.Symbol is Division) {
341        // only works for at least two subtrees
342        AutoDiff.Term a, b;
343        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
344          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
345          term = null;
346          return false;
347        } else {
348          List<AutoDiff.Term> factors = new List<Term>();
349          foreach (var subTree in node.Subtrees.Skip(2)) {
350            AutoDiff.Term f;
351            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
352              term = null;
353              return false;
354            }
355            factors.Add(1.0 / f);
356          }
357          term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray());
358          return true;
359        }
360      }
361      if (node.Symbol is Logarithm) {
362        AutoDiff.Term t;
363        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
364          term = null;
365          return false;
366        } else {
367          term = AutoDiff.TermBuilder.Log(t);
368          return true;
369        }
370      }
371      if (node.Symbol is Exponential) {
372        AutoDiff.Term t;
373        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
374          term = null;
375          return false;
376        } else {
377          term = AutoDiff.TermBuilder.Exp(t);
378          return true;
379        }
380      } if (node.Symbol is Sine) {
381        AutoDiff.Term t;
382        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
383          term = null;
384          return false;
385        } else {
386          term = sin(t);
387          return true;
388        }
389      } if (node.Symbol is Cosine) {
390        AutoDiff.Term t;
391        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
392          term = null;
393          return false;
394        } else {
395          term = cos(t);
396          return true;
397        }
398      } if (node.Symbol is Tangent) {
399        AutoDiff.Term t;
400        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
401          term = null;
402          return false;
403        } else {
404          term = tan(t);
405          return true;
406        }
407      }
408      if (node.Symbol is Square) {
409        AutoDiff.Term t;
410        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
411          term = null;
412          return false;
413        } else {
414          term = square(t);
415          return true;
416        }
417      } if (node.Symbol is Erf) {
418        AutoDiff.Term t;
419        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
420          term = null;
421          return false;
422        } else {
423          term = erf(t);
424          return true;
425        }
426      } if (node.Symbol is Norm) {
427        AutoDiff.Term t;
428        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
429          term = null;
430          return false;
431        } else {
432          term = norm(t);
433          return true;
434        }
435      }
436      if (node.Symbol is StartSymbol) {
437        var alpha = new AutoDiff.Variable();
438        var beta = new AutoDiff.Variable();
439        variables.Add(beta);
440        variables.Add(alpha);
441        AutoDiff.Term branchTerm;
442        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out branchTerm)) {
443          term = branchTerm * alpha + beta;
444          return true;
445        } else {
446          term = null;
447          return false;
448        }
449      }
450      term = null;
451      return false;
452    }
453
454    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
455      var containsUnknownSymbol = (
456        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
457        where
458         !(n.Symbol is Variable) &&
459         !(n.Symbol is Constant) &&
460         !(n.Symbol is Addition) &&
461         !(n.Symbol is Subtraction) &&
462         !(n.Symbol is Multiplication) &&
463         !(n.Symbol is Division) &&
464         !(n.Symbol is Logarithm) &&
465         !(n.Symbol is Exponential) &&
466         !(n.Symbol is Sine) &&
467         !(n.Symbol is Cosine) &&
468         !(n.Symbol is Tangent) &&
469         !(n.Symbol is Square) &&
470         !(n.Symbol is Erf) &&
471         !(n.Symbol is Norm) &&
472         !(n.Symbol is StartSymbol)
473        select n).
474      Any();
475      return !containsUnknownSymbol;
476    }
477  }
478}
Note: See TracBrowser for help on using the repository browser.