source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/Analysis/RSquaredEvaluator.cs @ 15861

Last change on this file since 15861 was 15861, checked in by lkammere, 3 years ago

#2886: Make constant optimization toggleable in algorithm.

File size: 5.3 KB
Line 
1using System;
2using System.Diagnostics;
3using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Data;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Optimization;
9using HeuristicLab.Problems.DataAnalysis;
10using HeuristicLab.Problems.DataAnalysis.Symbolic;
11using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
12
13namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
14  public class RSquaredEvaluator : Item, IGrammarEnumerationAnalyzer {
15    private readonly string BestTrainingQualityResultName = "Best R² (Training)";
16    private readonly string BestTrainingModelResultName = "Best model (Training)";
17    private readonly string BestTrainingSolutionResultName = "Best solution (Training)";
18
19    private readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
20
21    public bool OptimizeConstants { get; set; }
22
23    public RSquaredEvaluator() { }
24
25    protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) {
26      this.OptimizeConstants = original.OptimizeConstants;
27    }
28
29    public override IDeepCloneable Clone(Cloner cloner) {
30      return new RSquaredEvaluator(this, cloner);
31    }
32
33    public void Register(GrammarEnumerationAlgorithm algorithm) {
34      algorithm.Started += OnStarted;
35      algorithm.Stopped += OnStopped;
36
37      algorithm.DistinctSentenceGenerated += AlgorithmOnDistinctSentenceGenerated;
38    }
39
40    public void Deregister(GrammarEnumerationAlgorithm algorithm) {
41      algorithm.Started -= OnStarted;
42      algorithm.Stopped -= OnStopped;
43
44      algorithm.DistinctSentenceGenerated -= AlgorithmOnDistinctSentenceGenerated;
45    }
46
47    private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) {
48      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
49      EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase);
50    }
51
52    private void OnStarted(object sender, EventArgs eventArgs) {
53      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
54      algorithm.Results.Add(new Result(BestTrainingQualityResultName, new DoubleValue(-1.0)));
55
56      algorithm.BestTrainingSentence = null;
57    }
58
59    private void OnStopped(object sender, EventArgs eventArgs) {
60      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
61      if (algorithm.Results.ContainsKey(BestTrainingModelResultName)) {
62        SymbolicRegressionModel model = (SymbolicRegressionModel)algorithm.Results[BestTrainingModelResultName].Value;
63        IRegressionSolution bestTrainingSolution = new RegressionSolution(model, algorithm.Problem.ProblemData);
64
65        algorithm.Results.AddOrUpdateResult(BestTrainingSolutionResultName, bestTrainingSolution);
66      }
67    }
68
69    private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString) {
70      var problemData = algorithm.Problem.ProblemData;
71
72      SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(symbolString);
73      Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree));
74
75      // TODO: Initialize constant values randomly
76      // TODO: Restarts
77
78      double r2;
79
80      SymbolicRegressionModel model = new SymbolicRegressionModel(
81          problemData.TargetVariable,
82          tree,
83          expressionTreeLinearInterpreter);
84
85      if (OptimizeConstants) {
86        r2 = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(expressionTreeLinearInterpreter,
87          tree,
88          problemData,
89          problemData.TrainingIndices,
90          applyLinearScaling: false,
91          maxIterations: 50,
92          updateVariableWeights: true,
93          updateConstantsInTree: true);
94
95        foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) {
96          ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode;
97          if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) {
98            constTreeNode.Value = 0.0;
99          }
100        }
101
102
103      } else {
104        var target = problemData.TargetVariableTrainingValues;
105        var estVals = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
106        OnlineCalculatorError error;
107        r2 = OnlinePearsonsRCalculator.Calculate(target, estVals, out error);
108        if (error != OnlineCalculatorError.None) r2 = 0.0;
109      }
110
111      var bestR2Result = (DoubleValue)algorithm.Results[BestTrainingQualityResultName].Value;
112      bool better = r2 > bestR2Result.Value;
113      bool equallyGood = r2.IsAlmost(bestR2Result.Value);
114      bool shorter = false;
115
116      if (!better && equallyGood) {
117        shorter = algorithm.BestTrainingSentence != null &&
118          algorithm.Grammar.GetComplexity(algorithm.BestTrainingSentence) > algorithm.Grammar.GetComplexity(symbolString);
119      }
120      if (better || (equallyGood && shorter)) {
121        bestR2Result.Value = r2;
122        algorithm.Results.AddOrUpdateResult(BestTrainingModelResultName, model);
123
124        algorithm.BestTrainingSentence = symbolString;
125      }
126    }
127  }
128}
Note: See TracBrowser for help on using the repository browser.