#region License Information /* HeuristicLab * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.ComponentModel; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HEAL.Attic; namespace HeuristicLab.Problems.DataAnalysis.Symbolic { [Item("SlidingWindowBestSolutionsCollection", "An object holding a collection of the best sliding window solutions.")] [StorableType("08DA042D-9A0E-4D7A-8ED7-AED6918D8EF3")] public abstract class SlidingWindowBestSolutionsCollection : Item { [Storable] private List slidingWindowRanges; public List SlidingWindowRanges { get { return slidingWindowRanges; } private set { slidingWindowRanges = value; } } [Storable(AllowOneWay = true, Name = "bestSolutions")] private Dictionary, ISymbolicExpressionTree> StorableBestSolutions { set { var bestSolutions = value; var ranges = bestSolutions.Keys.OrderBy(x => x.Item1).ToList(); slidingWindowRanges = ranges.Select(x => new SlidingWindowRange(x.Item1, x.Item2)).ToList(); slidingWindowBestSolutions = new Dictionary(); for (int i = 0; i < slidingWindowRanges.Count; ++i) { slidingWindowBestSolutions.Add(slidingWindowRanges[i], bestSolutions[ranges[i]]); } } } [Storable] private Dictionary slidingWindowBestSolutions; public Dictionary SlidingWindowBestSolutions { get { return slidingWindowBestSolutions; } set { slidingWindowBestSolutions = value; } } [Storable] private IDataAnalysisProblemData problemData; public IDataAnalysisProblemData ProblemData { get { return problemData; } set { problemData = value; } } [Storable] private ISymbolicDataAnalysisExpressionTreeInterpreter interpreter; public ISymbolicDataAnalysisExpressionTreeInterpreter Interpreter { get { return interpreter; } set { interpreter = value; } } [Storable] private bool applyLinearScaling; public bool ApplyLinearScaling { get { return applyLinearScaling; } set { applyLinearScaling = value; } } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { if (bw == null) { bw = new BackgroundWorker(); bw.WorkerSupportsCancellation = true; bw.WorkerReportsProgress = true; bw.DoWork += CalculateQualities; } } public double[,] SlidingWindowQualities { get; set; } private BackgroundWorker bw; public enum QualityMeasures { PEARSON, MSE }; private QualityMeasures qualityMeasure; public QualityMeasures QualityMeasure { get { return qualityMeasure; } set { if (qualityMeasure != value) { qualityMeasure = value; CalculateQualities(); } } } public bool QualitiesCalculationInProgress { get { return bw.IsBusy; } } public event ProgressChangedEventHandler QualitiesCalculationProgress { add { bw.ProgressChanged += value; } remove { bw.ProgressChanged -= value; } } public event RunWorkerCompletedEventHandler QualitiesCalculationCompleted { add { bw.RunWorkerCompleted += value; } remove { bw.RunWorkerCompleted -= value; } } public event EventHandler QualitiesCalculationStarted; private void OnQualitiesCalculationStarted(object sender, EventArgs e) { var started = QualitiesCalculationStarted; if (started != null) { started(sender, e); } } public event EventHandler QualitiesUpdated; private void OnQualitiesUpdated(object sender, EventArgs e) { var updated = QualitiesUpdated; if (updated != null) { updated(sender, e); } } [StorableConstructor] protected SlidingWindowBestSolutionsCollection(StorableConstructorFlag _) : base(_) { } protected SlidingWindowBestSolutionsCollection(SlidingWindowBestSolutionsCollection original, Cloner cloner) : base(original, cloner) { this.slidingWindowBestSolutions = original.slidingWindowBestSolutions; this.problemData = original.problemData; this.interpreter = original.interpreter; this.applyLinearScaling = original.ApplyLinearScaling; } protected SlidingWindowBestSolutionsCollection() { slidingWindowBestSolutions = new Dictionary(); slidingWindowRanges = new List(); qualityMeasure = QualityMeasures.PEARSON; bw = new BackgroundWorker(); bw.WorkerSupportsCancellation = true; bw.WorkerReportsProgress = true; bw.DoWork += CalculateQualities; } public bool ContainsKey(SlidingWindowRange key) { return slidingWindowBestSolutions.ContainsKey(key); } public ISymbolicExpressionTree this[SlidingWindowRange key] { get { return slidingWindowBestSolutions[key]; } set { AddSolution(key, value); // this should be fast so there's no need for a background worker OnQualitiesUpdated(this, EventArgs.Empty); } } public void Add(SlidingWindowRange range, ISymbolicExpressionTree solution) { if (!slidingWindowBestSolutions.ContainsKey(range)) { slidingWindowBestSolutions.Add(range, solution); slidingWindowRanges.Add(range); } else { slidingWindowBestSolutions[range] = solution; } } public void Clear() { if (slidingWindowBestSolutions != null) slidingWindowBestSolutions.Clear(); if (slidingWindowRanges != null) slidingWindowRanges.Clear(); } public abstract ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue); public abstract ISymbolicDataAnalysisSolution CreateSolution(ISymbolicDataAnalysisModel model, IDataAnalysisProblemData problemData); private void AddSolution(SlidingWindowRange range, ISymbolicExpressionTree solution) { Add(range, solution); var solutions = slidingWindowRanges.Select(x => slidingWindowBestSolutions[x]).ToList(); var nRows = solutions.Count; var nCols = nRows + 1; // an extra column corresponding to the whole trainig partition var trainingIndices = problemData.TrainingIndices.ToList(); var matrix = new double[nRows, nCols]; // copy old qualities into the new matrix for (int i = 0; i < nRows - 1; ++i) { for (int j = 0; j < nCols - 1; ++j) { matrix[i, j] = SlidingWindowQualities[i, j]; } } // copy qualities of new solution into the new matrix var rows = Enumerable.Range(slidingWindowRanges.First().Start, slidingWindowRanges.Last().End - slidingWindowRanges.First().Start).ToList(); var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows).ToList(); var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows).ToList(); for (int i = 0; i < nCols; ++i) { if (i == nCols - 1) { matrix[nRows - 1, i] = CalculateQuality(solution, trainingIndices); } else { var indices = Enumerable.Range(slidingWindowRanges[i].Start, slidingWindowRanges[i].Size).ToList(); var estimated = indices.Select(x => estimatedValues[x]); var original = indices.Select(x => originalValues[x]); matrix[nRows - 1, i] = CalculateQuality(estimated, original); } } // shift old training qualities one column to the right rows = Enumerable.Range(range.Start, range.Size).ToList(); for (int i = 0; i < nRows; ++i) { matrix[i, nCols - 1] = matrix[i, nCols - 2]; matrix[i, nCols - 2] = CalculateQuality(solutions[i], rows); } // replace old matrix with new matrix SlidingWindowQualities = matrix; } private void CalculateQualities(object sender, DoWorkEventArgs e) { var worker = sender as BackgroundWorker; if (worker == null) return; if (worker.CancellationPending) { e.Cancel = true; return; } OnQualitiesCalculationStarted(this, EventArgs.Empty); var ranges = SlidingWindowRanges; var solutions = ranges.Select(x => SlidingWindowBestSolutions[x]).ToList(); int nRows = solutions.Count; int nCols = ranges.Count + 1; SlidingWindowQualities = new double[nRows, nCols]; var rows = Enumerable.Range(ranges.First().Start, ranges.Last().End - ranges.First().Start).ToList(); var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows).ToList(); for (int i = 0; i < nRows; ++i) { if (worker.CancellationPending) { e.Cancel = true; return; } var solution = solutions[i]; var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows).ToList(); for (int j = 0; j < nCols; ++j) { double q; if (j == nCols - 1) { q = CalculateQuality(solution, problemData.TrainingIndices); } else { var range = ranges[j]; var indices = Enumerable.Range(range.Start, range.Size).ToList(); var estimated = indices.Select(x => estimatedValues[x]); var original = indices.Select(x => originalValues[x]); q = CalculateQuality(estimated, original); } SlidingWindowQualities[i, j] = q; } worker.ReportProgress((int)Math.Round(i * 100.0 / nRows)); } } public void CalculateQualities() { bw.RunWorkerAsync(); } private string GetTargetVariable(IDataAnalysisProblemData problemData) { var regressionProblemData = problemData as IRegressionProblemData; var classificationProblemData = problemData as IClassificationProblemData; if (regressionProblemData != null) return regressionProblemData.TargetVariable; if (classificationProblemData != null) return classificationProblemData.TargetVariable; throw new NotSupportedException(); } private double CalculateQuality(IEnumerable estimatedValues, IEnumerable originalValues) { var errorState = OnlineCalculatorError.None; double quality = 0.0; switch (QualityMeasure) { case QualityMeasures.PEARSON: quality = OnlinePearsonsRSquaredCalculator.Calculate(estimatedValues, originalValues, out errorState); break; case QualityMeasures.MSE: quality = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, originalValues, out errorState); break; } return errorState == OnlineCalculatorError.None ? quality : double.NaN; } private double CalculateQuality(ISymbolicExpressionTree tree, IEnumerable rows) { var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(tree, ProblemData.Dataset, rows); var originalValues = ProblemData.Dataset.GetDoubleValues(GetTargetVariable(ProblemData), rows); double quality = 0; var errorState = new OnlineCalculatorError(); switch (QualityMeasure) { case QualityMeasures.PEARSON: quality = OnlinePearsonsRSquaredCalculator.Calculate(estimatedValues, originalValues, out errorState); break; case QualityMeasures.MSE: quality = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, originalValues, out errorState); break; } return errorState == OnlineCalculatorError.None ? quality : double.NaN; } } [StorableType("170F5739-8D8C-4A44-9EA2-B28B35E97A3F")] public sealed class SlidingWindowRange : IEquatable { [Storable] private readonly Tuple tuple; public int Start { get { return tuple.Item1; } } public int End { get { return tuple.Item2; } } [StorableConstructor] private SlidingWindowRange(StorableConstructorFlag _) { } private SlidingWindowRange() { } public SlidingWindowRange(int start, int end) { if (start > end) throw new ArgumentException("SlidingWindowRange: Start cannot be greater than End."); tuple = new Tuple(start, end); } public bool Equals(SlidingWindowRange other) { return tuple.Equals(other.tuple); } public override int GetHashCode() { return tuple.GetHashCode(); } public int Size { get { return End - Start; } } } }