source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/Analysis/RSquaredEvaluator.cs @ 15910

Last change on this file since 15910 was 15910, checked in by lkammere, 17 months ago

#2886: Fix length parameter when priorizing phrases and add weighting parameter to control exploration/exploitation during search, fix copy constructors in Analyzers

File size: 5.8 KB
Line 
1using System;
2using System.Diagnostics;
3using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Data;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Optimization;
9using HeuristicLab.Problems.DataAnalysis;
10using HeuristicLab.Problems.DataAnalysis.Symbolic;
11using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
12
13namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
14  public class RSquaredEvaluator : Item, IGrammarEnumerationAnalyzer {
15    public static readonly string BestTrainingQualityResultName = "Best R² (Training)";
16    public static readonly string BestTrainingModelResultName = "Best model (Training)";
17    public static readonly string BestTrainingSolutionResultName = "Best solution (Training)";
18
19    private static readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
20
21    public bool OptimizeConstants { get; set; }
22
23    public RSquaredEvaluator() { }
24
25    protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) : base(original, cloner) {
26      this.OptimizeConstants = original.OptimizeConstants;
27    }
28
29    public override IDeepCloneable Clone(Cloner cloner) {
30      return new RSquaredEvaluator(this, cloner);
31    }
32
33    public void Register(GrammarEnumerationAlgorithm algorithm) {
34      algorithm.Started += OnStarted;
35      algorithm.Stopped += OnStopped;
36
37      algorithm.DistinctSentenceGenerated += AlgorithmOnDistinctSentenceGenerated;
38    }
39
40    public void Deregister(GrammarEnumerationAlgorithm algorithm) {
41      algorithm.Started -= OnStarted;
42      algorithm.Stopped -= OnStopped;
43
44      algorithm.DistinctSentenceGenerated -= AlgorithmOnDistinctSentenceGenerated;
45    }
46
47    private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) {
48      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
49      EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase);
50    }
51
52    private void OnStarted(object sender, EventArgs eventArgs) {
53      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
54
55      algorithm.BestTrainingSentence = null;
56    }
57
58    private void OnStopped(object sender, EventArgs eventArgs) {
59      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
60      if (algorithm.Results.ContainsKey(BestTrainingModelResultName)) {
61        SymbolicRegressionModel model = (SymbolicRegressionModel)algorithm.Results[BestTrainingModelResultName].Value;
62        IRegressionSolution bestTrainingSolution = new RegressionSolution(model, algorithm.Problem.ProblemData);
63
64        algorithm.Results.AddOrUpdateResult(BestTrainingSolutionResultName, bestTrainingSolution);
65      }
66    }
67
68    private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString) {
69      var problemData = algorithm.Problem.ProblemData;
70
71      SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(symbolString);
72      Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree));
73
74      double r2 = Evaluate(problemData, tree, OptimizeConstants);
75
76      double bestR2 = 0.0;
77      if (algorithm.Results.ContainsKey(BestTrainingQualityResultName))
78        bestR2 = ((DoubleValue)algorithm.Results[BestTrainingQualityResultName].Value).Value;
79      bool better = r2 > bestR2;
80      bool equallyGood = r2.IsAlmost(bestR2);
81      bool shorter = false;
82
83      if (!better && equallyGood) {
84        shorter = algorithm.BestTrainingSentence != null &&
85          algorithm.Grammar.GetComplexity(algorithm.BestTrainingSentence) > algorithm.Grammar.GetComplexity(symbolString);
86      }
87      if (better || (equallyGood && shorter)) {
88        algorithm.Results.AddOrUpdateResult(BestTrainingQualityResultName, new DoubleValue(r2));
89
90        SymbolicRegressionModel model = new SymbolicRegressionModel(
91          problemData.TargetVariable,
92          tree,
93          expressionTreeLinearInterpreter);
94
95        algorithm.Results.AddOrUpdateResult(BestTrainingModelResultName, model);
96
97        algorithm.BestTrainingSentence = symbolString;
98      }
99    }
100
101    public static double Evaluate(IRegressionProblemData problemData, SymbolicExpressionTree tree, bool optimizeConstants = true) {
102      double r2;
103
104      // TODO: Initialize constant values randomly
105      // TODO: Restarts
106      if (optimizeConstants) {
107        r2 = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(expressionTreeLinearInterpreter,
108          tree,
109          problemData,
110          problemData.TrainingIndices,
111          applyLinearScaling: false,
112          maxIterations: 50,
113          updateVariableWeights: true,
114          updateConstantsInTree: true);
115
116        foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) {
117          ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode;
118          if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) {
119            constTreeNode.Value = 0.0;
120          }
121        }
122
123      } else {
124        var target = problemData.TargetVariableTrainingValues;
125
126        SymbolicRegressionModel model = new SymbolicRegressionModel(
127          problemData.TargetVariable,
128          tree,
129          expressionTreeLinearInterpreter);
130
131        var estVals = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
132        OnlineCalculatorError error;
133        r2 = OnlinePearsonsRCalculator.Calculate(target, estVals, out error);
134        if (error != OnlineCalculatorError.None) r2 = 0.0;
135      }
136
137      return r2;
138    }
139  }
140}
Note: See TracBrowser for help on using the repository browser.