Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/Analysis/RSquaredEvaluator.cs @ 15915

Last change on this file since 15915 was 15915, checked in by lkammere, 6 years ago

#2886: Add separate data structure for storing phrases in the queue.

File size: 5.8 KB
RevLine 
[15824]1using System;
[15859]2using System.Diagnostics;
[15824]3using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Data;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Optimization;
9using HeuristicLab.Problems.DataAnalysis;
10using HeuristicLab.Problems.DataAnalysis.Symbolic;
11using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
12
13namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
14  public class RSquaredEvaluator : Item, IGrammarEnumerationAnalyzer {
[15883]15    public static readonly string BestTrainingQualityResultName = "Best R² (Training)";
16    public static readonly string BestTrainingModelResultName = "Best model (Training)";
17    public static readonly string BestTrainingSolutionResultName = "Best solution (Training)";
[15824]18
[15883]19    private static readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
[15824]20
[15861]21    public bool OptimizeConstants { get; set; }
22
[15824]23    public RSquaredEvaluator() { }
24
[15910]25    protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) : base(original, cloner) {
[15861]26      this.OptimizeConstants = original.OptimizeConstants;
27    }
[15824]28
29    public override IDeepCloneable Clone(Cloner cloner) {
30      return new RSquaredEvaluator(this, cloner);
31    }
32
33    public void Register(GrammarEnumerationAlgorithm algorithm) {
34      algorithm.Started += OnStarted;
35      algorithm.Stopped += OnStopped;
36
37      algorithm.DistinctSentenceGenerated += AlgorithmOnDistinctSentenceGenerated;
38    }
39
40    public void Deregister(GrammarEnumerationAlgorithm algorithm) {
41      algorithm.Started -= OnStarted;
42      algorithm.Stopped -= OnStopped;
43
44      algorithm.DistinctSentenceGenerated -= AlgorithmOnDistinctSentenceGenerated;
45    }
46
47    private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) {
48      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
49      EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase);
50    }
51
52    private void OnStarted(object sender, EventArgs eventArgs) {
53      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
54
55      algorithm.BestTrainingSentence = null;
56    }
57
58    private void OnStopped(object sender, EventArgs eventArgs) {
59      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
60      if (algorithm.Results.ContainsKey(BestTrainingModelResultName)) {
61        SymbolicRegressionModel model = (SymbolicRegressionModel)algorithm.Results[BestTrainingModelResultName].Value;
62        IRegressionSolution bestTrainingSolution = new RegressionSolution(model, algorithm.Problem.ProblemData);
63
64        algorithm.Results.AddOrUpdateResult(BestTrainingSolutionResultName, bestTrainingSolution);
65      }
66    }
67
68    private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString) {
[15859]69      var problemData = algorithm.Problem.ProblemData;
70
[15824]71      SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(symbolString);
[15859]72      Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree));
[15824]73
[15883]74      double r2 = Evaluate(problemData, tree, OptimizeConstants);
[15859]75
[15910]76      double bestR2 = 0.0;
77      if (algorithm.Results.ContainsKey(BestTrainingQualityResultName))
78        bestR2 = ((DoubleValue)algorithm.Results[BestTrainingQualityResultName].Value).Value;
79      bool better = r2 > bestR2;
80      bool equallyGood = r2.IsAlmost(bestR2);
[15883]81      bool shorter = false;
[15859]82
[15883]83      if (!better && equallyGood) {
84        shorter = algorithm.BestTrainingSentence != null &&
85          algorithm.Grammar.GetComplexity(algorithm.BestTrainingSentence) > algorithm.Grammar.GetComplexity(symbolString);
86      }
87      if (better || (equallyGood && shorter)) {
[15910]88        algorithm.Results.AddOrUpdateResult(BestTrainingQualityResultName, new DoubleValue(r2));
[15883]89
90        SymbolicRegressionModel model = new SymbolicRegressionModel(
[15861]91          problemData.TargetVariable,
92          tree,
93          expressionTreeLinearInterpreter);
94
[15883]95        algorithm.Results.AddOrUpdateResult(BestTrainingModelResultName, model);
96
97        algorithm.BestTrainingSentence = symbolString;
98      }
99    }
100
101    public static double Evaluate(IRegressionProblemData problemData, SymbolicExpressionTree tree, bool optimizeConstants = true) {
102      double r2;
103
104      // TODO: Initialize constant values randomly
105      // TODO: Restarts
106      if (optimizeConstants) {
[15861]107        r2 = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(expressionTreeLinearInterpreter,
108          tree,
109          problemData,
110          problemData.TrainingIndices,
111          applyLinearScaling: false,
[15915]112          maxIterations: 10,
[15861]113          updateVariableWeights: true,
114          updateConstantsInTree: true);
115
116        foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) {
117          ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode;
118          if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) {
119            constTreeNode.Value = 0.0;
120          }
[15859]121        }
122
[15861]123      } else {
124        var target = problemData.TargetVariableTrainingValues;
[15883]125
126        SymbolicRegressionModel model = new SymbolicRegressionModel(
127          problemData.TargetVariable,
128          tree,
129          expressionTreeLinearInterpreter);
130
[15861]131        var estVals = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
132        OnlineCalculatorError error;
133        r2 = OnlinePearsonsRCalculator.Calculate(target, estVals, out error);
134        if (error != OnlineCalculatorError.None) r2 = 0.0;
135      }
[15824]136
[15883]137      return r2;
[15824]138    }
139  }
140}
Note: See TracBrowser for help on using the repository browser.