[15824] | 1 | using System;
|
---|
[15859] | 2 | using System.Diagnostics;
|
---|
[15824] | 3 | using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
|
---|
| 4 | using HeuristicLab.Common;
|
---|
| 5 | using HeuristicLab.Core;
|
---|
| 6 | using HeuristicLab.Data;
|
---|
| 7 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
| 8 | using HeuristicLab.Optimization;
|
---|
| 9 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 10 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
| 11 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
| 12 |
|
---|
| 13 | namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
|
---|
| 14 | public class RSquaredEvaluator : Item, IGrammarEnumerationAnalyzer {
|
---|
| 15 | private readonly string BestTrainingQualityResultName = "Best R² (Training)";
|
---|
| 16 | private readonly string BestTrainingModelResultName = "Best model (Training)";
|
---|
| 17 | private readonly string BestTrainingSolutionResultName = "Best solution (Training)";
|
---|
| 18 |
|
---|
| 19 | private readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
|
---|
| 20 |
|
---|
[15861] | 21 | public bool OptimizeConstants { get; set; }
|
---|
| 22 |
|
---|
[15824] | 23 | public RSquaredEvaluator() { }
|
---|
| 24 |
|
---|
[15861] | 25 | protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) {
|
---|
| 26 | this.OptimizeConstants = original.OptimizeConstants;
|
---|
| 27 | }
|
---|
[15824] | 28 |
|
---|
| 29 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 30 | return new RSquaredEvaluator(this, cloner);
|
---|
| 31 | }
|
---|
| 32 |
|
---|
| 33 | public void Register(GrammarEnumerationAlgorithm algorithm) {
|
---|
| 34 | algorithm.Started += OnStarted;
|
---|
| 35 | algorithm.Stopped += OnStopped;
|
---|
| 36 |
|
---|
| 37 | algorithm.DistinctSentenceGenerated += AlgorithmOnDistinctSentenceGenerated;
|
---|
| 38 | }
|
---|
| 39 |
|
---|
| 40 | public void Deregister(GrammarEnumerationAlgorithm algorithm) {
|
---|
| 41 | algorithm.Started -= OnStarted;
|
---|
| 42 | algorithm.Stopped -= OnStopped;
|
---|
| 43 |
|
---|
| 44 | algorithm.DistinctSentenceGenerated -= AlgorithmOnDistinctSentenceGenerated;
|
---|
| 45 | }
|
---|
| 46 |
|
---|
| 47 | private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) {
|
---|
| 48 | GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
|
---|
| 49 | EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase);
|
---|
| 50 | }
|
---|
| 51 |
|
---|
| 52 | private void OnStarted(object sender, EventArgs eventArgs) {
|
---|
| 53 | GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
|
---|
| 54 | algorithm.Results.Add(new Result(BestTrainingQualityResultName, new DoubleValue(-1.0)));
|
---|
| 55 |
|
---|
| 56 | algorithm.BestTrainingSentence = null;
|
---|
| 57 | }
|
---|
| 58 |
|
---|
| 59 | private void OnStopped(object sender, EventArgs eventArgs) {
|
---|
| 60 | GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
|
---|
| 61 | if (algorithm.Results.ContainsKey(BestTrainingModelResultName)) {
|
---|
| 62 | SymbolicRegressionModel model = (SymbolicRegressionModel)algorithm.Results[BestTrainingModelResultName].Value;
|
---|
| 63 | IRegressionSolution bestTrainingSolution = new RegressionSolution(model, algorithm.Problem.ProblemData);
|
---|
| 64 |
|
---|
| 65 | algorithm.Results.AddOrUpdateResult(BestTrainingSolutionResultName, bestTrainingSolution);
|
---|
| 66 | }
|
---|
| 67 | }
|
---|
| 68 |
|
---|
| 69 | private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString) {
|
---|
[15859] | 70 | var problemData = algorithm.Problem.ProblemData;
|
---|
| 71 |
|
---|
[15824] | 72 | SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(symbolString);
|
---|
[15859] | 73 | Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree));
|
---|
[15824] | 74 |
|
---|
[15859] | 75 | // TODO: Initialize constant values randomly
|
---|
| 76 | // TODO: Restarts
|
---|
| 77 |
|
---|
[15861] | 78 | double r2;
|
---|
[15859] | 79 |
|
---|
[15861] | 80 | SymbolicRegressionModel model = new SymbolicRegressionModel(
|
---|
| 81 | problemData.TargetVariable,
|
---|
| 82 | tree,
|
---|
| 83 | expressionTreeLinearInterpreter);
|
---|
| 84 |
|
---|
| 85 | if (OptimizeConstants) {
|
---|
| 86 | r2 = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(expressionTreeLinearInterpreter,
|
---|
| 87 | tree,
|
---|
| 88 | problemData,
|
---|
| 89 | problemData.TrainingIndices,
|
---|
| 90 | applyLinearScaling: false,
|
---|
| 91 | maxIterations: 50,
|
---|
| 92 | updateVariableWeights: true,
|
---|
| 93 | updateConstantsInTree: true);
|
---|
| 94 |
|
---|
| 95 | foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) {
|
---|
| 96 | ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode;
|
---|
| 97 | if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) {
|
---|
| 98 | constTreeNode.Value = 0.0;
|
---|
| 99 | }
|
---|
[15859] | 100 | }
|
---|
| 101 |
|
---|
[15824] | 102 |
|
---|
[15861] | 103 | } else {
|
---|
| 104 | var target = problemData.TargetVariableTrainingValues;
|
---|
| 105 | var estVals = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
|
---|
| 106 | OnlineCalculatorError error;
|
---|
| 107 | r2 = OnlinePearsonsRCalculator.Calculate(target, estVals, out error);
|
---|
| 108 | if (error != OnlineCalculatorError.None) r2 = 0.0;
|
---|
| 109 | }
|
---|
[15824] | 110 |
|
---|
| 111 | var bestR2Result = (DoubleValue)algorithm.Results[BestTrainingQualityResultName].Value;
|
---|
[15859] | 112 | bool better = r2 > bestR2Result.Value;
|
---|
| 113 | bool equallyGood = r2.IsAlmost(bestR2Result.Value);
|
---|
[15860] | 114 | bool shorter = false;
|
---|
| 115 |
|
---|
| 116 | if (!better && equallyGood) {
|
---|
| 117 | shorter = algorithm.BestTrainingSentence != null &&
|
---|
| 118 | algorithm.Grammar.GetComplexity(algorithm.BestTrainingSentence) > algorithm.Grammar.GetComplexity(symbolString);
|
---|
| 119 | }
|
---|
[15859] | 120 | if (better || (equallyGood && shorter)) {
|
---|
[15824] | 121 | bestR2Result.Value = r2;
|
---|
| 122 | algorithm.Results.AddOrUpdateResult(BestTrainingModelResultName, model);
|
---|
| 123 |
|
---|
| 124 | algorithm.BestTrainingSentence = symbolString;
|
---|
| 125 | }
|
---|
| 126 | }
|
---|
| 127 | }
|
---|
| 128 | }
|
---|