Free cookie consent management tool by TermsFeed Policy Generator

source: branches/Sliding Window GP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SlidingWindow/SlidingWindowBestSolutionsCollection.cs @ 10721

Last change on this file since 10721 was 10721, checked in by bburlacu, 9 years ago

#1837: Performance and usability improvements.

File size: 13.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.ComponentModel;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  [StorableClass]
33  [Item("SlidingWindowBestSolutionsCollection", "An object holding a collection of the best sliding window solutions.")]
34  public abstract class SlidingWindowBestSolutionsCollection : Item {
35    [Storable]
36    private List<SlidingWindowRange> slidingWindowRanges;
37
38    public List<SlidingWindowRange> SlidingWindowRanges {
39      get { return slidingWindowRanges; }
40      private set { slidingWindowRanges = value; }
41    }
42
43    [Storable(AllowOneWay = true, Name = "bestSolutions")]
44    private Dictionary<Tuple<int, int>, ISymbolicExpressionTree> StorableBestSolutions {
45      set {
46        var bestSolutions = value;
47        var ranges = bestSolutions.Keys.OrderBy(x => x.Item1).ToList();
48        slidingWindowRanges = ranges.Select(x => new SlidingWindowRange(x.Item1, x.Item2)).ToList();
49        slidingWindowBestSolutions = new Dictionary<SlidingWindowRange, ISymbolicExpressionTree>();
50        for (int i = 0; i < slidingWindowRanges.Count; ++i) {
51          slidingWindowBestSolutions.Add(slidingWindowRanges[i], bestSolutions[ranges[i]]);
52        }
53      }
54    }
55
56    [Storable]
57    private Dictionary<SlidingWindowRange, ISymbolicExpressionTree> slidingWindowBestSolutions;
58
59    public Dictionary<SlidingWindowRange, ISymbolicExpressionTree> SlidingWindowBestSolutions {
60      get { return slidingWindowBestSolutions; }
61      set { slidingWindowBestSolutions = value; }
62    }
63
64    [Storable]
65    private IDataAnalysisProblemData problemData;
66
67    public IDataAnalysisProblemData ProblemData {
68      get { return problemData; }
69      set { problemData = value; }
70    }
71
72    [Storable]
73    private ISymbolicDataAnalysisExpressionTreeInterpreter interpreter;
74
75    public ISymbolicDataAnalysisExpressionTreeInterpreter Interpreter {
76      get { return interpreter; }
77      set { interpreter = value; }
78    }
79
80    [Storable]
81    private bool applyLinearScaling;
82
83    public bool ApplyLinearScaling {
84      get { return applyLinearScaling; }
85      set { applyLinearScaling = value; }
86    }
87
88    [StorableHook(HookType.AfterDeserialization)]
89    private void AfterDeserialization() {
90      if (bw == null) {
91        bw = new BackgroundWorker();
92        bw.WorkerSupportsCancellation = true;
93        bw.WorkerReportsProgress = true;
94        bw.DoWork += CalculateQualities;
95      }
96    }
97
98    public double[,] SlidingWindowQualities { get; set; }
99
100    private BackgroundWorker bw;
101
102    public enum QualityMeasures { PEARSON, MSE };
103
104    private QualityMeasures qualityMeasure;
105
106    public QualityMeasures QualityMeasure {
107      get { return qualityMeasure; }
108      set {
109        if (qualityMeasure != value) {
110          qualityMeasure = value;
111          CalculateQualities();
112        }
113      }
114    }
115
116    public bool QualitiesCalculationInProgress {
117      get { return bw.IsBusy; }
118    }
119
120    public event ProgressChangedEventHandler QualitiesCalculationProgress {
121      add { bw.ProgressChanged += value; }
122      remove { bw.ProgressChanged -= value; }
123    }
124
125    public event RunWorkerCompletedEventHandler QualitiesCalculationCompleted {
126      add { bw.RunWorkerCompleted += value; }
127      remove { bw.RunWorkerCompleted -= value; }
128    }
129
130    public event EventHandler QualitiesCalculationStarted;
131
132    private void OnQualitiesCalculationStarted(object sender, EventArgs e) {
133      var started = QualitiesCalculationStarted;
134      if (started != null) {
135        started(sender, e);
136      }
137    }
138
139    public event EventHandler QualitiesUpdated;
140    private void OnQualitiesUpdated(object sender, EventArgs e) {
141      var updated = QualitiesUpdated;
142      if (updated != null) {
143        updated(sender, e);
144      }
145    }
146
147    [StorableConstructor]
148    protected SlidingWindowBestSolutionsCollection(bool deserializing)
149      : base(deserializing) {
150    }
151
152    protected SlidingWindowBestSolutionsCollection(SlidingWindowBestSolutionsCollection original, Cloner cloner)
153      : base(original, cloner) {
154      this.slidingWindowBestSolutions = original.slidingWindowBestSolutions;
155      this.problemData = original.problemData;
156      this.interpreter = original.interpreter;
157      this.applyLinearScaling = original.ApplyLinearScaling;
158    }
159
160    protected SlidingWindowBestSolutionsCollection() {
161      slidingWindowBestSolutions = new Dictionary<SlidingWindowRange, ISymbolicExpressionTree>();
162      slidingWindowRanges = new List<SlidingWindowRange>();
163      qualityMeasure = QualityMeasures.PEARSON;
164
165      bw = new BackgroundWorker();
166      bw.WorkerSupportsCancellation = true;
167      bw.WorkerReportsProgress = true;
168
169      bw.DoWork += CalculateQualities;
170    }
171
172    public bool ContainsKey(SlidingWindowRange key) {
173      return slidingWindowBestSolutions.ContainsKey(key);
174    }
175
176    public ISymbolicExpressionTree this[SlidingWindowRange key] {
177      get { return slidingWindowBestSolutions[key]; }
178      set {
179        AddSolution(key, value); // this should be fast so there's no need for a background worker
180        OnQualitiesUpdated(this, EventArgs.Empty);
181      }
182    }
183
184    public void Add(SlidingWindowRange range, ISymbolicExpressionTree solution) {
185      if (!slidingWindowBestSolutions.ContainsKey(range)) {
186        slidingWindowBestSolutions.Add(range, solution);
187        slidingWindowRanges.Add(range);
188      } else {
189        slidingWindowBestSolutions[range] = solution;
190      }
191    }
192
193    public void Clear() {
194      if (slidingWindowBestSolutions != null) slidingWindowBestSolutions.Clear();
195      if (slidingWindowRanges != null) slidingWindowRanges.Clear();
196    }
197
198    public abstract ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree,
199      ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
200      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue);
201
202    public abstract ISymbolicDataAnalysisSolution CreateSolution(ISymbolicDataAnalysisModel model,
203      IDataAnalysisProblemData problemData);
204
205    private void AddSolution(SlidingWindowRange range, ISymbolicExpressionTree solution) {
206      Add(range, solution);
207
208      var solutions = slidingWindowRanges.Select(x => slidingWindowBestSolutions[x]).ToList();
209
210      var nRows = solutions.Count;
211      var nCols = nRows + 1; // an extra column corresponding to the whole trainig partition
212
213      var trainingIndices = problemData.TrainingIndices.ToList();
214      var matrix = new double[nRows, nCols];
215
216      // copy old qualities into the new matrix
217      for (int i = 0; i < nRows - 1; ++i) {
218        for (int j = 0; j < nCols - 1; ++j) {
219          matrix[i, j] = SlidingWindowQualities[i, j];
220        }
221      }
222      // copy qualities of new solution into the new matrix
223      var rows = Enumerable.Range(slidingWindowRanges.First().Start, slidingWindowRanges.Last().End - slidingWindowRanges.First().Start).ToList();
224      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows).ToList();
225      var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows).ToList();
226      for (int i = 0; i < nCols; ++i) {
227        if (i == nCols - 1) {
228          matrix[nRows - 1, i] = CalculateQuality(solution, trainingIndices);
229        } else {
230          var indices = Enumerable.Range(slidingWindowRanges[i].Start, slidingWindowRanges[i].Size).ToList();
231          var estimated = indices.Select(x => estimatedValues[x]);
232          var original = indices.Select(x => originalValues[x]);
233          matrix[nRows - 1, i] = CalculateQuality(estimated, original);
234        }
235      }
236      // shift old training qualities one column to the right
237      rows = Enumerable.Range(range.Start, range.Size).ToList();
238      for (int i = 0; i < nRows; ++i) {
239        matrix[i, nCols - 1] = matrix[i, nCols - 2];
240        matrix[i, nCols - 2] = CalculateQuality(solutions[i], rows);
241      }
242      // replace old matrix with new matrix
243      SlidingWindowQualities = matrix;
244    }
245
246    private void CalculateQualities(object sender, DoWorkEventArgs e) {
247      var worker = sender as BackgroundWorker;
248      if (worker == null) return;
249      if (worker.CancellationPending) {
250        e.Cancel = true;
251        return;
252      }
253
254      OnQualitiesCalculationStarted(this, EventArgs.Empty);
255
256      var ranges = SlidingWindowRanges;
257      var solutions = ranges.Select(x => SlidingWindowBestSolutions[x]).ToList();
258
259      int nRows = solutions.Count;
260      int nCols = ranges.Count + 1;
261
262      SlidingWindowQualities = new double[nRows, nCols];
263      var rows = Enumerable.Range(ranges.First().Start, ranges.Last().End - ranges.First().Start).ToList();
264      var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows).ToList();
265
266      for (int i = 0; i < nRows; ++i) {
267        if (worker.CancellationPending) {
268          e.Cancel = true;
269          return;
270        }
271
272        var solution = solutions[i];
273        var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows).ToList();
274
275        for (int j = 0; j < nCols; ++j) {
276          double q;
277          if (j == nCols - 1) {
278            q = CalculateQuality(solution, problemData.TrainingIndices);
279          } else {
280            var range = ranges[j];
281            var indices = Enumerable.Range(range.Start, range.Size).ToList();
282            var estimated = indices.Select(x => estimatedValues[x]);
283            var original = indices.Select(x => originalValues[x]);
284
285            q = CalculateQuality(estimated, original);
286          }
287
288          SlidingWindowQualities[i, j] = q;
289        }
290
291        worker.ReportProgress((int)Math.Round(i * 100.0 / nRows));
292      }
293    }
294
295    public void CalculateQualities() {
296      bw.RunWorkerAsync();
297    }
298
299    private string GetTargetVariable(IDataAnalysisProblemData problemData) {
300      var regressionProblemData = problemData as IRegressionProblemData;
301      var classificationProblemData = problemData as IClassificationProblemData;
302      if (regressionProblemData != null) return regressionProblemData.TargetVariable;
303      if (classificationProblemData != null) return classificationProblemData.TargetVariable;
304      throw new NotSupportedException();
305    }
306
307    private double CalculateQuality(IEnumerable<double> estimatedValues, IEnumerable<double> originalValues) {
308      var errorState = OnlineCalculatorError.None;
309      double quality = 0.0;
310      switch (QualityMeasure) {
311        case QualityMeasures.PEARSON:
312          quality = OnlinePearsonsRSquaredCalculator.Calculate(estimatedValues, originalValues, out errorState);
313          break;
314        case QualityMeasures.MSE:
315          quality = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, originalValues, out errorState);
316          break;
317      }
318      return errorState == OnlineCalculatorError.None ? quality : double.NaN;
319    }
320
321    private double CalculateQuality(ISymbolicExpressionTree tree, IEnumerable<int> rows) {
322      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(tree, ProblemData.Dataset, rows);
323      var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows);
324      double quality = 0;
325      var errorState = new OnlineCalculatorError();
326      switch (QualityMeasure) {
327        case QualityMeasures.PEARSON:
328          quality = OnlinePearsonsRSquaredCalculator.Calculate(estimatedValues, originalValues, out errorState);
329          break;
330        case QualityMeasures.MSE:
331          quality = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, originalValues, out errorState);
332          break;
333      }
334      return errorState == OnlineCalculatorError.None ? quality : double.NaN;
335    }
336  }
337
338  public class SlidingWindowRange : IEquatable<SlidingWindowRange> {
339    private readonly Tuple<int, int> tuple;
340
341    public int Start { get { return tuple.Item1; } }
342
343    public int End { get { return tuple.Item2; } }
344
345    public SlidingWindowRange(int start, int end) {
346      if (start > end) throw new ArgumentException("SlidingWindowRange: Start cannot be greater than End.");
347      tuple = new Tuple<int, int>(start, end);
348    }
349
350    public bool Equals(SlidingWindowRange other) {
351      return tuple.Equals(other.tuple);
352    }
353
354    public override int GetHashCode() {
355      return tuple.GetHashCode();
356    }
357
358    public int Size {
359      get { return End - Start; }
360    }
361  }
362}
Note: See TracBrowser for help on using the repository browser.