Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/Analysis/RSquaredEvaluator.cs @ 15994

Last change on this file since 15994 was 15994, checked in by bburlacu, 6 years ago

#2886: Add symbolic regression solution to results during algorithm run and scale model.

File size: 6.9 KB
Line 
1using System;
2using System.Diagnostics;
3using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
4using HeuristicLab.Analysis;
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
9using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
10using HeuristicLab.Problems.DataAnalysis;
11using HeuristicLab.Problems.DataAnalysis.Symbolic;
12using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
13
14namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
15  [Item("RSquaredEvaluator", "")]
16  [StorableClass]
17  public class RSquaredEvaluator : Item, IGrammarEnumerationAnalyzer {
18    public static readonly string BestTrainingQualityResultName = "Best R² (Training)";
19    public static readonly string BestTestQualityResultName = "Best R² (Test)";
20    public static readonly string BestTrainingModelResultName = "Best model (Training)";
21    public static readonly string BestTrainingSolutionResultName = "Best solution (Training)";
22    public static readonly string BestComplexityResultName = "Best solution complexity";
23    public static readonly string BestSolutions = "Best solutions";
24
25    private static readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
26
27    public RSquaredEvaluator() { }
28
29    [StorableConstructor]
30    protected RSquaredEvaluator(bool deserializing) : base(deserializing) { }
31
32    protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) : base(original, cloner) {
33    }
34
35    public override IDeepCloneable Clone(Cloner cloner) {
36      return new RSquaredEvaluator(this, cloner);
37    }
38
39    public void Register(GrammarEnumerationAlgorithm algorithm) {
40      algorithm.Started += OnStarted;
41      algorithm.Stopped += OnStopped;
42      algorithm.DistinctSentenceGenerated += AlgorithmOnDistinctSentenceGenerated;
43    }
44
45    public void Deregister(GrammarEnumerationAlgorithm algorithm) {
46      algorithm.Started -= OnStarted;
47      algorithm.Stopped -= OnStopped;
48      algorithm.DistinctSentenceGenerated -= AlgorithmOnDistinctSentenceGenerated;
49    }
50
51    private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) {
52      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
53      EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase, algorithm.OptimizeConstants);
54    }
55
56    private void OnStarted(object sender, EventArgs eventArgs) {
57      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
58
59      algorithm.BestTrainingSentence = null;
60    }
61
62    private void OnStopped(object sender, EventArgs eventArgs) { }
63
64    private T GetValue<T>(IItem value) where T : struct {
65      var v = value as ValueTypeValue<T>;
66      if (v == null)
67        throw new ArgumentException(string.Format("Item is not of type {0}", typeof(ValueTypeValue<T>)));
68      return v.Value;
69    }
70
71    private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString, bool optimizeConstants) {
72      var results = algorithm.Results;
73      var grammar = algorithm.Grammar;
74      var problemData = algorithm.Problem.ProblemData;
75
76      SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(symbolString);
77      Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree));
78
79      double r2 = Evaluate(problemData, tree, optimizeConstants);
80      double bestR2 = results.ContainsKey(BestTrainingQualityResultName) ? GetValue<double>(results[BestTrainingQualityResultName].Value) : 0.0;
81      if (r2 < bestR2)
82        return;
83
84      var bestComplexity = results.ContainsKey(BestComplexityResultName) ? GetValue<int>(results[BestComplexityResultName].Value) : int.MaxValue;
85      var complexity = grammar.GetComplexity(symbolString);
86
87      if (algorithm.BestTrainingSentence == null || r2 > bestR2 || (r2.IsAlmost(bestR2) && complexity < bestComplexity)) {
88        algorithm.BestTrainingSentence = symbolString;
89
90        var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, expressionTreeLinearInterpreter);
91        model.Scale(problemData);
92        var bestSolution = model.CreateRegressionSolution(problemData);
93
94        results.AddOrUpdateResult(BestTrainingQualityResultName, new DoubleValue(bestSolution.TrainingRSquared));
95        results.AddOrUpdateResult(BestTestQualityResultName, new DoubleValue(bestSolution.TestRSquared));
96        results.AddOrUpdateResult(BestTrainingModelResultName, bestSolution.Model);
97        results.AddOrUpdateResult(BestTrainingSolutionResultName, bestSolution);
98        results.AddOrUpdateResult(BestComplexityResultName, new IntValue(complexity));
99
100        // record best sentence quality & length
101        DataTable dt;
102        if (!results.ContainsKey(BestSolutions)) {
103          var names = new[] { "Quality", "Relative Length", "Complexity", "Timestamp" };
104          dt = new DataTable();
105          foreach (var name in names) {
106            dt.Rows.Add(new DataRow(name) { VisualProperties = { StartIndexZero = true } });
107          }
108          results.AddOrUpdateResult(BestSolutions, dt);
109        }
110        dt = (DataTable)results[BestSolutions].Value;
111        dt.Rows["Quality"].Values.Add(r2);
112        dt.Rows["Relative Length"].Values.Add((double)symbolString.Count() / algorithm.MaxSentenceLength);
113        dt.Rows["Complexity"].Values.Add(complexity);
114        dt.Rows["Timestamp"].Values.Add(algorithm.ExecutionTime.TotalMilliseconds / 1000d);
115      }
116    }
117
118    public static double Evaluate(IRegressionProblemData problemData, SymbolicExpressionTree tree, bool optimizeConstants = true) {
119      double r2;
120
121      // TODO: Initialize constant values randomly
122      // TODO: Restarts
123      if (optimizeConstants) {
124        r2 = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(expressionTreeLinearInterpreter,
125          tree,
126          problemData,
127          problemData.TrainingIndices,
128          applyLinearScaling: true,
129          maxIterations: 10,
130          updateVariableWeights: false,
131          updateConstantsInTree: true);
132
133        foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) {
134          ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode;
135          if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) {
136            constTreeNode.Value = 0.0;
137          }
138        }
139      } else {
140        r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(expressionTreeLinearInterpreter,
141          tree,
142          double.MinValue,
143          double.MaxValue,
144          problemData,
145          problemData.TrainingIndices,
146          applyLinearScaling: true);
147      }
148      return double.IsNaN(r2) ? 0.0 : r2;
149    }
150  }
151}
Note: See TracBrowser for help on using the repository browser.