[15824] | 1 | using System;
|
---|
[15859] | 2 | using System.Diagnostics;
|
---|
[15824] | 3 | using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
|
---|
| 4 | using HeuristicLab.Common;
|
---|
| 5 | using HeuristicLab.Core;
|
---|
| 6 | using HeuristicLab.Data;
|
---|
| 7 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
| 8 | using HeuristicLab.Optimization;
|
---|
| 9 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 10 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
| 11 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
| 12 |
|
---|
| 13 | namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
|
---|
| 14 | public class RSquaredEvaluator : Item, IGrammarEnumerationAnalyzer {
|
---|
[15883] | 15 | public static readonly string BestTrainingQualityResultName = "Best R² (Training)";
|
---|
| 16 | public static readonly string BestTrainingModelResultName = "Best model (Training)";
|
---|
| 17 | public static readonly string BestTrainingSolutionResultName = "Best solution (Training)";
|
---|
[15824] | 18 |
|
---|
[15883] | 19 | private static readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
|
---|
[15824] | 20 |
|
---|
[15861] | 21 | public bool OptimizeConstants { get; set; }
|
---|
| 22 |
|
---|
[15824] | 23 | public RSquaredEvaluator() { }
|
---|
| 24 |
|
---|
[15910] | 25 | protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) : base(original, cloner) {
|
---|
[15861] | 26 | this.OptimizeConstants = original.OptimizeConstants;
|
---|
| 27 | }
|
---|
[15824] | 28 |
|
---|
| 29 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 30 | return new RSquaredEvaluator(this, cloner);
|
---|
| 31 | }
|
---|
| 32 |
|
---|
| 33 | public void Register(GrammarEnumerationAlgorithm algorithm) {
|
---|
| 34 | algorithm.Started += OnStarted;
|
---|
| 35 | algorithm.Stopped += OnStopped;
|
---|
| 36 |
|
---|
| 37 | algorithm.DistinctSentenceGenerated += AlgorithmOnDistinctSentenceGenerated;
|
---|
| 38 | }
|
---|
| 39 |
|
---|
| 40 | public void Deregister(GrammarEnumerationAlgorithm algorithm) {
|
---|
| 41 | algorithm.Started -= OnStarted;
|
---|
| 42 | algorithm.Stopped -= OnStopped;
|
---|
| 43 |
|
---|
| 44 | algorithm.DistinctSentenceGenerated -= AlgorithmOnDistinctSentenceGenerated;
|
---|
| 45 | }
|
---|
| 46 |
|
---|
| 47 | private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) {
|
---|
| 48 | GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
|
---|
| 49 | EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase);
|
---|
| 50 | }
|
---|
| 51 |
|
---|
| 52 | private void OnStarted(object sender, EventArgs eventArgs) {
|
---|
| 53 | GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
|
---|
| 54 |
|
---|
| 55 | algorithm.BestTrainingSentence = null;
|
---|
| 56 | }
|
---|
| 57 |
|
---|
| 58 | private void OnStopped(object sender, EventArgs eventArgs) {
|
---|
| 59 | GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
|
---|
| 60 | if (algorithm.Results.ContainsKey(BestTrainingModelResultName)) {
|
---|
| 61 | SymbolicRegressionModel model = (SymbolicRegressionModel)algorithm.Results[BestTrainingModelResultName].Value;
|
---|
| 62 | IRegressionSolution bestTrainingSolution = new RegressionSolution(model, algorithm.Problem.ProblemData);
|
---|
| 63 |
|
---|
| 64 | algorithm.Results.AddOrUpdateResult(BestTrainingSolutionResultName, bestTrainingSolution);
|
---|
| 65 | }
|
---|
| 66 | }
|
---|
| 67 |
|
---|
| 68 | private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString) {
|
---|
[15859] | 69 | var problemData = algorithm.Problem.ProblemData;
|
---|
| 70 |
|
---|
[15824] | 71 | SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(symbolString);
|
---|
[15859] | 72 | Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree));
|
---|
[15824] | 73 |
|
---|
[15883] | 74 | double r2 = Evaluate(problemData, tree, OptimizeConstants);
|
---|
[15859] | 75 |
|
---|
[15910] | 76 | double bestR2 = 0.0;
|
---|
| 77 | if (algorithm.Results.ContainsKey(BestTrainingQualityResultName))
|
---|
| 78 | bestR2 = ((DoubleValue)algorithm.Results[BestTrainingQualityResultName].Value).Value;
|
---|
| 79 | bool better = r2 > bestR2;
|
---|
| 80 | bool equallyGood = r2.IsAlmost(bestR2);
|
---|
[15883] | 81 | bool shorter = false;
|
---|
[15859] | 82 |
|
---|
[15883] | 83 | if (!better && equallyGood) {
|
---|
| 84 | shorter = algorithm.BestTrainingSentence != null &&
|
---|
| 85 | algorithm.Grammar.GetComplexity(algorithm.BestTrainingSentence) > algorithm.Grammar.GetComplexity(symbolString);
|
---|
| 86 | }
|
---|
| 87 | if (better || (equallyGood && shorter)) {
|
---|
[15910] | 88 | algorithm.Results.AddOrUpdateResult(BestTrainingQualityResultName, new DoubleValue(r2));
|
---|
[15883] | 89 |
|
---|
| 90 | SymbolicRegressionModel model = new SymbolicRegressionModel(
|
---|
[15861] | 91 | problemData.TargetVariable,
|
---|
| 92 | tree,
|
---|
| 93 | expressionTreeLinearInterpreter);
|
---|
| 94 |
|
---|
[15883] | 95 | algorithm.Results.AddOrUpdateResult(BestTrainingModelResultName, model);
|
---|
| 96 |
|
---|
| 97 | algorithm.BestTrainingSentence = symbolString;
|
---|
| 98 | }
|
---|
| 99 | }
|
---|
| 100 |
|
---|
| 101 | public static double Evaluate(IRegressionProblemData problemData, SymbolicExpressionTree tree, bool optimizeConstants = true) {
|
---|
| 102 | double r2;
|
---|
| 103 |
|
---|
| 104 | // TODO: Initialize constant values randomly
|
---|
| 105 | // TODO: Restarts
|
---|
| 106 | if (optimizeConstants) {
|
---|
[15861] | 107 | r2 = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(expressionTreeLinearInterpreter,
|
---|
| 108 | tree,
|
---|
| 109 | problemData,
|
---|
| 110 | problemData.TrainingIndices,
|
---|
| 111 | applyLinearScaling: false,
|
---|
[15915] | 112 | maxIterations: 10,
|
---|
[15861] | 113 | updateVariableWeights: true,
|
---|
| 114 | updateConstantsInTree: true);
|
---|
| 115 |
|
---|
| 116 | foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) {
|
---|
| 117 | ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode;
|
---|
| 118 | if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) {
|
---|
| 119 | constTreeNode.Value = 0.0;
|
---|
| 120 | }
|
---|
[15859] | 121 | }
|
---|
| 122 |
|
---|
[15861] | 123 | } else {
|
---|
| 124 | var target = problemData.TargetVariableTrainingValues;
|
---|
[15883] | 125 |
|
---|
| 126 | SymbolicRegressionModel model = new SymbolicRegressionModel(
|
---|
| 127 | problemData.TargetVariable,
|
---|
| 128 | tree,
|
---|
| 129 | expressionTreeLinearInterpreter);
|
---|
| 130 |
|
---|
[15861] | 131 | var estVals = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
|
---|
| 132 | OnlineCalculatorError error;
|
---|
| 133 | r2 = OnlinePearsonsRCalculator.Calculate(target, estVals, out error);
|
---|
| 134 | if (error != OnlineCalculatorError.None) r2 = 0.0;
|
---|
| 135 | }
|
---|
[15824] | 136 |
|
---|
[15883] | 137 | return r2;
|
---|
[15824] | 138 | }
|
---|
| 139 | }
|
---|
| 140 | }
|
---|