Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs @ 8730

Last change on this file since 8730 was 8730, checked in by gkronber, 12 years ago

#1962: disabled optimized button if the model contains non-differentiable functions. Added support for exact differentiation for additional function symbols (sin, cos, tan, square, norm, erf)

File size: 21.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using AutoDiff;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
34  [Item("Constant Optimization Evaluator", "Calculates Pearson R² of a symbolic regression solution and optimizes the constant used.")]
35  [StorableClass]
36  public class SymbolicRegressionConstantOptimizationEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
37    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
38    private const string ConstantOptimizationImprovementParameterName = "ConstantOptimizationImprovement";
39    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
40    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
41
42    private const string EvaluatedTreesResultName = "EvaluatedTrees";
43    private const string EvaluatedTreeNodesResultName = "EvaluatedTreeNodes";
44
45    public ILookupParameter<IntValue> EvaluatedTreesParameter {
46      get { return (ILookupParameter<IntValue>)Parameters[EvaluatedTreesResultName]; }
47    }
48    public ILookupParameter<IntValue> EvaluatedTreeNodesParameter {
49      get { return (ILookupParameter<IntValue>)Parameters[EvaluatedTreeNodesResultName]; }
50    }
51
52    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
53      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
54    }
55    public IFixedValueParameter<DoubleValue> ConstantOptimizationImprovementParameter {
56      get { return (IFixedValueParameter<DoubleValue>)Parameters[ConstantOptimizationImprovementParameterName]; }
57    }
58    public IFixedValueParameter<PercentValue> ConstantOptimizationProbabilityParameter {
59      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationProbabilityParameterName]; }
60    }
61    public IFixedValueParameter<PercentValue> ConstantOptimizationRowsPercentageParameter {
62      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
63    }
64
65    public IntValue ConstantOptimizationIterations {
66      get { return ConstantOptimizationIterationsParameter.Value; }
67    }
68    public DoubleValue ConstantOptimizationImprovement {
69      get { return ConstantOptimizationImprovementParameter.Value; }
70    }
71    public PercentValue ConstantOptimizationProbability {
72      get { return ConstantOptimizationProbabilityParameter.Value; }
73    }
74    public PercentValue ConstantOptimizationRowsPercentage {
75      get { return ConstantOptimizationRowsPercentageParameter.Value; }
76    }
77
78    public override bool Maximization {
79      get { return true; }
80    }
81
82    [StorableConstructor]
83    protected SymbolicRegressionConstantOptimizationEvaluator(bool deserializing) : base(deserializing) { }
84    protected SymbolicRegressionConstantOptimizationEvaluator(SymbolicRegressionConstantOptimizationEvaluator original, Cloner cloner)
85      : base(original, cloner) {
86    }
87    public SymbolicRegressionConstantOptimizationEvaluator()
88      : base() {
89      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(3), true));
90      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true));
91      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
92      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
93
94      Parameters.Add(new LookupParameter<IntValue>(EvaluatedTreesResultName));
95      Parameters.Add(new LookupParameter<IntValue>(EvaluatedTreeNodesResultName));
96    }
97
98    public override IDeepCloneable Clone(Cloner cloner) {
99      return new SymbolicRegressionConstantOptimizationEvaluator(this, cloner);
100    }
101
102    public override IOperation Apply() {
103      AddResults();
104      var solution = SymbolicExpressionTreeParameter.ActualValue;
105      double quality;
106      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
107        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
108        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
109           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value,
110           EstimationLimitsParameter.ActualValue.Upper, EstimationLimitsParameter.ActualValue.Lower,
111          EvaluatedTreesParameter.ActualValue, EvaluatedTreeNodesParameter.ActualValue);
112        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
113          var evaluationRows = GenerateRowsToEvaluate();
114          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
115        }
116      } else {
117        var evaluationRows = GenerateRowsToEvaluate();
118        quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
119      }
120      QualityParameter.ActualValue = new DoubleValue(quality);
121      EvaluatedTreesParameter.ActualValue.Value += 1;
122      EvaluatedTreeNodesParameter.ActualValue.Value += solution.Length;
123
124      if (Successor != null)
125        return ExecutionContext.CreateOperation(Successor);
126      else
127        return null;
128    }
129
130    private void AddResults() {
131      if (EvaluatedTreesParameter.ActualValue == null) {
132        var scope = ExecutionContext.Scope;
133        while (scope.Parent != null)
134          scope = scope.Parent;
135        scope.Variables.Add(new Core.Variable(EvaluatedTreesResultName, new IntValue()));
136      }
137      if (EvaluatedTreeNodesParameter.ActualValue == null) {
138        var scope = ExecutionContext.Scope;
139        while (scope.Parent != null)
140          scope = scope.Parent;
141        scope.Variables.Add(new Core.Variable(EvaluatedTreeNodesResultName, new IntValue()));
142      }
143    }
144
145    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows) {
146      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
147      EstimationLimitsParameter.ExecutionContext = context;
148      ApplyLinearScalingParameter.ExecutionContext = context;
149
150      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value);
151
152      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
153      EstimationLimitsParameter.ExecutionContext = null;
154      ApplyLinearScalingParameter.ExecutionContext = context;
155
156      return r2;
157    }
158
159    // create function factory for arctangent
160    private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
161        x => Math.Atan(x),      // evaluate
162        x => 1 / (1 + x * x));  // derivative of atan
163
164    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
165      x => Math.Sin(x),
166      x => Math.Cos(x));
167    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
168      x => Math.Cos(x),
169      x => -Math.Sin(x));
170    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
171      x => Math.Tan(x),
172      x => 1 + Math.Tan(x) * Math.Tan(x));
173    private static readonly Func<Term, UnaryFunc> square = UnaryFunc.Factory(
174      x => x * x,
175      x => 2 * x);
176    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
177      x => alglib.errorfunction(x),
178      x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
179
180    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
181      x => alglib.normaldistribution(x),
182      x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI)
183      );
184
185
186    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData,
187      IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, double upperEstimationLimit = double.MaxValue, double lowerEstimationLimit = double.MinValue, IntValue evaluatedTrees = null, IntValue evaluatedTreeNodes = null) {
188
189      List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
190      List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
191      List<string> variableNames = new List<string>();
192
193      AutoDiff.Term func;
194      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, out func)) return 0.0;
195      if (variableNames.Count == 0) return 0.0;
196
197      AutoDiff.IParametricCompiledTerm compiledFunc = AutoDiff.TermUtils.Compile(func, variables.ToArray(), parameters.ToArray());
198
199      List<SymbolicExpressionTreeTerminalNode> terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
200      double[] c = new double[variables.Count];
201
202      {
203        c[0] = 0.0;
204        c[1] = 1.0;
205        //extract inital constants
206        int i = 2;
207        foreach (var node in terminalNodes) {
208          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
209          VariableTreeNode variableTreeNode = node as VariableTreeNode;
210          if (constantTreeNode != null)
211            c[i++] = constantTreeNode.Value;
212          else if (variableTreeNode != null && !variableTreeNode.Weight.IsAlmost(1.0))
213            c[i++] = variableTreeNode.Weight;
214        }
215      }
216
217      alglib.lsfitstate state;
218      alglib.lsfitreport rep;
219      int info;
220
221      Dataset ds = problemData.Dataset;
222      double[,] x = new double[rows.Count(), variableNames.Count];
223      int row = 0;
224      foreach (var r in rows) {
225        for (int col = 0; col < variableNames.Count; col++) {
226          x[row, col] = ds.GetDoubleValue(variableNames[col], r);
227        }
228        row++;
229      }
230      double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
231      int n = x.GetLength(0);
232      int m = x.GetLength(1);
233      int k = c.Length;
234
235      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
236      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
237
238      try {
239        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
240        alglib.lsfitsetcond(state, 0, 0, maxIterations);
241        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
242        alglib.lsfitresults(state, out info, out c, out rep);
243
244      }
245      catch (ArithmeticException) {
246        return 0.0;
247      }
248      catch (alglib.alglibexception) {
249        return 0.0;
250      }
251      {
252        // only when no error occurred
253        // set constants in tree
254        int i = 2;
255        foreach (var node in terminalNodes) {
256          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
257          VariableTreeNode variableTreeNode = node as VariableTreeNode;
258          if (constantTreeNode != null)
259            constantTreeNode.Value = c[i++];
260          else if (variableTreeNode != null && !variableTreeNode.Weight.IsAlmost(1.0))
261            variableTreeNode.Weight = c[i++];
262        }
263      }
264
265      return SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
266    }
267
268    private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
269      return (double[] c, double[] x, ref double func, object o) => {
270        func = compiledFunc.Evaluate(c, x);
271      };
272    }
273
274    private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
275      return (double[] c, double[] x, ref double func, double[] grad, object o) => {
276        var tupel = compiledFunc.Differentiate(c, x);
277        func = tupel.Item2;
278        Array.Copy(tupel.Item1, grad, grad.Length);
279      };
280    }
281
282    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, out AutoDiff.Term term) {
283      if (node.Symbol is Constant) {
284        var var = new AutoDiff.Variable();
285        variables.Add(var);
286        term = var;
287        return true;
288      }
289      if (node.Symbol is Variable) {
290        // don't tune weights with a value of 1.0 because it was probably set by the simplifier
291        var varNode = node as VariableTreeNode;
292        var par = new AutoDiff.Variable();
293        parameters.Add(par);
294        variableNames.Add(varNode.VariableName);
295        if (!varNode.Weight.IsAlmost(1.0)) {
296          var w = new AutoDiff.Variable();
297          variables.Add(w);
298          term = AutoDiff.TermBuilder.Product(w, par);
299        } else {
300          term = par;
301        }
302        return true;
303      }
304      if (node.Symbol is Addition) {
305        List<AutoDiff.Term> terms = new List<Term>();
306        foreach (var subTree in node.Subtrees) {
307          AutoDiff.Term t;
308          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out t)) {
309            term = null;
310            return false;
311          }
312          terms.Add(t);
313        }
314        term = AutoDiff.TermBuilder.Sum(terms);
315        return true;
316      }
317      if (node.Symbol is Multiplication) {
318        AutoDiff.Term a, b;
319        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
320          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
321          term = null;
322          return false;
323        } else {
324          List<AutoDiff.Term> factors = new List<Term>();
325          foreach (var subTree in node.Subtrees.Skip(2)) {
326            AutoDiff.Term f;
327            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
328              term = null;
329              return false;
330            }
331            factors.Add(f);
332          }
333          term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray());
334          return true;
335        }
336      }
337      if (node.Symbol is Division) {
338        // only works for at least two subtrees
339        AutoDiff.Term a, b;
340        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
341          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
342          term = null;
343          return false;
344        } else {
345          List<AutoDiff.Term> factors = new List<Term>();
346          foreach (var subTree in node.Subtrees.Skip(2)) {
347            AutoDiff.Term f;
348            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
349              term = null;
350              return false;
351            }
352            factors.Add(1.0 / f);
353          }
354          term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray());
355          return true;
356        }
357      }
358      if (node.Symbol is Logarithm) {
359        AutoDiff.Term t;
360        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
361          term = null;
362          return false;
363        } else {
364          term = AutoDiff.TermBuilder.Log(t);
365          return true;
366        }
367      }
368      if (node.Symbol is Exponential) {
369        AutoDiff.Term t;
370        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
371          term = null;
372          return false;
373        } else {
374          term = AutoDiff.TermBuilder.Exp(t);
375          return true;
376        }
377      } if (node.Symbol is Sine) {
378        AutoDiff.Term t;
379        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
380          term = null;
381          return false;
382        } else {
383          term = sin(t);
384          return true;
385        }
386      } if (node.Symbol is Cosine) {
387        AutoDiff.Term t;
388        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
389          term = null;
390          return false;
391        } else {
392          term = cos(t);
393          return true;
394        }
395      } if (node.Symbol is Tangent) {
396        AutoDiff.Term t;
397        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
398          term = null;
399          return false;
400        } else {
401          term = tan(t);
402          return true;
403        }
404      }
405      if (node.Symbol is Square) {
406        AutoDiff.Term t;
407        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
408          term = null;
409          return false;
410        } else {
411          term = square(t);
412          return true;
413        }
414      } if (node.Symbol is Erf) {
415        AutoDiff.Term t;
416        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
417          term = null;
418          return false;
419        } else {
420          term = erf(t);
421          return true;
422        }
423      } if (node.Symbol is Norm) {
424        AutoDiff.Term t;
425        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
426          term = null;
427          return false;
428        } else {
429          term = norm(t);
430          return true;
431        }
432      }
433      if (node.Symbol is StartSymbol) {
434        var alpha = new AutoDiff.Variable();
435        var beta = new AutoDiff.Variable();
436        variables.Add(beta);
437        variables.Add(alpha);
438        AutoDiff.Term branchTerm;
439        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out branchTerm)) {
440          term = branchTerm * alpha + beta;
441          return true;
442        } else {
443          term = null;
444          return false;
445        }
446      }
447      term = null;
448      return false;
449    }
450
451    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
452      var containsUnknownSymbol = (
453        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
454        where
455         !(n.Symbol is Variable) &&
456         !(n.Symbol is Constant) &&
457         !(n.Symbol is Addition) &&
458         !(n.Symbol is Subtraction) &&
459         !(n.Symbol is Multiplication) &&
460         !(n.Symbol is Division) &&
461         !(n.Symbol is Logarithm) &&
462         !(n.Symbol is Exponential) &&
463         !(n.Symbol is Sine) &&
464         !(n.Symbol is Cosine) &&
465         !(n.Symbol is Tangent) &&
466         !(n.Symbol is Square) &&
467         !(n.Symbol is Erf) &&
468         !(n.Symbol is Norm) &&
469         !(n.Symbol is StartSymbol)
470        select n).
471      Any();
472      return !containsUnknownSymbol;
473    }
474  }
475}
Note: See TracBrowser for help on using the repository browser.