Free cookie consent management tool by TermsFeed Policy Generator

source: branches/1837_Sliding Window GP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SlidingWindow/SlidingWindowBestSolutionsCollection.cs @ 17717

Last change on this file since 17717 was 17687, checked in by fbaching, 4 years ago

#1837: merged changes from trunk

  • apply changes from Attic release to all SlidingWindow specific code files (replace StorableClass with StorableType)
File size: 13.8 KB
RevLine 
[10396]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[10720]22using System;
[10396]23using System.Collections.Generic;
[10720]24using System.ComponentModel;
25using System.Linq;
[10396]26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
[17687]30using HEAL.Attic;
[10396]31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
[10398]33  [Item("SlidingWindowBestSolutionsCollection", "An object holding a collection of the best sliding window solutions.")]
[17687]34  [StorableType("08DA042D-9A0E-4D7A-8ED7-AED6918D8EF3")]
[10396]35  public abstract class SlidingWindowBestSolutionsCollection : Item {
36    [Storable]
[10720]37    private List<SlidingWindowRange> slidingWindowRanges;
[10721]38
[10720]39    public List<SlidingWindowRange> SlidingWindowRanges {
40      get { return slidingWindowRanges; }
41      private set { slidingWindowRanges = value; }
[10396]42    }
43
[10720]44    [Storable(AllowOneWay = true, Name = "bestSolutions")]
45    private Dictionary<Tuple<int, int>, ISymbolicExpressionTree> StorableBestSolutions {
46      set {
47        var bestSolutions = value;
48        var ranges = bestSolutions.Keys.OrderBy(x => x.Item1).ToList();
49        slidingWindowRanges = ranges.Select(x => new SlidingWindowRange(x.Item1, x.Item2)).ToList();
50        slidingWindowBestSolutions = new Dictionary<SlidingWindowRange, ISymbolicExpressionTree>();
51        for (int i = 0; i < slidingWindowRanges.Count; ++i) {
52          slidingWindowBestSolutions.Add(slidingWindowRanges[i], bestSolutions[ranges[i]]);
53        }
54      }
55    }
56
[10396]57    [Storable]
[10720]58    private Dictionary<SlidingWindowRange, ISymbolicExpressionTree> slidingWindowBestSolutions;
[10721]59
[10720]60    public Dictionary<SlidingWindowRange, ISymbolicExpressionTree> SlidingWindowBestSolutions {
61      get { return slidingWindowBestSolutions; }
62      set { slidingWindowBestSolutions = value; }
63    }
64
65    [Storable]
[10398]66    private IDataAnalysisProblemData problemData;
[10721]67
[10398]68    public IDataAnalysisProblemData ProblemData {
69      get { return problemData; }
70      set { problemData = value; }
71    }
72
73    [Storable]
74    private ISymbolicDataAnalysisExpressionTreeInterpreter interpreter;
[10721]75
[10398]76    public ISymbolicDataAnalysisExpressionTreeInterpreter Interpreter {
77      get { return interpreter; }
78      set { interpreter = value; }
79    }
[10720]80
81    [Storable]
82    private bool applyLinearScaling;
[10721]83
[10720]84    public bool ApplyLinearScaling {
85      get { return applyLinearScaling; }
86      set { applyLinearScaling = value; }
87    }
88
[10402]89    [StorableHook(HookType.AfterDeserialization)]
90    private void AfterDeserialization() {
[10720]91      if (bw == null) {
92        bw = new BackgroundWorker();
93        bw.WorkerSupportsCancellation = true;
94        bw.WorkerReportsProgress = true;
95        bw.DoWork += CalculateQualities;
96      }
[10402]97    }
[10398]98
[10720]99    public double[,] SlidingWindowQualities { get; set; }
100
101    private BackgroundWorker bw;
102
103    public enum QualityMeasures { PEARSON, MSE };
104
105    private QualityMeasures qualityMeasure;
[10721]106
[10720]107    public QualityMeasures QualityMeasure {
108      get { return qualityMeasure; }
109      set {
110        if (qualityMeasure != value) {
111          qualityMeasure = value;
112          CalculateQualities();
113        }
114      }
115    }
116
117    public bool QualitiesCalculationInProgress {
118      get { return bw.IsBusy; }
119    }
120
121    public event ProgressChangedEventHandler QualitiesCalculationProgress {
122      add { bw.ProgressChanged += value; }
123      remove { bw.ProgressChanged -= value; }
124    }
125
126    public event RunWorkerCompletedEventHandler QualitiesCalculationCompleted {
127      add { bw.RunWorkerCompleted += value; }
128      remove { bw.RunWorkerCompleted -= value; }
129    }
130
[10721]131    public event EventHandler QualitiesCalculationStarted;
132
133    private void OnQualitiesCalculationStarted(object sender, EventArgs e) {
134      var started = QualitiesCalculationStarted;
135      if (started != null) {
136        started(sender, e);
137      }
138    }
139
[10720]140    public event EventHandler QualitiesUpdated;
141    private void OnQualitiesUpdated(object sender, EventArgs e) {
142      var updated = QualitiesUpdated;
[10721]143      if (updated != null) {
144        updated(sender, e);
145      }
[10720]146    }
147
[10402]148    [StorableConstructor]
[17687]149    protected SlidingWindowBestSolutionsCollection(StorableConstructorFlag _) : base(_) {
[10721]150    }
151
[10396]152    protected SlidingWindowBestSolutionsCollection(SlidingWindowBestSolutionsCollection original, Cloner cloner)
153      : base(original, cloner) {
[10720]154      this.slidingWindowBestSolutions = original.slidingWindowBestSolutions;
[10402]155      this.problemData = original.problemData;
156      this.interpreter = original.interpreter;
[10720]157      this.applyLinearScaling = original.ApplyLinearScaling;
[10396]158    }
[10721]159
[10396]160    protected SlidingWindowBestSolutionsCollection() {
[10720]161      slidingWindowBestSolutions = new Dictionary<SlidingWindowRange, ISymbolicExpressionTree>();
162      slidingWindowRanges = new List<SlidingWindowRange>();
163      qualityMeasure = QualityMeasures.PEARSON;
164
165      bw = new BackgroundWorker();
166      bw.WorkerSupportsCancellation = true;
167      bw.WorkerReportsProgress = true;
168
169      bw.DoWork += CalculateQualities;
[10396]170    }
171
[10413]172    public bool ContainsKey(SlidingWindowRange key) {
[10720]173      return slidingWindowBestSolutions.ContainsKey(key);
[10396]174    }
175
[10413]176    public ISymbolicExpressionTree this[SlidingWindowRange key] {
[10721]177      get { return slidingWindowBestSolutions[key]; }
[10413]178      set {
[10720]179        AddSolution(key, value); // this should be fast so there's no need for a background worker
180        OnQualitiesUpdated(this, EventArgs.Empty);
[10413]181      }
182    }
183
184    public void Add(SlidingWindowRange range, ISymbolicExpressionTree solution) {
[10720]185      if (!slidingWindowBestSolutions.ContainsKey(range)) {
186        slidingWindowBestSolutions.Add(range, solution);
187        slidingWindowRanges.Add(range);
188      } else {
189        slidingWindowBestSolutions[range] = solution;
190      }
[10413]191    }
192
[10396]193    public void Clear() {
[10720]194      if (slidingWindowBestSolutions != null) slidingWindowBestSolutions.Clear();
195      if (slidingWindowRanges != null) slidingWindowRanges.Clear();
[10396]196    }
197
[10721]198    public abstract ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree,
199      ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
200      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue);
[10396]201
[10721]202    public abstract ISymbolicDataAnalysisSolution CreateSolution(ISymbolicDataAnalysisModel model,
203      IDataAnalysisProblemData problemData);
[10720]204
205    private void AddSolution(SlidingWindowRange range, ISymbolicExpressionTree solution) {
206      Add(range, solution);
207
208      var solutions = slidingWindowRanges.Select(x => slidingWindowBestSolutions[x]).ToList();
209
210      var nRows = solutions.Count;
211      var nCols = nRows + 1; // an extra column corresponding to the whole trainig partition
212
213      var trainingIndices = problemData.TrainingIndices.ToList();
214      var matrix = new double[nRows, nCols];
[10721]215
[10720]216      // copy old qualities into the new matrix
217      for (int i = 0; i < nRows - 1; ++i) {
218        for (int j = 0; j < nCols - 1; ++j) {
219          matrix[i, j] = SlidingWindowQualities[i, j];
220        }
221      }
222      // copy qualities of new solution into the new matrix
[10721]223      var rows = Enumerable.Range(slidingWindowRanges.First().Start, slidingWindowRanges.Last().End - slidingWindowRanges.First().Start).ToList();
224      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows).ToList();
225      var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows).ToList();
[10720]226      for (int i = 0; i < nCols; ++i) {
[10721]227        if (i == nCols - 1) {
228          matrix[nRows - 1, i] = CalculateQuality(solution, trainingIndices);
229        } else {
230          var indices = Enumerable.Range(slidingWindowRanges[i].Start, slidingWindowRanges[i].Size).ToList();
231          var estimated = indices.Select(x => estimatedValues[x]);
232          var original = indices.Select(x => originalValues[x]);
233          matrix[nRows - 1, i] = CalculateQuality(estimated, original);
234        }
[10720]235      }
236      // shift old training qualities one column to the right
237      rows = Enumerable.Range(range.Start, range.Size).ToList();
238      for (int i = 0; i < nRows; ++i) {
239        matrix[i, nCols - 1] = matrix[i, nCols - 2];
240        matrix[i, nCols - 2] = CalculateQuality(solutions[i], rows);
241      }
242      // replace old matrix with new matrix
243      SlidingWindowQualities = matrix;
244    }
245
246    private void CalculateQualities(object sender, DoWorkEventArgs e) {
247      var worker = sender as BackgroundWorker;
248      if (worker == null) return;
249      if (worker.CancellationPending) {
250        e.Cancel = true;
251        return;
252      }
253
[10721]254      OnQualitiesCalculationStarted(this, EventArgs.Empty);
255
[10720]256      var ranges = SlidingWindowRanges;
257      var solutions = ranges.Select(x => SlidingWindowBestSolutions[x]).ToList();
258
[10721]259      int nRows = solutions.Count;
260      int nCols = ranges.Count + 1;
[10720]261
[10721]262      SlidingWindowQualities = new double[nRows, nCols];
263      var rows = Enumerable.Range(ranges.First().Start, ranges.Last().End - ranges.First().Start).ToList();
264      var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows).ToList();
[10720]265
[10721]266      for (int i = 0; i < nRows; ++i) {
[10720]267        if (worker.CancellationPending) {
268          e.Cancel = true;
269          return;
270        }
271
272        var solution = solutions[i];
[10721]273        var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows).ToList();
[10720]274
[10721]275        for (int j = 0; j < nCols; ++j) {
276          double q;
277          if (j == nCols - 1) {
278            q = CalculateQuality(solution, problemData.TrainingIndices);
279          } else {
280            var range = ranges[j];
281            var indices = Enumerable.Range(range.Start, range.Size).ToList();
282            var estimated = indices.Select(x => estimatedValues[x]);
283            var original = indices.Select(x => originalValues[x]);
284
285            q = CalculateQuality(estimated, original);
286          }
287
[10720]288          SlidingWindowQualities[i, j] = q;
289        }
290
[10721]291        worker.ReportProgress((int)Math.Round(i * 100.0 / nRows));
[10720]292      }
293    }
294
295    public void CalculateQualities() {
296      bw.RunWorkerAsync();
297    }
298
299    private string GetTargetVariable(IDataAnalysisProblemData problemData) {
300      var regressionProblemData = problemData as IRegressionProblemData;
301      var classificationProblemData = problemData as IClassificationProblemData;
302      if (regressionProblemData != null) return regressionProblemData.TargetVariable;
303      if (classificationProblemData != null) return classificationProblemData.TargetVariable;
304      throw new NotSupportedException();
305    }
306
[10721]307    private double CalculateQuality(IEnumerable<double> estimatedValues, IEnumerable<double> originalValues) {
308      var errorState = OnlineCalculatorError.None;
309      double quality = 0.0;
310      switch (QualityMeasure) {
311        case QualityMeasures.PEARSON:
312          quality = OnlinePearsonsRSquaredCalculator.Calculate(estimatedValues, originalValues, out errorState);
313          break;
314        case QualityMeasures.MSE:
315          quality = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, originalValues, out errorState);
316          break;
317      }
318      return errorState == OnlineCalculatorError.None ? quality : double.NaN;
319    }
320
[10720]321    private double CalculateQuality(ISymbolicExpressionTree tree, IEnumerable<int> rows) {
322      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(tree, ProblemData.Dataset, rows);
323      var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows);
324      double quality = 0;
325      var errorState = new OnlineCalculatorError();
326      switch (QualityMeasure) {
327        case QualityMeasures.PEARSON:
328          quality = OnlinePearsonsRSquaredCalculator.Calculate(estimatedValues, originalValues, out errorState);
329          break;
330        case QualityMeasures.MSE:
331          quality = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, originalValues, out errorState);
332          break;
333      }
334      return errorState == OnlineCalculatorError.None ? quality : double.NaN;
335    }
[10396]336  }
[10720]337
[17687]338  [StorableType("170F5739-8D8C-4A44-9EA2-B28B35E97A3F")]
[10829]339  public sealed class SlidingWindowRange : IEquatable<SlidingWindowRange> {
[10828]340    [Storable]
[10720]341    private readonly Tuple<int, int> tuple;
342
343    public int Start { get { return tuple.Item1; } }
344    public int End { get { return tuple.Item2; } }
345
[10828]346    [StorableConstructor]
[17687]347    private SlidingWindowRange(StorableConstructorFlag _) { }
[10828]348    private SlidingWindowRange() { }
349
[10720]350    public SlidingWindowRange(int start, int end) {
351      if (start > end) throw new ArgumentException("SlidingWindowRange: Start cannot be greater than End.");
352      tuple = new Tuple<int, int>(start, end);
353    }
354
355    public bool Equals(SlidingWindowRange other) {
356      return tuple.Equals(other.tuple);
357    }
358
359    public override int GetHashCode() {
360      return tuple.GetHashCode();
361    }
362
363    public int Size {
364      get { return End - Start; }
365    }
366  }
[10396]367}
Note: See TracBrowser for help on using the repository browser.