1 | using System;
|
---|
2 | using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
|
---|
3 | using HeuristicLab.Common;
|
---|
4 | using HeuristicLab.Core;
|
---|
5 | using HeuristicLab.Data;
|
---|
6 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
7 | using HeuristicLab.Optimization;
|
---|
8 | using HeuristicLab.Problems.DataAnalysis;
|
---|
9 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
10 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
11 |
|
---|
12 | namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
|
---|
13 | public class RSquaredEvaluator : Item, IGrammarEnumerationAnalyzer {
|
---|
14 | private readonly string BestTrainingQualityResultName = "Best R² (Training)";
|
---|
15 | private readonly string BestTrainingModelResultName = "Best model (Training)";
|
---|
16 | private readonly string BestTrainingSolutionResultName = "Best solution (Training)";
|
---|
17 |
|
---|
18 | private readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
|
---|
19 |
|
---|
20 | public RSquaredEvaluator() { }
|
---|
21 |
|
---|
22 | protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) { }
|
---|
23 |
|
---|
24 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
25 | return new RSquaredEvaluator(this, cloner);
|
---|
26 | }
|
---|
27 |
|
---|
28 | public void Register(GrammarEnumerationAlgorithm algorithm) {
|
---|
29 | algorithm.Started += OnStarted;
|
---|
30 | algorithm.Stopped += OnStopped;
|
---|
31 |
|
---|
32 | algorithm.DistinctSentenceGenerated += AlgorithmOnDistinctSentenceGenerated;
|
---|
33 | }
|
---|
34 |
|
---|
35 | public void Deregister(GrammarEnumerationAlgorithm algorithm) {
|
---|
36 | algorithm.Started -= OnStarted;
|
---|
37 | algorithm.Stopped -= OnStopped;
|
---|
38 |
|
---|
39 | algorithm.DistinctSentenceGenerated -= AlgorithmOnDistinctSentenceGenerated;
|
---|
40 | }
|
---|
41 |
|
---|
42 | private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) {
|
---|
43 | GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
|
---|
44 | EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase);
|
---|
45 | }
|
---|
46 |
|
---|
47 | private void OnStarted(object sender, EventArgs eventArgs) {
|
---|
48 | GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
|
---|
49 | algorithm.Results.Add(new Result(BestTrainingQualityResultName, new DoubleValue(-1.0)));
|
---|
50 |
|
---|
51 | algorithm.BestTrainingSentence = null;
|
---|
52 | }
|
---|
53 |
|
---|
54 | private void OnStopped(object sender, EventArgs eventArgs) {
|
---|
55 | GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
|
---|
56 | if (algorithm.Results.ContainsKey(BestTrainingModelResultName)) {
|
---|
57 | SymbolicRegressionModel model = (SymbolicRegressionModel)algorithm.Results[BestTrainingModelResultName].Value;
|
---|
58 | IRegressionSolution bestTrainingSolution = new RegressionSolution(model, algorithm.Problem.ProblemData);
|
---|
59 |
|
---|
60 | algorithm.Results.AddOrUpdateResult(BestTrainingSolutionResultName, bestTrainingSolution);
|
---|
61 | }
|
---|
62 | }
|
---|
63 |
|
---|
64 | private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString) {
|
---|
65 | SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(symbolString);
|
---|
66 |
|
---|
67 | var problemData = algorithm.Problem.ProblemData;
|
---|
68 | SymbolicRegressionModel model = new SymbolicRegressionModel(
|
---|
69 | problemData.TargetVariable,
|
---|
70 | tree,
|
---|
71 | expressionTreeLinearInterpreter);
|
---|
72 |
|
---|
73 | var target = problemData.TargetVariableTrainingValues;
|
---|
74 | var estVals = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
|
---|
75 | OnlineCalculatorError error;
|
---|
76 | var r2 = OnlinePearsonsRCalculator.Calculate(target, estVals, out error);
|
---|
77 | if (error != OnlineCalculatorError.None) r2 = 0.0;
|
---|
78 |
|
---|
79 | var bestR2Result = (DoubleValue)algorithm.Results[BestTrainingQualityResultName].Value;
|
---|
80 | if (r2 > bestR2Result.Value) {
|
---|
81 | bestR2Result.Value = r2;
|
---|
82 | algorithm.Results.AddOrUpdateResult(BestTrainingModelResultName, model);
|
---|
83 |
|
---|
84 | algorithm.BestTrainingSentence = symbolString;
|
---|
85 | }
|
---|
86 | }
|
---|
87 | }
|
---|
88 | }
|
---|