Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/Analysis/RSquaredEvaluator.cs @ 15993

Last change on this file since 15993 was 15993, checked in by bburlacu, 6 years ago

#2886: refactor code

File size: 6.4 KB
Line 
1using System;
2using System.Diagnostics;
3using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
4using HeuristicLab.Analysis;
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
9using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
10using HeuristicLab.Problems.DataAnalysis;
11using HeuristicLab.Problems.DataAnalysis.Symbolic;
12using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
13
14namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
15  [Item("RSquaredEvaluator", "")]
16  [StorableClass]
17  public class RSquaredEvaluator : Item, IGrammarEnumerationAnalyzer {
18    public static readonly string BestTrainingQualityResultName = "Best R² (Training)";
19    public static readonly string BestTestQualityResultName = "Best R² (Test)";
20    public static readonly string BestTrainingModelResultName = "Best model (Training)";
21    public static readonly string BestTrainingSolutionResultName = "Best solution (Training)";
22    public static readonly string BestComplexityResultName = "Best solution complexity";
23    public static readonly string BestSolutions = "Best solutions";
24
25    private static readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
26
27    public RSquaredEvaluator() { }
28
29    [StorableConstructor]
30    protected RSquaredEvaluator(bool deserializing) : base(deserializing) { }
31
32    protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) : base(original, cloner) {
33    }
34
35    public override IDeepCloneable Clone(Cloner cloner) {
36      return new RSquaredEvaluator(this, cloner);
37    }
38
39    public void Register(GrammarEnumerationAlgorithm algorithm) {
40      algorithm.Started += OnStarted;
41      algorithm.Stopped += OnStopped;
42      algorithm.DistinctSentenceGenerated += AlgorithmOnDistinctSentenceGenerated;
43    }
44
45    public void Deregister(GrammarEnumerationAlgorithm algorithm) {
46      algorithm.Started -= OnStarted;
47      algorithm.Stopped -= OnStopped;
48      algorithm.DistinctSentenceGenerated -= AlgorithmOnDistinctSentenceGenerated;
49    }
50
51    private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) {
52      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
53      EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase, algorithm.OptimizeConstants);
54    }
55
56    private void OnStarted(object sender, EventArgs eventArgs) {
57      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
58
59      algorithm.BestTrainingSentence = null;
60    }
61
62    private void OnStopped(object sender, EventArgs eventArgs) { }
63
64    private T GetValue<T>(IItem value) where T : struct {
65      var v = value as ValueTypeValue<T>;
66      if (v == null)
67        throw new ArgumentException(string.Format("Item is not of type {0}", typeof(ValueTypeValue<T>)));
68      return v.Value;
69    }
70
71    private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString, bool optimizeConstants) {
72      var results = algorithm.Results;
73      var grammar = algorithm.Grammar;
74      var problemData = algorithm.Problem.ProblemData;
75
76      SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(symbolString);
77      Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree));
78
79      double r2 = Evaluate(problemData, tree, optimizeConstants);
80      double bestR2 = results.ContainsKey(BestTrainingQualityResultName) ? GetValue<double>(results[BestTrainingQualityResultName].Value) : 0.0;
81      if (r2 < bestR2)
82        return;
83
84      var bestComplexity = results.ContainsKey(BestComplexityResultName) ? GetValue<int>(results[BestComplexityResultName].Value) : int.MaxValue;
85      var complexity = grammar.GetComplexity(symbolString);
86
87      if (algorithm.BestTrainingSentence == null || r2 > bestR2 || (r2.IsAlmost(bestR2) && complexity < bestComplexity)) {
88        results.AddOrUpdateResult(BestTrainingQualityResultName, new DoubleValue(r2));
89        results.AddOrUpdateResult(BestComplexityResultName, new IntValue(complexity));
90        algorithm.BestTrainingSentence = symbolString;
91
92        // record best sentence quality & length
93        DataTable dt;
94        if (!results.ContainsKey(BestSolutions)) {
95          var names = new[] { "Quality", "Relative Length", "Complexity", "Timestamp" };
96          dt = new DataTable();
97          foreach (var name in names) {
98            dt.Rows.Add(new DataRow(name) { VisualProperties = { StartIndexZero = true } });
99          }
100          results.AddOrUpdateResult(BestSolutions, dt);
101        }
102        dt = (DataTable)results[BestSolutions].Value;
103        dt.Rows["Quality"].Values.Add(r2);
104        dt.Rows["Relative Length"].Values.Add((double)symbolString.Count() / algorithm.MaxSentenceLength);
105        dt.Rows["Complexity"].Values.Add(complexity);
106        dt.Rows["Timestamp"].Values.Add(algorithm.ExecutionTime.TotalMilliseconds / 1000d);
107      }
108    }
109
110    public static double Evaluate(IRegressionProblemData problemData, SymbolicExpressionTree tree, bool optimizeConstants = true) {
111      double r2;
112
113      // TODO: Initialize constant values randomly
114      // TODO: Restarts
115      if (optimizeConstants) {
116        r2 = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(expressionTreeLinearInterpreter,
117          tree,
118          problemData,
119          problemData.TrainingIndices,
120          applyLinearScaling: true,
121          maxIterations: 10,
122          updateVariableWeights: false,
123          updateConstantsInTree: true);
124
125        foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) {
126          ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode;
127          if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) {
128            constTreeNode.Value = 0.0;
129          }
130        }
131      } else {
132        r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(expressionTreeLinearInterpreter,
133          tree,
134          double.MinValue,
135          double.MaxValue,
136          problemData,
137          problemData.TrainingIndices,
138          applyLinearScaling: true);
139      }
140      return double.IsNaN(r2) ? 0.0 : r2;
141    }
142  }
143}
Note: See TracBrowser for help on using the repository browser.