Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/RSquaredEvaluator.cs @ 16157

Last change on this file since 16157 was 16157, checked in by bburlacu, 6 years ago

#2886: Update IGrammarEnumerationEvaluator interface (add Evaluate method accepting an ISymbolicExpressionTree for the case when the constants have already been optimized in the tree, add boolean OptimizeConstants flag), small refactor in GrammarEnumeration/GrammarEnumerationAlgorithm.cs, add unit tests

File size: 7.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Linq;
23using HeuristicLab.Common;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Parameters;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
35  [Item("RSquaredEvaluator", "")]
36  [StorableClass]
37  public class RSquaredEvaluator : ParameterizedNamedItem, IGrammarEnumerationEvaluator {
38    private readonly string OptimizeConstantsParameterName = "Optimize Constants";
39    private readonly string ApplyLinearScalingParameterName = "Apply Linear Scaling";
40    private readonly string ConstantOptimizationIterationsParameterName = "Constant Optimization Iterations";
41    private readonly string RestartsParameterName = "Restarts";
42    private readonly string SeedParameterName = "Seed"; // seed for the random number generator
43
44    private readonly MersenneTwister random = new MersenneTwister();
45
46    #region parameter properties
47    public IFixedValueParameter<BoolValue> OptimizeConstantsParameter {
48      get { return (IFixedValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
49    }
50
51    public IFixedValueParameter<BoolValue> ApplyLinearScalingParameter {
52      get { return (IFixedValueParameter<BoolValue>)Parameters[ApplyLinearScalingParameterName]; }
53    }
54
55    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
56      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
57    }
58
59    private IFixedValueParameter<IntValue> RestartsParameter {
60      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
61    }
62
63    private IFixedValueParameter<IntValue> SeedParameter {
64      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
65    }
66
67    private int Restarts {
68      get { return RestartsParameter.Value.Value; }
69      set { RestartsParameter.Value.Value = value; }
70    }
71
72    private int Seed {
73      get { return SeedParameter.Value.Value; }
74      set { SeedParameter.Value.Value = value; }
75    }
76
77    public bool OptimizeConstants {
78      get { return OptimizeConstantsParameter.Value.Value; }
79      set { OptimizeConstantsParameter.Value.Value = value; }
80    }
81
82    public bool ApplyLinearScaling {
83      get { return ApplyLinearScalingParameter.Value.Value; }
84      set { ApplyLinearScalingParameter.Value.Value = value; }
85    }
86
87    public int ConstantOptimizationIterations {
88      get { return ConstantOptimizationIterationsParameter.Value.Value; }
89      set { ConstantOptimizationIterationsParameter.Value.Value = value; }
90    }
91    #endregion
92
93    private static readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
94
95    public RSquaredEvaluator() {
96      Parameters.Add(new FixedValueParameter<BoolValue>(OptimizeConstantsParameterName, "Run constant optimization in sentence evaluation.", new BoolValue(false)));
97      Parameters.Add(new FixedValueParameter<BoolValue>(ApplyLinearScalingParameterName, "Apply linear scaling on the tree model during evaluation.", new BoolValue(false)));
98      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Number of gradient descent iterations.", new IntValue(10)));
99      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "Number of restarts for gradient descent.", new IntValue(10)));
100
101      var seedParameter = new FixedValueParameter<IntValue>(SeedParameterName, "Seed value for random restarts.", new IntValue(0));
102      seedParameter.Value.ValueChanged += (sender, args) => random.Seed((uint)seedParameter.Value.Value);
103      random.Seed(0u);
104
105      Parameters.Add(seedParameter);
106    }
107
108    [StorableConstructor]
109    protected RSquaredEvaluator(bool deserializing) : base(deserializing) { }
110
111    protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) : base(original, cloner) {
112    }
113
114    public override IDeepCloneable Clone(Cloner cloner) {
115      return new RSquaredEvaluator(this, cloner);
116    }
117
118    public double Evaluate(IRegressionProblemData problemData, Grammar grammar, SymbolList sentence) {
119      var tree = grammar.ParseSymbolicExpressionTree(sentence);
120      return Evaluate(problemData, tree);
121    }
122
123    public double Evaluate(IRegressionProblemData problemData, ISymbolicExpressionTree tree) {
124      random.Seed((uint)Seed); // not the ideal solution for ensuring result consistency
125      return Evaluate(problemData, tree, random, OptimizeConstants, ConstantOptimizationIterations, ApplyLinearScaling, Restarts);
126    }
127
128    public static double Evaluate(IRegressionProblemData problemData, ISymbolicExpressionTree tree, IRandom random, bool optimizeConstants = true, int maxIterations = 10, bool applyLinearScaling = false, int restarts = 1) {
129      // we begin with an evaluation without constant optimization (relatively small speed penalty compared to const opt)
130      // this value will be used as a baseline to decide if an improvement was achieved via const opt
131      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(expressionTreeLinearInterpreter,
132          tree,
133          double.MinValue,
134          double.MaxValue,
135          problemData,
136          problemData.TrainingIndices,
137          applyLinearScaling: applyLinearScaling);
138
139      // restart const opt and try to obtain an improved r2 value
140      if (optimizeConstants) {
141        int count = 0;
142        double optimized = r2;
143        do {
144          foreach (var constantNode in tree.IterateNodesPrefix().OfType<ConstantTreeNode>()) {
145            constantNode.ResetLocalParameters(random);
146          }
147
148          optimized = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(
149            expressionTreeLinearInterpreter,
150            tree,
151            problemData,
152            problemData.TrainingIndices,
153            applyLinearScaling,
154            maxIterations,
155            false,
156            double.MinValue,
157            double.MaxValue,
158            true);
159        } while (optimized <= r2 && ++count < restarts);
160
161        // do not update constants if quality is not improved
162        if (optimized > r2) {
163          r2 = optimized;
164
165          // is this code really necessary ?
166          foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) {
167            ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode;
168            if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) {
169              constTreeNode.Value = 0.0;
170            }
171          }
172        }
173      }
174      return double.IsNaN(r2) || double.IsInfinity(r2) ? 0.0 : r2;
175    }
176  }
177}
Note: See TracBrowser for help on using the repository browser.