source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/Analysis/RSquaredEvaluator.cs @ 15883

Last change on this file since 15883 was 15883, checked in by lkammere, 18 months ago

#2886: Priorize phrases whose (fully expanded) terms result in high R².

File size: 5.7 KB
Line 
1using System;
2using System.Diagnostics;
3using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Data;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Optimization;
9using HeuristicLab.Problems.DataAnalysis;
10using HeuristicLab.Problems.DataAnalysis.Symbolic;
11using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
12
13namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
14  public class RSquaredEvaluator : Item, IGrammarEnumerationAnalyzer {
15    public static readonly string BestTrainingQualityResultName = "Best R² (Training)";
16    public static readonly string BestTrainingModelResultName = "Best model (Training)";
17    public static readonly string BestTrainingSolutionResultName = "Best solution (Training)";
18
19    private static readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
20
21    public bool OptimizeConstants { get; set; }
22
23    public RSquaredEvaluator() { }
24
25    protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) {
26      this.OptimizeConstants = original.OptimizeConstants;
27    }
28
29    public override IDeepCloneable Clone(Cloner cloner) {
30      return new RSquaredEvaluator(this, cloner);
31    }
32
33    public void Register(GrammarEnumerationAlgorithm algorithm) {
34      algorithm.Started += OnStarted;
35      algorithm.Stopped += OnStopped;
36
37      algorithm.DistinctSentenceGenerated += AlgorithmOnDistinctSentenceGenerated;
38    }
39
40    public void Deregister(GrammarEnumerationAlgorithm algorithm) {
41      algorithm.Started -= OnStarted;
42      algorithm.Stopped -= OnStopped;
43
44      algorithm.DistinctSentenceGenerated -= AlgorithmOnDistinctSentenceGenerated;
45    }
46
47    private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) {
48      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
49      EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase);
50    }
51
52    private void OnStarted(object sender, EventArgs eventArgs) {
53      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
54      algorithm.Results.Add(new Result(BestTrainingQualityResultName, new DoubleValue(-1.0)));
55
56      algorithm.BestTrainingSentence = null;
57    }
58
59    private void OnStopped(object sender, EventArgs eventArgs) {
60      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
61      if (algorithm.Results.ContainsKey(BestTrainingModelResultName)) {
62        SymbolicRegressionModel model = (SymbolicRegressionModel)algorithm.Results[BestTrainingModelResultName].Value;
63        IRegressionSolution bestTrainingSolution = new RegressionSolution(model, algorithm.Problem.ProblemData);
64
65        algorithm.Results.AddOrUpdateResult(BestTrainingSolutionResultName, bestTrainingSolution);
66      }
67    }
68
69    private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString) {
70      var problemData = algorithm.Problem.ProblemData;
71
72      SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(symbolString);
73      Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree));
74
75      double r2 = Evaluate(problemData, tree, OptimizeConstants);
76
77      var bestR2Result = (DoubleValue)algorithm.Results[BestTrainingQualityResultName].Value;
78      bool better = r2 > bestR2Result.Value;
79      bool equallyGood = r2.IsAlmost(bestR2Result.Value);
80      bool shorter = false;
81
82      if (!better && equallyGood) {
83        shorter = algorithm.BestTrainingSentence != null &&
84          algorithm.Grammar.GetComplexity(algorithm.BestTrainingSentence) > algorithm.Grammar.GetComplexity(symbolString);
85      }
86      if (better || (equallyGood && shorter)) {
87        bestR2Result.Value = r2;
88
89        SymbolicRegressionModel model = new SymbolicRegressionModel(
90          problemData.TargetVariable,
91          tree,
92          expressionTreeLinearInterpreter);
93
94        algorithm.Results.AddOrUpdateResult(BestTrainingModelResultName, model);
95
96        algorithm.BestTrainingSentence = symbolString;
97      }
98    }
99
100    public static double Evaluate(IRegressionProblemData problemData, SymbolicExpressionTree tree, bool optimizeConstants = true) {
101      double r2;
102
103      // TODO: Initialize constant values randomly
104      // TODO: Restarts
105      if (optimizeConstants) {
106        r2 = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(expressionTreeLinearInterpreter,
107          tree,
108          problemData,
109          problemData.TrainingIndices,
110          applyLinearScaling: false,
111          maxIterations: 50,
112          updateVariableWeights: true,
113          updateConstantsInTree: true);
114
115        foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) {
116          ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode;
117          if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) {
118            constTreeNode.Value = 0.0;
119          }
120        }
121
122      } else {
123        var target = problemData.TargetVariableTrainingValues;
124
125        SymbolicRegressionModel model = new SymbolicRegressionModel(
126          problemData.TargetVariable,
127          tree,
128          expressionTreeLinearInterpreter);
129
130        var estVals = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
131        OnlineCalculatorError error;
132        r2 = OnlinePearsonsRCalculator.Calculate(target, estVals, out error);
133        if (error != OnlineCalculatorError.None) r2 = 0.0;
134      }
135
136      return r2;
137    }
138  }
139}
Note: See TracBrowser for help on using the repository browser.