Free cookie consent management tool by TermsFeed Policy Generator

source: branches/Sliding Window GP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SlidingWindow/SlidingWindowBestSolutionsCollection.cs @ 12062

Last change on this file since 12062 was 10829, checked in by bburlacu, 11 years ago

#1837: Sealed SlidingWindowRange class.

File size: 13.7 KB
RevLine 
[10396]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[10720]22using System;
[10396]23using System.Collections.Generic;
[10720]24using System.ComponentModel;
25using System.Linq;
[10396]26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  [StorableClass]
[10398]33  [Item("SlidingWindowBestSolutionsCollection", "An object holding a collection of the best sliding window solutions.")]
[10396]34  public abstract class SlidingWindowBestSolutionsCollection : Item {
35    [Storable]
[10720]36    private List<SlidingWindowRange> slidingWindowRanges;
[10721]37
[10720]38    public List<SlidingWindowRange> SlidingWindowRanges {
39      get { return slidingWindowRanges; }
40      private set { slidingWindowRanges = value; }
[10396]41    }
42
[10720]43    [Storable(AllowOneWay = true, Name = "bestSolutions")]
44    private Dictionary<Tuple<int, int>, ISymbolicExpressionTree> StorableBestSolutions {
45      set {
46        var bestSolutions = value;
47        var ranges = bestSolutions.Keys.OrderBy(x => x.Item1).ToList();
48        slidingWindowRanges = ranges.Select(x => new SlidingWindowRange(x.Item1, x.Item2)).ToList();
49        slidingWindowBestSolutions = new Dictionary<SlidingWindowRange, ISymbolicExpressionTree>();
50        for (int i = 0; i < slidingWindowRanges.Count; ++i) {
51          slidingWindowBestSolutions.Add(slidingWindowRanges[i], bestSolutions[ranges[i]]);
52        }
53      }
54    }
55
[10396]56    [Storable]
[10720]57    private Dictionary<SlidingWindowRange, ISymbolicExpressionTree> slidingWindowBestSolutions;
[10721]58
[10720]59    public Dictionary<SlidingWindowRange, ISymbolicExpressionTree> SlidingWindowBestSolutions {
60      get { return slidingWindowBestSolutions; }
61      set { slidingWindowBestSolutions = value; }
62    }
63
64    [Storable]
[10398]65    private IDataAnalysisProblemData problemData;
[10721]66
[10398]67    public IDataAnalysisProblemData ProblemData {
68      get { return problemData; }
69      set { problemData = value; }
70    }
71
72    [Storable]
73    private ISymbolicDataAnalysisExpressionTreeInterpreter interpreter;
[10721]74
[10398]75    public ISymbolicDataAnalysisExpressionTreeInterpreter Interpreter {
76      get { return interpreter; }
77      set { interpreter = value; }
78    }
[10720]79
80    [Storable]
81    private bool applyLinearScaling;
[10721]82
[10720]83    public bool ApplyLinearScaling {
84      get { return applyLinearScaling; }
85      set { applyLinearScaling = value; }
86    }
87
[10402]88    [StorableHook(HookType.AfterDeserialization)]
89    private void AfterDeserialization() {
[10720]90      if (bw == null) {
91        bw = new BackgroundWorker();
92        bw.WorkerSupportsCancellation = true;
93        bw.WorkerReportsProgress = true;
94        bw.DoWork += CalculateQualities;
95      }
[10402]96    }
[10398]97
[10720]98    public double[,] SlidingWindowQualities { get; set; }
99
100    private BackgroundWorker bw;
101
102    public enum QualityMeasures { PEARSON, MSE };
103
104    private QualityMeasures qualityMeasure;
[10721]105
[10720]106    public QualityMeasures QualityMeasure {
107      get { return qualityMeasure; }
108      set {
109        if (qualityMeasure != value) {
110          qualityMeasure = value;
111          CalculateQualities();
112        }
113      }
114    }
115
116    public bool QualitiesCalculationInProgress {
117      get { return bw.IsBusy; }
118    }
119
120    public event ProgressChangedEventHandler QualitiesCalculationProgress {
121      add { bw.ProgressChanged += value; }
122      remove { bw.ProgressChanged -= value; }
123    }
124
125    public event RunWorkerCompletedEventHandler QualitiesCalculationCompleted {
126      add { bw.RunWorkerCompleted += value; }
127      remove { bw.RunWorkerCompleted -= value; }
128    }
129
[10721]130    public event EventHandler QualitiesCalculationStarted;
131
132    private void OnQualitiesCalculationStarted(object sender, EventArgs e) {
133      var started = QualitiesCalculationStarted;
134      if (started != null) {
135        started(sender, e);
136      }
137    }
138
[10720]139    public event EventHandler QualitiesUpdated;
140    private void OnQualitiesUpdated(object sender, EventArgs e) {
141      var updated = QualitiesUpdated;
[10721]142      if (updated != null) {
143        updated(sender, e);
144      }
[10720]145    }
146
[10402]147    [StorableConstructor]
[10721]148    protected SlidingWindowBestSolutionsCollection(bool deserializing)
149      : base(deserializing) {
150    }
151
[10396]152    protected SlidingWindowBestSolutionsCollection(SlidingWindowBestSolutionsCollection original, Cloner cloner)
153      : base(original, cloner) {
[10720]154      this.slidingWindowBestSolutions = original.slidingWindowBestSolutions;
[10402]155      this.problemData = original.problemData;
156      this.interpreter = original.interpreter;
[10720]157      this.applyLinearScaling = original.ApplyLinearScaling;
[10396]158    }
[10721]159
[10396]160    protected SlidingWindowBestSolutionsCollection() {
[10720]161      slidingWindowBestSolutions = new Dictionary<SlidingWindowRange, ISymbolicExpressionTree>();
162      slidingWindowRanges = new List<SlidingWindowRange>();
163      qualityMeasure = QualityMeasures.PEARSON;
164
165      bw = new BackgroundWorker();
166      bw.WorkerSupportsCancellation = true;
167      bw.WorkerReportsProgress = true;
168
169      bw.DoWork += CalculateQualities;
[10396]170    }
171
[10413]172    public bool ContainsKey(SlidingWindowRange key) {
[10720]173      return slidingWindowBestSolutions.ContainsKey(key);
[10396]174    }
175
[10413]176    public ISymbolicExpressionTree this[SlidingWindowRange key] {
[10721]177      get { return slidingWindowBestSolutions[key]; }
[10413]178      set {
[10720]179        AddSolution(key, value); // this should be fast so there's no need for a background worker
180        OnQualitiesUpdated(this, EventArgs.Empty);
[10413]181      }
182    }
183
184    public void Add(SlidingWindowRange range, ISymbolicExpressionTree solution) {
[10720]185      if (!slidingWindowBestSolutions.ContainsKey(range)) {
186        slidingWindowBestSolutions.Add(range, solution);
187        slidingWindowRanges.Add(range);
188      } else {
189        slidingWindowBestSolutions[range] = solution;
190      }
[10413]191    }
192
[10396]193    public void Clear() {
[10720]194      if (slidingWindowBestSolutions != null) slidingWindowBestSolutions.Clear();
195      if (slidingWindowRanges != null) slidingWindowRanges.Clear();
[10396]196    }
197
[10721]198    public abstract ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree,
199      ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
200      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue);
[10396]201
[10721]202    public abstract ISymbolicDataAnalysisSolution CreateSolution(ISymbolicDataAnalysisModel model,
203      IDataAnalysisProblemData problemData);
[10720]204
205    private void AddSolution(SlidingWindowRange range, ISymbolicExpressionTree solution) {
206      Add(range, solution);
207
208      var solutions = slidingWindowRanges.Select(x => slidingWindowBestSolutions[x]).ToList();
209
210      var nRows = solutions.Count;
211      var nCols = nRows + 1; // an extra column corresponding to the whole trainig partition
212
213      var trainingIndices = problemData.TrainingIndices.ToList();
214      var matrix = new double[nRows, nCols];
[10721]215
[10720]216      // copy old qualities into the new matrix
217      for (int i = 0; i < nRows - 1; ++i) {
218        for (int j = 0; j < nCols - 1; ++j) {
219          matrix[i, j] = SlidingWindowQualities[i, j];
220        }
221      }
222      // copy qualities of new solution into the new matrix
[10721]223      var rows = Enumerable.Range(slidingWindowRanges.First().Start, slidingWindowRanges.Last().End - slidingWindowRanges.First().Start).ToList();
224      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows).ToList();
225      var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows).ToList();
[10720]226      for (int i = 0; i < nCols; ++i) {
[10721]227        if (i == nCols - 1) {
228          matrix[nRows - 1, i] = CalculateQuality(solution, trainingIndices);
229        } else {
230          var indices = Enumerable.Range(slidingWindowRanges[i].Start, slidingWindowRanges[i].Size).ToList();
231          var estimated = indices.Select(x => estimatedValues[x]);
232          var original = indices.Select(x => originalValues[x]);
233          matrix[nRows - 1, i] = CalculateQuality(estimated, original);
234        }
[10720]235      }
236      // shift old training qualities one column to the right
237      rows = Enumerable.Range(range.Start, range.Size).ToList();
238      for (int i = 0; i < nRows; ++i) {
239        matrix[i, nCols - 1] = matrix[i, nCols - 2];
240        matrix[i, nCols - 2] = CalculateQuality(solutions[i], rows);
241      }
242      // replace old matrix with new matrix
243      SlidingWindowQualities = matrix;
244    }
245
246    private void CalculateQualities(object sender, DoWorkEventArgs e) {
247      var worker = sender as BackgroundWorker;
248      if (worker == null) return;
249      if (worker.CancellationPending) {
250        e.Cancel = true;
251        return;
252      }
253
[10721]254      OnQualitiesCalculationStarted(this, EventArgs.Empty);
255
[10720]256      var ranges = SlidingWindowRanges;
257      var solutions = ranges.Select(x => SlidingWindowBestSolutions[x]).ToList();
258
[10721]259      int nRows = solutions.Count;
260      int nCols = ranges.Count + 1;
[10720]261
[10721]262      SlidingWindowQualities = new double[nRows, nCols];
263      var rows = Enumerable.Range(ranges.First().Start, ranges.Last().End - ranges.First().Start).ToList();
264      var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows).ToList();
[10720]265
[10721]266      for (int i = 0; i < nRows; ++i) {
[10720]267        if (worker.CancellationPending) {
268          e.Cancel = true;
269          return;
270        }
271
272        var solution = solutions[i];
[10721]273        var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows).ToList();
[10720]274
[10721]275        for (int j = 0; j < nCols; ++j) {
276          double q;
277          if (j == nCols - 1) {
278            q = CalculateQuality(solution, problemData.TrainingIndices);
279          } else {
280            var range = ranges[j];
281            var indices = Enumerable.Range(range.Start, range.Size).ToList();
282            var estimated = indices.Select(x => estimatedValues[x]);
283            var original = indices.Select(x => originalValues[x]);
284
285            q = CalculateQuality(estimated, original);
286          }
287
[10720]288          SlidingWindowQualities[i, j] = q;
289        }
290
[10721]291        worker.ReportProgress((int)Math.Round(i * 100.0 / nRows));
[10720]292      }
293    }
294
295    public void CalculateQualities() {
296      bw.RunWorkerAsync();
297    }
298
299    private string GetTargetVariable(IDataAnalysisProblemData problemData) {
300      var regressionProblemData = problemData as IRegressionProblemData;
301      var classificationProblemData = problemData as IClassificationProblemData;
302      if (regressionProblemData != null) return regressionProblemData.TargetVariable;
303      if (classificationProblemData != null) return classificationProblemData.TargetVariable;
304      throw new NotSupportedException();
305    }
306
[10721]307    private double CalculateQuality(IEnumerable<double> estimatedValues, IEnumerable<double> originalValues) {
308      var errorState = OnlineCalculatorError.None;
309      double quality = 0.0;
310      switch (QualityMeasure) {
311        case QualityMeasures.PEARSON:
312          quality = OnlinePearsonsRSquaredCalculator.Calculate(estimatedValues, originalValues, out errorState);
313          break;
314        case QualityMeasures.MSE:
315          quality = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, originalValues, out errorState);
316          break;
317      }
318      return errorState == OnlineCalculatorError.None ? quality : double.NaN;
319    }
320
[10720]321    private double CalculateQuality(ISymbolicExpressionTree tree, IEnumerable<int> rows) {
322      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(tree, ProblemData.Dataset, rows);
323      var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows);
324      double quality = 0;
325      var errorState = new OnlineCalculatorError();
326      switch (QualityMeasure) {
327        case QualityMeasures.PEARSON:
328          quality = OnlinePearsonsRSquaredCalculator.Calculate(estimatedValues, originalValues, out errorState);
329          break;
330        case QualityMeasures.MSE:
331          quality = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, originalValues, out errorState);
332          break;
333      }
334      return errorState == OnlineCalculatorError.None ? quality : double.NaN;
335    }
[10396]336  }
[10720]337
[10828]338  [StorableClass]
[10829]339  public sealed class SlidingWindowRange : IEquatable<SlidingWindowRange> {
[10828]340    [Storable]
[10720]341    private readonly Tuple<int, int> tuple;
342
343    public int Start { get { return tuple.Item1; } }
344    public int End { get { return tuple.Item2; } }
345
[10828]346    [StorableConstructor]
347    private SlidingWindowRange(bool deserializable) { }
348    private SlidingWindowRange() { }
349
[10720]350    public SlidingWindowRange(int start, int end) {
351      if (start > end) throw new ArgumentException("SlidingWindowRange: Start cannot be greater than End.");
352      tuple = new Tuple<int, int>(start, end);
353    }
354
355    public bool Equals(SlidingWindowRange other) {
356      return tuple.Equals(other.tuple);
357    }
358
359    public override int GetHashCode() {
360      return tuple.GetHashCode();
361    }
362
363    public int Size {
364      get { return End - Start; }
365    }
366  }
[10396]367}
Note: See TracBrowser for help on using the repository browser.