Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/Analysis/BestSolutionAnalyzer.cs @ 16123

Last change on this file since 16123 was 16090, checked in by lkammere, 6 years ago

#2886: Explicitely store all pareto-optimal RegressionSolution objects at the end of the algorithm.

File size: 9.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Diagnostics;
24using System.Linq;
25using HeuristicLab.Analysis;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Optimization;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Problems.DataAnalysis.Symbolic;
33using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
34
35namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
36  [Item("Best Solution Analyzer", "Returns the characteristics of the best solution so far.")]
37  [StorableClass]
38  public class BestSolutionAnalyzer : Item, IGrammarEnumerationAnalyzer {
39    public static readonly string BestTrainingQualityResultName = "Best R² (Training)";
40    public static readonly string BestTestQualityResultName = "Best R² (Test)";
41    public static readonly string BestTrainingModelResultName = "Best model (Training)";
42    public static readonly string BestTrainingSolutionResultName = "Best solution (Training)";
43    public static readonly string BestComplexityResultName = "Best solution complexity";
44    public static readonly string BestSolutions = "Best solutions";
45    public static readonly string ParetoFrontResultName = "Pareto Front";
46    public static readonly string ParetoFrontAnalysisResultName = "Pareto Front Analysis";
47    public static readonly string ParetoFrontSolutionsResultName = "Pareto Front Solutions";
48
49    private static readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
50
51    public BestSolutionAnalyzer() { }
52
53    [StorableConstructor]
54    protected BestSolutionAnalyzer(bool deserializing) : base(deserializing) { }
55
56    protected BestSolutionAnalyzer(BestSolutionAnalyzer original, Cloner cloner) : base(original, cloner) {
57    }
58
59    public override IDeepCloneable Clone(Cloner cloner) {
60      return new BestSolutionAnalyzer(this, cloner);
61    }
62
63    public void Deregister(GrammarEnumerationAlgorithm algorithm) {
64      algorithm.DistinctSentenceGenerated -= AlgorithmDistinctSentenceGenerated;
65      algorithm.Stopped -= AlgorithmOnStopped;
66    }
67
68    public void Register(GrammarEnumerationAlgorithm algorithm) {
69      algorithm.DistinctSentenceGenerated += AlgorithmDistinctSentenceGenerated;
70      algorithm.Stopped += AlgorithmOnStopped;
71    }
72
73    private void AlgorithmOnStopped(object sender, EventArgs eventArgs) {
74      var algorithm = (GrammarEnumerationAlgorithm)sender;
75
76      IResult paretoFrontResult;
77      if (algorithm.Results.TryGetValue(ParetoFrontAnalysisResultName, out paretoFrontResult)) {
78        var plot = (ScatterPlot)paretoFrontResult.Value;
79
80        var solutions = plot.Rows.First().Points.Select(p => (ISymbolicRegressionSolution)p.Tag);
81
82        algorithm.Results.AddOrUpdateResult(ParetoFrontSolutionsResultName, new ItemList<ISymbolicRegressionSolution>(solutions));
83      }
84    }
85
86    private void AlgorithmDistinctSentenceGenerated(object sender, PhraseAddedEventArgs args) {
87      var algorithm = (GrammarEnumerationAlgorithm)sender;
88      var sentence = args.NewPhrase;
89
90      var results = algorithm.Results;
91      var problemData = algorithm.Problem.ProblemData;
92
93      SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(sentence);
94      Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree));
95
96      double r2 = algorithm.Evaluator.Evaluate(problemData, tree);
97      int rank = GetRank(sentence);
98
99      // Store solution in pareto front
100      if (IsParetoOptimal(algorithm, rank, r2)) {
101        var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, expressionTreeLinearInterpreter);
102        model.Scale(problemData);
103        var bestSolution = model.CreateRegressionSolution(problemData);
104
105        AddToParetoFront(algorithm, rank, r2, bestSolution);
106
107        // Store overall best solution
108        double bestR2 = results.ContainsKey(BestTrainingQualityResultName)
109          ? GetValue<double>(results[BestTrainingQualityResultName].Value)
110          : 0.0;
111        var bestComplexity = results.ContainsKey(BestComplexityResultName) ? GetValue<int>(results[BestComplexityResultName].Value) : int.MaxValue;
112        var complexity = sentence.Complexity;
113
114        if (algorithm.BestTrainingSentence == null || r2 > bestR2 || (r2.IsAlmost(bestR2) && complexity < bestComplexity)) {
115          algorithm.BestTrainingSentence = sentence;
116
117          results.AddOrUpdateResult(BestTrainingQualityResultName, new DoubleValue(bestSolution.TrainingRSquared));
118          results.AddOrUpdateResult(BestTestQualityResultName, new DoubleValue(bestSolution.TestRSquared));
119          results.AddOrUpdateResult(BestTrainingModelResultName, bestSolution.Model);
120          results.AddOrUpdateResult(BestTrainingSolutionResultName, bestSolution);
121          results.AddOrUpdateResult(BestComplexityResultName, new IntValue(complexity));
122
123          // record best sentence quality & length
124          DataTable dt;
125          if (!results.ContainsKey(BestSolutions)) {
126            var names = new[] { "Quality", "Length", "Complexity", "Timestamp" };
127            dt = new DataTable();
128            foreach (var name in names) {
129              dt.Rows.Add(new DataRow(name) { VisualProperties = { StartIndexZero = true } });
130            }
131            results.AddOrUpdateResult(BestSolutions, dt);
132          }
133          dt = (DataTable)results[BestSolutions].Value;
134          dt.Rows["Quality"].Values.Add(r2);
135          dt.Rows["Length"].Values.Add((double)sentence.Count);
136          dt.Rows["Complexity"].Values.Add(complexity);
137          dt.Rows["Timestamp"].Values.Add(algorithm.ExecutionTime.TotalMilliseconds / 1000d);
138        }
139      }
140
141      // stop the algorithm if the best quality was already achieved
142      if (r2.IsAlmost(1d)) {
143        algorithm.Stop();
144      }
145    }
146
147    private T GetValue<T>(IItem value) where T : struct {
148      var v = value as ValueTypeValue<T>;
149      if (v == null)
150        throw new ArgumentException(string.Format("Item is not of type {0}", typeof(ValueTypeValue<T>)));
151      return v.Value;
152    }
153
154    private int GetRank(SymbolList s) {
155      return s.Complexity;
156    }
157
158    private bool IsParetoOptimal(GrammarEnumerationAlgorithm algorithm, int currRank, double currQuality) {
159      if (!algorithm.Results.ContainsKey(ParetoFrontResultName)) return true;
160
161      ItemList<DoubleArray> paretoFront = (ItemList<DoubleArray>)algorithm.Results[ParetoFrontResultName].Value;
162
163      int preceedingRankIndex = -1;
164      int lastIndex = paretoFront.Count - 1;
165      while (preceedingRankIndex < lastIndex) {
166        if (preceedingRankIndex + 1 > currRank)
167          break;
168        preceedingRankIndex++;
169      }
170
171      return preceedingRankIndex < 0 || paretoFront[preceedingRankIndex][1] < currQuality;
172    }
173
174    private void AddToParetoFront(GrammarEnumerationAlgorithm algorithm, int currRank, double currQuality, ISymbolicRegressionSolution solution) {
175      if (!algorithm.Results.ContainsKey(ParetoFrontResultName)) {
176        algorithm.Results.Add(new Result(ParetoFrontResultName, new ItemList<DoubleArray>()));
177
178        var scatterPlot = new ScatterPlot(ParetoFrontAnalysisResultName, ParetoFrontAnalysisResultName);
179        algorithm.Results.Add(new Result(ParetoFrontAnalysisResultName, scatterPlot));
180        scatterPlot.Rows.Add(new ScatterPlotDataRow());
181
182        scatterPlot.VisualProperties.XAxisTitle = "Complexity";
183        scatterPlot.VisualProperties.YAxisTitle = "R²";
184        scatterPlot.Rows.First().VisualProperties.PointSize = 10;
185      }
186
187      ItemList<DoubleArray> paretoFront = (ItemList<DoubleArray>)algorithm.Results[ParetoFrontResultName].Value;
188      ScatterPlotDataRow plot = ((ScatterPlot)algorithm.Results[ParetoFrontAnalysisResultName].Value).Rows.First();
189
190      // Delete solutions with higher rank, which are now dominated.
191      int i = 0;
192      while (i < paretoFront.Count) {
193        if (paretoFront[i][0] >= currRank) { // Go to current rank
194          double quality = paretoFront[i][1];
195          if (quality <= currQuality) { // If existing solution is worse, delete it
196            RemovePoint(plot, currRank);
197            paretoFront.RemoveAt(i);
198          } else { // Otherwise stop, since following solutions can only be better
199            break;
200          }
201        } else {
202          i++;
203        }
204      }
205
206      paretoFront.Insert(i, new DoubleArray(new double[] { currRank, currQuality }));
207      plot.Points.Add(new Point2D<double>(currRank, currQuality, solution));
208    }
209
210    private void RemovePoint(ScatterPlotDataRow plot, double rank) {
211      plot.Points.RemoveAll(p => p.X.IsAlmost(rank));
212    }
213  }
214}
Note: See TracBrowser for help on using the repository browser.