1 | using System;
|
---|
2 | using System.Diagnostics;
|
---|
3 | using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
|
---|
4 | using HeuristicLab.Common;
|
---|
5 | using HeuristicLab.Core;
|
---|
6 | using HeuristicLab.Data;
|
---|
7 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
8 | using HeuristicLab.Optimization;
|
---|
9 | using HeuristicLab.Problems.DataAnalysis;
|
---|
10 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
11 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
12 |
|
---|
13 | namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
|
---|
14 | public class RSquaredEvaluator : Item, IGrammarEnumerationAnalyzer {
|
---|
15 | private readonly string BestTrainingQualityResultName = "Best R² (Training)";
|
---|
16 | private readonly string BestTrainingModelResultName = "Best model (Training)";
|
---|
17 | private readonly string BestTrainingSolutionResultName = "Best solution (Training)";
|
---|
18 |
|
---|
19 | private readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
|
---|
20 |
|
---|
21 | public bool OptimizeConstants { get; set; }
|
---|
22 |
|
---|
23 | public RSquaredEvaluator() { }
|
---|
24 |
|
---|
25 | protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) {
|
---|
26 | this.OptimizeConstants = original.OptimizeConstants;
|
---|
27 | }
|
---|
28 |
|
---|
29 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
30 | return new RSquaredEvaluator(this, cloner);
|
---|
31 | }
|
---|
32 |
|
---|
33 | public void Register(GrammarEnumerationAlgorithm algorithm) {
|
---|
34 | algorithm.Started += OnStarted;
|
---|
35 | algorithm.Stopped += OnStopped;
|
---|
36 |
|
---|
37 | algorithm.DistinctSentenceGenerated += AlgorithmOnDistinctSentenceGenerated;
|
---|
38 | }
|
---|
39 |
|
---|
40 | public void Deregister(GrammarEnumerationAlgorithm algorithm) {
|
---|
41 | algorithm.Started -= OnStarted;
|
---|
42 | algorithm.Stopped -= OnStopped;
|
---|
43 |
|
---|
44 | algorithm.DistinctSentenceGenerated -= AlgorithmOnDistinctSentenceGenerated;
|
---|
45 | }
|
---|
46 |
|
---|
47 | private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) {
|
---|
48 | GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
|
---|
49 | EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase);
|
---|
50 | }
|
---|
51 |
|
---|
52 | private void OnStarted(object sender, EventArgs eventArgs) {
|
---|
53 | GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
|
---|
54 | algorithm.Results.Add(new Result(BestTrainingQualityResultName, new DoubleValue(-1.0)));
|
---|
55 |
|
---|
56 | algorithm.BestTrainingSentence = null;
|
---|
57 | }
|
---|
58 |
|
---|
59 | private void OnStopped(object sender, EventArgs eventArgs) {
|
---|
60 | GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
|
---|
61 | if (algorithm.Results.ContainsKey(BestTrainingModelResultName)) {
|
---|
62 | SymbolicRegressionModel model = (SymbolicRegressionModel)algorithm.Results[BestTrainingModelResultName].Value;
|
---|
63 | IRegressionSolution bestTrainingSolution = new RegressionSolution(model, algorithm.Problem.ProblemData);
|
---|
64 |
|
---|
65 | algorithm.Results.AddOrUpdateResult(BestTrainingSolutionResultName, bestTrainingSolution);
|
---|
66 | }
|
---|
67 | }
|
---|
68 |
|
---|
69 | private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString) {
|
---|
70 | var problemData = algorithm.Problem.ProblemData;
|
---|
71 |
|
---|
72 | SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(symbolString);
|
---|
73 | Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree));
|
---|
74 |
|
---|
75 | // TODO: Initialize constant values randomly
|
---|
76 | // TODO: Restarts
|
---|
77 |
|
---|
78 | double r2;
|
---|
79 |
|
---|
80 | SymbolicRegressionModel model = new SymbolicRegressionModel(
|
---|
81 | problemData.TargetVariable,
|
---|
82 | tree,
|
---|
83 | expressionTreeLinearInterpreter);
|
---|
84 |
|
---|
85 | if (OptimizeConstants) {
|
---|
86 | r2 = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(expressionTreeLinearInterpreter,
|
---|
87 | tree,
|
---|
88 | problemData,
|
---|
89 | problemData.TrainingIndices,
|
---|
90 | applyLinearScaling: false,
|
---|
91 | maxIterations: 50,
|
---|
92 | updateVariableWeights: true,
|
---|
93 | updateConstantsInTree: true);
|
---|
94 |
|
---|
95 | foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) {
|
---|
96 | ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode;
|
---|
97 | if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) {
|
---|
98 | constTreeNode.Value = 0.0;
|
---|
99 | }
|
---|
100 | }
|
---|
101 |
|
---|
102 |
|
---|
103 | } else {
|
---|
104 | var target = problemData.TargetVariableTrainingValues;
|
---|
105 | var estVals = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
|
---|
106 | OnlineCalculatorError error;
|
---|
107 | r2 = OnlinePearsonsRCalculator.Calculate(target, estVals, out error);
|
---|
108 | if (error != OnlineCalculatorError.None) r2 = 0.0;
|
---|
109 | }
|
---|
110 |
|
---|
111 | var bestR2Result = (DoubleValue)algorithm.Results[BestTrainingQualityResultName].Value;
|
---|
112 | bool better = r2 > bestR2Result.Value;
|
---|
113 | bool equallyGood = r2.IsAlmost(bestR2Result.Value);
|
---|
114 | bool shorter = false;
|
---|
115 |
|
---|
116 | if (!better && equallyGood) {
|
---|
117 | shorter = algorithm.BestTrainingSentence != null &&
|
---|
118 | algorithm.Grammar.GetComplexity(algorithm.BestTrainingSentence) > algorithm.Grammar.GetComplexity(symbolString);
|
---|
119 | }
|
---|
120 | if (better || (equallyGood && shorter)) {
|
---|
121 | bestR2Result.Value = r2;
|
---|
122 | algorithm.Results.AddOrUpdateResult(BestTrainingModelResultName, model);
|
---|
123 |
|
---|
124 | algorithm.BestTrainingSentence = symbolString;
|
---|
125 | }
|
---|
126 | }
|
---|
127 | }
|
---|
128 | }
|
---|