Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/Analysis/RSquaredEvaluator.cs @ 15982

Last change on this file since 15982 was 15982, checked in by bburlacu, 6 years ago

#2886: Add storable constructors for analyzers

File size: 5.8 KB
Line 
1using System;
2using System.Diagnostics;
3using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Data;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
9using HeuristicLab.Problems.DataAnalysis;
10using HeuristicLab.Problems.DataAnalysis.Symbolic;
11using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
12
13namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
14  [Item("RSquaredEvaluator", "")]
15  [StorableClass]
16  public class RSquaredEvaluator : Item, IGrammarEnumerationAnalyzer {
17    public static readonly string BestTrainingQualityResultName = "Best R² (Training)";
18    public static readonly string BestTestQualityResultName = "Best R² (Test)";
19    public static readonly string BestTrainingModelResultName = "Best model (Training)";
20    public static readonly string BestTrainingSolutionResultName = "Best solution (Training)";
21    public static readonly string BestComplexityResultName = "Best solution complexity";
22
23    private static readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
24
25    public bool OptimizeConstants { get; set; }
26
27    public RSquaredEvaluator() { }
28
29    [StorableConstructor]
30    protected RSquaredEvaluator(bool deserializing) : base(deserializing) { }
31
32    protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) : base(original, cloner) {
33      this.OptimizeConstants = original.OptimizeConstants;
34    }
35
36    public override IDeepCloneable Clone(Cloner cloner) {
37      return new RSquaredEvaluator(this, cloner);
38    }
39
40    public void Register(GrammarEnumerationAlgorithm algorithm) {
41      algorithm.Started += OnStarted;
42      algorithm.Stopped += OnStopped;
43      algorithm.DistinctSentenceGenerated += AlgorithmOnDistinctSentenceGenerated;
44    }
45
46    public void Deregister(GrammarEnumerationAlgorithm algorithm) {
47      algorithm.Started -= OnStarted;
48      algorithm.Stopped -= OnStopped;
49      algorithm.DistinctSentenceGenerated -= AlgorithmOnDistinctSentenceGenerated;
50    }
51
52    private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) {
53      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
54      EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase);
55    }
56
57    private void OnStarted(object sender, EventArgs eventArgs) {
58      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
59
60      algorithm.BestTrainingSentence = null;
61    }
62
63    private void OnStopped(object sender, EventArgs eventArgs) { }
64
65    private T GetValue<T>(IItem value) where T : struct {
66      var v = value as ValueTypeValue<T>;
67      if (v == null)
68        throw new ArgumentException(string.Format("Item is not of type {0}", typeof(ValueTypeValue<T>)));
69      return v.Value;
70    }
71
72    private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString) {
73      var results = algorithm.Results;
74      var grammar = algorithm.Grammar;
75      var problemData = algorithm.Problem.ProblemData;
76
77      SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(symbolString);
78      Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree));
79
80      double r2 = Evaluate(problemData, tree, OptimizeConstants);
81      double bestR2 = results.ContainsKey(BestTrainingQualityResultName) ? GetValue<double>(results[BestTrainingQualityResultName].Value) : 0.0;
82      if (r2 < bestR2)
83        return;
84
85      var bestComplexity = int.MaxValue;
86      if (results.ContainsKey(BestComplexityResultName)) {
87        bestComplexity = GetValue<int>(results[BestComplexityResultName].Value);
88      } else if (algorithm.BestTrainingSentence != null) {
89        bestComplexity = grammar.GetComplexity(algorithm.BestTrainingSentence);
90        results.AddOrUpdateResult(BestComplexityResultName, new IntValue(bestComplexity));
91      }
92      var complexity = grammar.GetComplexity(symbolString);
93
94      if (r2 > bestR2 || (r2.IsAlmost(bestR2) && complexity < bestComplexity)) {
95        results.AddOrUpdateResult(BestTrainingQualityResultName, new DoubleValue(r2));
96        results.AddOrUpdateResult(BestComplexityResultName, new IntValue(complexity));
97        algorithm.BestTrainingSentence = symbolString;
98      }
99    }
100
101    public static double Evaluate(IRegressionProblemData problemData, SymbolicExpressionTree tree, bool optimizeConstants = true) {
102      double r2;
103
104      // TODO: Initialize constant values randomly
105      // TODO: Restarts
106      if (optimizeConstants) {
107        r2 = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(expressionTreeLinearInterpreter,
108          tree,
109          problemData,
110          problemData.TrainingIndices,
111          applyLinearScaling: false,
112          maxIterations: 10,
113          updateVariableWeights: false,
114          updateConstantsInTree: true);
115
116        foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) {
117          ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode;
118          if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) {
119            constTreeNode.Value = 0.0;
120          }
121        }
122      } else {
123        r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(expressionTreeLinearInterpreter,
124          tree,
125          double.MinValue,
126          double.MaxValue,
127          problemData,
128          problemData.TrainingIndices,
129          applyLinearScaling: true);
130      }
131      return double.IsNaN(r2) ? 0.0 : r2;
132    }
133  }
134}
Note: See TracBrowser for help on using the repository browser.