Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/ParameterOptimizationEvaluator.cs @ 18011

Last change on this file since 18011 was 18011, checked in by bburlacu, 3 years ago

#3087: Implement ceres-based parameter optimizer in new evaluator. Revert constant optimization evaluator to old behavior.

File size: 15.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25
26using HEAL.Attic;
27
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using HeuristicLab.NativeInterpreter;
33using HeuristicLab.Optimization;
34using HeuristicLab.Parameters;
35
36namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
37  [Item("Parameter Optimization Evaluator", "Optimizes model parameters using nonlinear least squares and returns the mean squared error.")]
38  [StorableType("D6443358-1FA3-4F4C-89DB-DCC3D81050B2")]
39  public class ParameterOptimizationEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
40    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
41    private const string ConstantOptimizationImprovementParameterName = "ConstantOptimizationImprovement";
42    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
43    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
44    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
45    private const string UpdateVariableWeightsParameterName = "Update Variable Weights";
46    private const string FunctionEvaluationsResultParameterName = "Constants Optimization Function Evaluations";
47    private const string GradientEvaluationsResultParameterName = "Constants Optimization Gradient Evaluations";
48    private const string CountEvaluationsParameterName = "Count Function and Gradient Evaluations";
49
50    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
51      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
52    }
53    public IFixedValueParameter<DoubleValue> ConstantOptimizationImprovementParameter {
54      get { return (IFixedValueParameter<DoubleValue>)Parameters[ConstantOptimizationImprovementParameterName]; }
55    }
56    public IFixedValueParameter<PercentValue> ConstantOptimizationProbabilityParameter {
57      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationProbabilityParameterName]; }
58    }
59    public IFixedValueParameter<PercentValue> ConstantOptimizationRowsPercentageParameter {
60      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
61    }
62    public IFixedValueParameter<BoolValue> UpdateConstantsInTreeParameter {
63      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
64    }
65    public IFixedValueParameter<BoolValue> UpdateVariableWeightsParameter {
66      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateVariableWeightsParameterName]; }
67    }
68
69    public IResultParameter<IntValue> FunctionEvaluationsResultParameter {
70      get { return (IResultParameter<IntValue>)Parameters[FunctionEvaluationsResultParameterName]; }
71    }
72    public IResultParameter<IntValue> GradientEvaluationsResultParameter {
73      get { return (IResultParameter<IntValue>)Parameters[GradientEvaluationsResultParameterName]; }
74    }
75    public IFixedValueParameter<BoolValue> CountEvaluationsParameter {
76      get { return (IFixedValueParameter<BoolValue>)Parameters[CountEvaluationsParameterName]; }
77    }
78
79    public IntValue ConstantOptimizationIterations {
80      get { return ConstantOptimizationIterationsParameter.Value; }
81    }
82    public DoubleValue ConstantOptimizationImprovement {
83      get { return ConstantOptimizationImprovementParameter.Value; }
84    }
85    public PercentValue ConstantOptimizationProbability {
86      get { return ConstantOptimizationProbabilityParameter.Value; }
87    }
88    public PercentValue ConstantOptimizationRowsPercentage {
89      get { return ConstantOptimizationRowsPercentageParameter.Value; }
90    }
91    public bool UpdateConstantsInTree {
92      get { return UpdateConstantsInTreeParameter.Value.Value; }
93      set { UpdateConstantsInTreeParameter.Value.Value = value; }
94    }
95
96    public bool UpdateVariableWeights {
97      get { return UpdateVariableWeightsParameter.Value.Value; }
98      set { UpdateVariableWeightsParameter.Value.Value = value; }
99    }
100
101    public bool CountEvaluations {
102      get { return CountEvaluationsParameter.Value.Value; }
103      set { CountEvaluationsParameter.Value.Value = value; }
104    }
105
106    public override bool Maximization {
107      get { return false; }
108    }
109
110    [StorableConstructor]
111    protected ParameterOptimizationEvaluator(StorableConstructorFlag _) : base(_) { }
112    protected ParameterOptimizationEvaluator(ParameterOptimizationEvaluator original, Cloner cloner)
113      : base(original, cloner) {
114    }
115    public ParameterOptimizationEvaluator()
116      : base() {
117      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10)));
118      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0)) { Hidden = true });
119      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1)));
120      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1)));
121      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)) { Hidden = true });
122      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)) { Hidden = true });
123
124      Parameters.Add(new FixedValueParameter<BoolValue>(CountEvaluationsParameterName, "Determines if function and gradient evaluation should be counted.", new BoolValue(false)));
125      Parameters.Add(new ResultParameter<IntValue>(FunctionEvaluationsResultParameterName, "The number of function evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
126      Parameters.Add(new ResultParameter<IntValue>(GradientEvaluationsResultParameterName, "The number of gradient evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
127    }
128
129    public override IDeepCloneable Clone(Cloner cloner) {
130      return new ParameterOptimizationEvaluator(this, cloner);
131    }
132
133    [StorableHook(HookType.AfterDeserialization)]
134    private void AfterDeserialization() {
135      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
136        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
137      if (!Parameters.ContainsKey(UpdateVariableWeightsParameterName))
138        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)));
139
140      if (!Parameters.ContainsKey(CountEvaluationsParameterName))
141        Parameters.Add(new FixedValueParameter<BoolValue>(CountEvaluationsParameterName, "Determines if function and gradient evaluation should be counted.", new BoolValue(false)));
142
143      if (!Parameters.ContainsKey(FunctionEvaluationsResultParameterName))
144        Parameters.Add(new ResultParameter<IntValue>(FunctionEvaluationsResultParameterName, "The number of function evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
145      if (!Parameters.ContainsKey(GradientEvaluationsResultParameterName))
146        Parameters.Add(new ResultParameter<IntValue>(GradientEvaluationsResultParameterName, "The number of gradient evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
147    }
148
149    private static readonly object locker = new object();
150    public override IOperation InstrumentedApply() {
151      var solution = SymbolicExpressionTreeParameter.ActualValue;
152      double quality;
153      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
154        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
155        var counter = new EvaluationsCounter();
156        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
157           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value, updateVariableWeights: UpdateVariableWeights, lowerEstimationLimit: EstimationLimitsParameter.ActualValue.Lower, upperEstimationLimit: EstimationLimitsParameter.ActualValue.Upper, updateConstantsInTree: UpdateConstantsInTree, counter: counter);
158
159        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
160          var evaluationRows = GenerateRowsToEvaluate();
161          quality = SymbolicRegressionSingleObjectiveMeanSquaredErrorEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
162        }
163
164        if (CountEvaluations) {
165          lock (locker) {
166            FunctionEvaluationsResultParameter.ActualValue.Value += counter.FunctionEvaluations;
167            GradientEvaluationsResultParameter.ActualValue.Value += counter.GradientEvaluations;
168          }
169        }
170
171      } else {
172        var evaluationRows = GenerateRowsToEvaluate();
173        quality = SymbolicRegressionSingleObjectiveMeanSquaredErrorEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
174      }
175      QualityParameter.ActualValue = new DoubleValue(quality);
176
177      return base.InstrumentedApply();
178    }
179
180    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows) {
181      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
182      EstimationLimitsParameter.ExecutionContext = context;
183      ApplyLinearScalingParameter.ExecutionContext = context;
184      FunctionEvaluationsResultParameter.ExecutionContext = context;
185      GradientEvaluationsResultParameter.ExecutionContext = context;
186
187      // Mean Squared Error evaluator is used on purpose instead of the const-opt evaluator,
188      // because Evaluate() is used to get the quality of evolved models on
189      // different partitions of the dataset (e.g., best validation model)
190      double mse = SymbolicRegressionSingleObjectiveMeanSquaredErrorEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value);
191
192      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
193      EstimationLimitsParameter.ExecutionContext = null;
194      ApplyLinearScalingParameter.ExecutionContext = null;
195      FunctionEvaluationsResultParameter.ExecutionContext = null;
196      GradientEvaluationsResultParameter.ExecutionContext = null;
197
198      return mse;
199    }
200
201    public class EvaluationsCounter {
202      public int FunctionEvaluations = 0;
203      public int GradientEvaluations = 0;
204    }
205
206    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
207      ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling,
208      int maxIterations, bool updateVariableWeights = true,
209      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
210      bool updateConstantsInTree = true, Action<double[], double, object> iterationCallback = null, EvaluationsCounter counter = null) {
211
212      var nodesToOptimize = new HashSet<ISymbolicExpressionTreeNode>();
213      var originalNodeValues = new Dictionary<ISymbolicExpressionTreeNode, double>();
214
215      foreach (var node in tree.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
216        if (node is VariableTreeNode && !updateVariableWeights) {
217          continue;
218        }
219        if (node is ConstantTreeNode && node.Parent.Symbol is Power && node.Parent.GetSubtree(1) == node) {
220          // do not optimize exponents
221          continue;
222        }
223        nodesToOptimize.Add(node);
224        if (node is ConstantTreeNode constant) {
225          originalNodeValues[node] = constant.Value;
226        } else if (node is VariableTreeNode variable) {
227          originalNodeValues[node] = variable.Weight;
228        }
229      }
230
231      var options = new SolverOptions {
232        Iterations = maxIterations
233      };
234      var summary = new OptimizationSummary();
235      var optimizedNodeValues = ParameterOptimizer.OptimizeTree(tree, problemData.Dataset, problemData.TrainingIndices, problemData.TargetVariable, nodesToOptimize, options, ref summary);
236
237      counter.FunctionEvaluations += summary.ResidualEvaluations;
238      counter.GradientEvaluations += summary.JacobianEvaluations;
239
240      if (summary.FinalCost < summary.InitialCost && updateConstantsInTree) {
241        UpdateNodeValues(optimizedNodeValues);
242      }
243      var mse = SymbolicRegressionSingleObjectiveMeanSquaredErrorEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
244      return mse;
245    }
246
247    private static void UpdateNodeValues(IDictionary<ISymbolicExpressionTreeNode, double> values) {
248      foreach (var item in values) {
249        var node = item.Key;
250        if (node is ConstantTreeNode constant) {
251          constant.Value = item.Value;
252        } else if (node is VariableTreeNode variable) {
253          variable.Weight = item.Value;
254        }
255      }
256    }
257
258    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
259      return TreeToAutoDiffTermConverter.IsCompatible(tree);
260    }
261  }
262}
Note: See TracBrowser for help on using the repository browser.