#region License Information
/* HeuristicLab
* Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using HEAL.Attic;
namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
[Item("SlidingWindowBestSolutionsCollection", "An object holding a collection of the best sliding window solutions.")]
[StorableType("08DA042D-9A0E-4D7A-8ED7-AED6918D8EF3")]
public abstract class SlidingWindowBestSolutionsCollection : Item {
[Storable]
private List slidingWindowRanges;
public List SlidingWindowRanges {
get { return slidingWindowRanges; }
private set { slidingWindowRanges = value; }
}
[Storable(AllowOneWay = true, Name = "bestSolutions")]
private Dictionary, ISymbolicExpressionTree> StorableBestSolutions {
set {
var bestSolutions = value;
var ranges = bestSolutions.Keys.OrderBy(x => x.Item1).ToList();
slidingWindowRanges = ranges.Select(x => new SlidingWindowRange(x.Item1, x.Item2)).ToList();
slidingWindowBestSolutions = new Dictionary();
for (int i = 0; i < slidingWindowRanges.Count; ++i) {
slidingWindowBestSolutions.Add(slidingWindowRanges[i], bestSolutions[ranges[i]]);
}
}
}
[Storable]
private Dictionary slidingWindowBestSolutions;
public Dictionary SlidingWindowBestSolutions {
get { return slidingWindowBestSolutions; }
set { slidingWindowBestSolutions = value; }
}
[Storable]
private IDataAnalysisProblemData problemData;
public IDataAnalysisProblemData ProblemData {
get { return problemData; }
set { problemData = value; }
}
[Storable]
private ISymbolicDataAnalysisExpressionTreeInterpreter interpreter;
public ISymbolicDataAnalysisExpressionTreeInterpreter Interpreter {
get { return interpreter; }
set { interpreter = value; }
}
[Storable]
private bool applyLinearScaling;
public bool ApplyLinearScaling {
get { return applyLinearScaling; }
set { applyLinearScaling = value; }
}
[StorableHook(HookType.AfterDeserialization)]
private void AfterDeserialization() {
if (bw == null) {
bw = new BackgroundWorker();
bw.WorkerSupportsCancellation = true;
bw.WorkerReportsProgress = true;
bw.DoWork += CalculateQualities;
}
}
public double[,] SlidingWindowQualities { get; set; }
private BackgroundWorker bw;
public enum QualityMeasures { PEARSON, MSE };
private QualityMeasures qualityMeasure;
public QualityMeasures QualityMeasure {
get { return qualityMeasure; }
set {
if (qualityMeasure != value) {
qualityMeasure = value;
CalculateQualities();
}
}
}
public bool QualitiesCalculationInProgress {
get { return bw.IsBusy; }
}
public event ProgressChangedEventHandler QualitiesCalculationProgress {
add { bw.ProgressChanged += value; }
remove { bw.ProgressChanged -= value; }
}
public event RunWorkerCompletedEventHandler QualitiesCalculationCompleted {
add { bw.RunWorkerCompleted += value; }
remove { bw.RunWorkerCompleted -= value; }
}
public event EventHandler QualitiesCalculationStarted;
private void OnQualitiesCalculationStarted(object sender, EventArgs e) {
var started = QualitiesCalculationStarted;
if (started != null) {
started(sender, e);
}
}
public event EventHandler QualitiesUpdated;
private void OnQualitiesUpdated(object sender, EventArgs e) {
var updated = QualitiesUpdated;
if (updated != null) {
updated(sender, e);
}
}
[StorableConstructor]
protected SlidingWindowBestSolutionsCollection(StorableConstructorFlag _) : base(_) {
}
protected SlidingWindowBestSolutionsCollection(SlidingWindowBestSolutionsCollection original, Cloner cloner)
: base(original, cloner) {
this.slidingWindowBestSolutions = original.slidingWindowBestSolutions;
this.problemData = original.problemData;
this.interpreter = original.interpreter;
this.applyLinearScaling = original.ApplyLinearScaling;
}
protected SlidingWindowBestSolutionsCollection() {
slidingWindowBestSolutions = new Dictionary();
slidingWindowRanges = new List();
qualityMeasure = QualityMeasures.PEARSON;
bw = new BackgroundWorker();
bw.WorkerSupportsCancellation = true;
bw.WorkerReportsProgress = true;
bw.DoWork += CalculateQualities;
}
public bool ContainsKey(SlidingWindowRange key) {
return slidingWindowBestSolutions.ContainsKey(key);
}
public ISymbolicExpressionTree this[SlidingWindowRange key] {
get { return slidingWindowBestSolutions[key]; }
set {
AddSolution(key, value); // this should be fast so there's no need for a background worker
OnQualitiesUpdated(this, EventArgs.Empty);
}
}
public void Add(SlidingWindowRange range, ISymbolicExpressionTree solution) {
if (!slidingWindowBestSolutions.ContainsKey(range)) {
slidingWindowBestSolutions.Add(range, solution);
slidingWindowRanges.Add(range);
} else {
slidingWindowBestSolutions[range] = solution;
}
}
public void Clear() {
if (slidingWindowBestSolutions != null) slidingWindowBestSolutions.Clear();
if (slidingWindowRanges != null) slidingWindowRanges.Clear();
}
public abstract ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree,
ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue);
public abstract ISymbolicDataAnalysisSolution CreateSolution(ISymbolicDataAnalysisModel model,
IDataAnalysisProblemData problemData);
private void AddSolution(SlidingWindowRange range, ISymbolicExpressionTree solution) {
Add(range, solution);
var solutions = slidingWindowRanges.Select(x => slidingWindowBestSolutions[x]).ToList();
var nRows = solutions.Count;
var nCols = nRows + 1; // an extra column corresponding to the whole trainig partition
var trainingIndices = problemData.TrainingIndices.ToList();
var matrix = new double[nRows, nCols];
// copy old qualities into the new matrix
for (int i = 0; i < nRows - 1; ++i) {
for (int j = 0; j < nCols - 1; ++j) {
matrix[i, j] = SlidingWindowQualities[i, j];
}
}
// copy qualities of new solution into the new matrix
var rows = Enumerable.Range(slidingWindowRanges.First().Start, slidingWindowRanges.Last().End - slidingWindowRanges.First().Start).ToList();
var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows).ToList();
var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows).ToList();
for (int i = 0; i < nCols; ++i) {
if (i == nCols - 1) {
matrix[nRows - 1, i] = CalculateQuality(solution, trainingIndices);
} else {
var indices = Enumerable.Range(slidingWindowRanges[i].Start, slidingWindowRanges[i].Size).ToList();
var estimated = indices.Select(x => estimatedValues[x]);
var original = indices.Select(x => originalValues[x]);
matrix[nRows - 1, i] = CalculateQuality(estimated, original);
}
}
// shift old training qualities one column to the right
rows = Enumerable.Range(range.Start, range.Size).ToList();
for (int i = 0; i < nRows; ++i) {
matrix[i, nCols - 1] = matrix[i, nCols - 2];
matrix[i, nCols - 2] = CalculateQuality(solutions[i], rows);
}
// replace old matrix with new matrix
SlidingWindowQualities = matrix;
}
private void CalculateQualities(object sender, DoWorkEventArgs e) {
var worker = sender as BackgroundWorker;
if (worker == null) return;
if (worker.CancellationPending) {
e.Cancel = true;
return;
}
OnQualitiesCalculationStarted(this, EventArgs.Empty);
var ranges = SlidingWindowRanges;
var solutions = ranges.Select(x => SlidingWindowBestSolutions[x]).ToList();
int nRows = solutions.Count;
int nCols = ranges.Count + 1;
SlidingWindowQualities = new double[nRows, nCols];
var rows = Enumerable.Range(ranges.First().Start, ranges.Last().End - ranges.First().Start).ToList();
var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows).ToList();
for (int i = 0; i < nRows; ++i) {
if (worker.CancellationPending) {
e.Cancel = true;
return;
}
var solution = solutions[i];
var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows).ToList();
for (int j = 0; j < nCols; ++j) {
double q;
if (j == nCols - 1) {
q = CalculateQuality(solution, problemData.TrainingIndices);
} else {
var range = ranges[j];
var indices = Enumerable.Range(range.Start, range.Size).ToList();
var estimated = indices.Select(x => estimatedValues[x]);
var original = indices.Select(x => originalValues[x]);
q = CalculateQuality(estimated, original);
}
SlidingWindowQualities[i, j] = q;
}
worker.ReportProgress((int)Math.Round(i * 100.0 / nRows));
}
}
public void CalculateQualities() {
bw.RunWorkerAsync();
}
private string GetTargetVariable(IDataAnalysisProblemData problemData) {
var regressionProblemData = problemData as IRegressionProblemData;
var classificationProblemData = problemData as IClassificationProblemData;
if (regressionProblemData != null) return regressionProblemData.TargetVariable;
if (classificationProblemData != null) return classificationProblemData.TargetVariable;
throw new NotSupportedException();
}
private double CalculateQuality(IEnumerable estimatedValues, IEnumerable originalValues) {
var errorState = OnlineCalculatorError.None;
double quality = 0.0;
switch (QualityMeasure) {
case QualityMeasures.PEARSON:
quality = OnlinePearsonsRSquaredCalculator.Calculate(estimatedValues, originalValues, out errorState);
break;
case QualityMeasures.MSE:
quality = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, originalValues, out errorState);
break;
}
return errorState == OnlineCalculatorError.None ? quality : double.NaN;
}
private double CalculateQuality(ISymbolicExpressionTree tree, IEnumerable rows) {
var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(tree, ProblemData.Dataset, rows);
var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows);
double quality = 0;
var errorState = new OnlineCalculatorError();
switch (QualityMeasure) {
case QualityMeasures.PEARSON:
quality = OnlinePearsonsRSquaredCalculator.Calculate(estimatedValues, originalValues, out errorState);
break;
case QualityMeasures.MSE:
quality = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, originalValues, out errorState);
break;
}
return errorState == OnlineCalculatorError.None ? quality : double.NaN;
}
}
[StorableType("170F5739-8D8C-4A44-9EA2-B28B35E97A3F")]
public sealed class SlidingWindowRange : IEquatable {
[Storable]
private readonly Tuple tuple;
public int Start { get { return tuple.Item1; } }
public int End { get { return tuple.Item2; } }
[StorableConstructor]
private SlidingWindowRange(StorableConstructorFlag _) { }
private SlidingWindowRange() { }
public SlidingWindowRange(int start, int end) {
if (start > end) throw new ArgumentException("SlidingWindowRange: Start cannot be greater than End.");
tuple = new Tuple(start, end);
}
public bool Equals(SlidingWindowRange other) {
return tuple.Equals(other.tuple);
}
public override int GetHashCode() {
return tuple.GetHashCode();
}
public int Size {
get { return End - Start; }
}
}
}