Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/RSquaredEvaluator.cs @ 16073

Last change on this file since 16073 was 16073, checked in by bburlacu, 6 years ago

#2886: Implement restarts for constant optimization in the RSquaredEvaluator

File size: 7.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Linq;
23using HeuristicLab.Common;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Parameters;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
35  [Item("RSquaredEvaluator", "")]
36  [StorableClass]
37  public class RSquaredEvaluator : ParameterizedNamedItem, IGrammarEnumerationEvaluator {
38    private readonly string OptimizeConstantsParameterName = "Optimize Constants";
39    private readonly string ApplyLinearScalingParameterName = "Apply Linear Scaling";
40    private readonly string ConstantOptimizationIterationsParameterName = "Constant Optimization Iterations";
41    private readonly string RestartsParameterName = "Restarts";
42    private readonly string SeedParameterName = "Seed"; // seed for the random number generator
43
44    private readonly MersenneTwister random = new MersenneTwister();
45
46    #region parameter properties
47    public IFixedValueParameter<BoolValue> OptimizeConstantsParameter {
48      get { return (IFixedValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
49    }
50
51    public IFixedValueParameter<BoolValue> ApplyLinearScalingParameter {
52      get { return (IFixedValueParameter<BoolValue>)Parameters[ApplyLinearScalingParameterName]; }
53    }
54
55    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
56      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
57    }
58
59    private IFixedValueParameter<IntValue> RestartsParameter {
60      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
61    }
62
63    private int Restarts {
64      get { return RestartsParameter.Value.Value; }
65      set { RestartsParameter.Value.Value = value; }
66    }
67
68    public bool OptimizeConstants {
69      get { return OptimizeConstantsParameter.Value.Value; }
70      set { OptimizeConstantsParameter.Value.Value = value; }
71    }
72
73    public bool ApplyLinearScaling {
74      get { return ApplyLinearScalingParameter.Value.Value; }
75      set { ApplyLinearScalingParameter.Value.Value = value; }
76    }
77
78    public int ConstantOptimizationIterations {
79      get { return ConstantOptimizationIterationsParameter.Value.Value; }
80      set { ConstantOptimizationIterationsParameter.Value.Value = value; }
81    }
82    #endregion
83
84    private static readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
85
86    public RSquaredEvaluator() {
87      Parameters.Add(new FixedValueParameter<BoolValue>(OptimizeConstantsParameterName, "Run constant optimization in sentence evaluation.", new BoolValue(false)));
88      Parameters.Add(new FixedValueParameter<BoolValue>(ApplyLinearScalingParameterName, "Apply linear scaling on the tree model during evaluation.", new BoolValue(false)));
89      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Number of gradient descent iterations.", new IntValue(10)));
90      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "Number of restarts for gradient descent.", new IntValue(10)));
91
92      var seedParameter = new FixedValueParameter<IntValue>(SeedParameterName, "Seed value for random restarts.", new IntValue(0));
93      seedParameter.Value.ValueChanged += (sender, args) => random.Seed((uint)seedParameter.Value.Value);
94      random.Seed(0u);
95
96      Parameters.Add(seedParameter);
97    }
98
99    [StorableConstructor]
100    protected RSquaredEvaluator(bool deserializing) : base(deserializing) { }
101
102    protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) : base(original, cloner) {
103    }
104
105    public override IDeepCloneable Clone(Cloner cloner) {
106      return new RSquaredEvaluator(this, cloner);
107    }
108
109    public double Evaluate(IRegressionProblemData problemData, Grammar grammar, SymbolList sentence) {
110      var tree = grammar.ParseSymbolicExpressionTree(sentence);
111      return Evaluate(problemData, tree);
112    }
113
114    public double Evaluate(IRegressionProblemData problemData, ISymbolicExpressionTree tree) {
115      return Evaluate(problemData, tree, random, OptimizeConstants, ConstantOptimizationIterations, ApplyLinearScaling, Restarts);
116    }
117
118    public static double Evaluate(IRegressionProblemData problemData, ISymbolicExpressionTree tree, IRandom random, bool optimizeConstants = true, int maxIterations = 10, bool applyLinearScaling = false, int restarts = 1) {
119      // we begin with an evaluation without constant optimization (relatively small speed penalty compared to const opt)
120      // this value will be used as a baseline to decide if an improvement was achieved via const opt
121      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(expressionTreeLinearInterpreter,
122          tree,
123          double.MinValue,
124          double.MaxValue,
125          problemData,
126          problemData.TrainingIndices,
127          applyLinearScaling: applyLinearScaling);
128
129      // restart const opt and try to obtain an improved r2 value
130      if (optimizeConstants) {
131        int count = 0;
132        double optimized = r2;
133        do {
134          foreach (var constantNode in tree.IterateNodesPrefix().OfType<ConstantTreeNode>()) {
135            constantNode.ResetLocalParameters(random);
136          }
137
138          optimized = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(
139            expressionTreeLinearInterpreter,
140            tree,
141            problemData,
142            problemData.TrainingIndices,
143            applyLinearScaling,
144            maxIterations,
145            false,
146            double.MinValue,
147            double.MaxValue,
148            true);
149        } while (optimized <= r2 && ++count < restarts);
150
151        // do not update constants if quality is not improved
152        if (optimized > r2) {
153          r2 = optimized;
154
155          // is this code really necessary ?
156          foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) {
157            ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode;
158            if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) {
159              constTreeNode.Value = 0.0;
160            }
161          }
162        }
163      }
164      return double.IsNaN(r2) || double.IsInfinity(r2) ? 0.0 : r2;
165    }
166  }
167}
Note: See TracBrowser for help on using the repository browser.