Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/CrossValidation.cs @ 14905

Last change on this file since 14905 was 14904, checked in by bburlacu, 8 years ago

#2760: Reuse the shuffled data when creating the solution ensemble.

File size: 34.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Threading;
27using HeuristicLab.Collections;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Optimization;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using HeuristicLab.Random;
36
37namespace HeuristicLab.Algorithms.DataAnalysis {
38  [Item("Cross Validation (CV)", "Cross-validation wrapper for data analysis algorithms.")]
39  [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 100)]
40  [StorableClass]
41  public sealed class CrossValidation : ParameterizedNamedItem, IAlgorithm, IStorableContent {
42    private IDataAnalysisProblemData shuffledProblemData;
43
44    public CrossValidation()
45      : base() {
46      name = ItemName;
47      description = ItemDescription;
48
49      executionState = ExecutionState.Stopped;
50      runs = new RunCollection { OptimizerName = name };
51      runsCounter = 0;
52
53      algorithm = null;
54      clonedAlgorithms = new ItemCollection<IAlgorithm>();
55      results = new ResultCollection();
56
57      folds = new IntValue(2);
58      numberOfWorkers = new IntValue(1);
59      samplesStart = new IntValue(0);
60      samplesEnd = new IntValue(0);
61      shuffleSamples = new BoolValue(false);
62      storeAlgorithmInEachRun = false;
63
64      RegisterEvents();
65      if (Algorithm != null) RegisterAlgorithmEvents();
66    }
67
68    public string Filename { get; set; }
69
70    #region persistence and cloning
71    [StorableConstructor]
72    private CrossValidation(bool deserializing)
73      : base(deserializing) {
74    }
75    [StorableHook(HookType.AfterDeserialization)]
76    private void AfterDeserialization() {
77      RegisterEvents();
78      if (Algorithm != null) RegisterAlgorithmEvents();
79    }
80
81    private CrossValidation(CrossValidation original, Cloner cloner)
82      : base(original, cloner) {
83      executionState = original.executionState;
84      storeAlgorithmInEachRun = original.storeAlgorithmInEachRun;
85      runs = cloner.Clone(original.runs);
86      runsCounter = original.runsCounter;
87      algorithm = cloner.Clone(original.algorithm);
88      clonedAlgorithms = cloner.Clone(original.clonedAlgorithms);
89      results = cloner.Clone(original.results);
90
91      folds = cloner.Clone(original.folds);
92      numberOfWorkers = cloner.Clone(original.numberOfWorkers);
93      samplesStart = cloner.Clone(original.samplesStart);
94      samplesEnd = cloner.Clone(original.samplesEnd);
95      shuffleSamples = cloner.Clone(original.shuffleSamples);
96      RegisterEvents();
97      if (Algorithm != null) RegisterAlgorithmEvents();
98    }
99    public override IDeepCloneable Clone(Cloner cloner) {
100      return new CrossValidation(this, cloner);
101    }
102
103    #endregion
104
105    #region properties
106    [Storable]
107    private IAlgorithm algorithm;
108    public IAlgorithm Algorithm {
109      get { return algorithm; }
110      set {
111        if (ExecutionState != ExecutionState.Prepared && ExecutionState != ExecutionState.Stopped)
112          throw new InvalidOperationException("Changing the algorithm is only allowed if the CrossValidation is stopped or prepared.");
113        if (algorithm != value) {
114          if (value != null && value.Problem != null && !(value.Problem is IDataAnalysisProblem))
115            throw new ArgumentException("Only algorithms with a DataAnalysisProblem could be used for the cross validation.");
116          if (algorithm != null) DeregisterAlgorithmEvents();
117          algorithm = value;
118          Parameters.Clear();
119
120          if (algorithm != null) {
121            algorithm.StoreAlgorithmInEachRun = false;
122            RegisterAlgorithmEvents();
123            algorithm.Prepare(true);
124            Parameters.AddRange(algorithm.Parameters);
125          }
126          OnAlgorithmChanged();
127          Prepare();
128        }
129      }
130    }
131
132
133    [Storable]
134    private IDataAnalysisProblem problem;
135    public IDataAnalysisProblem Problem {
136      get {
137        if (algorithm == null)
138          return null;
139        return (IDataAnalysisProblem)algorithm.Problem;
140      }
141      set {
142        if (ExecutionState != ExecutionState.Prepared && ExecutionState != ExecutionState.Stopped)
143          throw new InvalidOperationException("Changing the problem is only allowed if the CrossValidation is stopped or prepared.");
144        if (algorithm == null) throw new ArgumentNullException("Could not set a problem before an algorithm was set.");
145        algorithm.Problem = value;
146        problem = value;
147      }
148    }
149
150    IProblem IAlgorithm.Problem {
151      get { return Problem; }
152      set {
153        if (value != null && !ProblemType.IsInstanceOfType(value))
154          throw new ArgumentException("Only DataAnalysisProblems could be used for the cross validation.");
155        Problem = (IDataAnalysisProblem)value;
156      }
157    }
158    public Type ProblemType {
159      get { return typeof(IDataAnalysisProblem); }
160    }
161
162    [Storable]
163    private ItemCollection<IAlgorithm> clonedAlgorithms;
164
165    public IEnumerable<IOptimizer> NestedOptimizers {
166      get {
167        if (Algorithm == null) yield break;
168        yield return Algorithm;
169      }
170    }
171
172    [Storable]
173    private ResultCollection results;
174    public ResultCollection Results {
175      get { return results; }
176    }
177    [Storable]
178    private BoolValue shuffleSamples;
179    public BoolValue ShuffleSamples {
180      get { return shuffleSamples; }
181    }
182    [Storable]
183    private IntValue folds;
184    public IntValue Folds {
185      get { return folds; }
186    }
187    [Storable]
188    private IntValue samplesStart;
189    public IntValue SamplesStart {
190      get { return samplesStart; }
191    }
192    [Storable]
193    private IntValue samplesEnd;
194    public IntValue SamplesEnd {
195      get { return samplesEnd; }
196    }
197    [Storable]
198    private IntValue numberOfWorkers;
199    public IntValue NumberOfWorkers {
200      get { return numberOfWorkers; }
201    }
202
203    [Storable]
204    private bool storeAlgorithmInEachRun;
205    public bool StoreAlgorithmInEachRun {
206      get { return storeAlgorithmInEachRun; }
207      set {
208        if (storeAlgorithmInEachRun != value) {
209          storeAlgorithmInEachRun = value;
210          OnStoreAlgorithmInEachRunChanged();
211        }
212      }
213    }
214
215    [Storable]
216    private int runsCounter;
217    [Storable]
218    private RunCollection runs;
219    public RunCollection Runs {
220      get { return runs; }
221    }
222
223    [Storable]
224    private ExecutionState executionState;
225    public ExecutionState ExecutionState {
226      get { return executionState; }
227      private set {
228        if (executionState != value) {
229          executionState = value;
230          OnExecutionStateChanged();
231          OnItemImageChanged();
232        }
233      }
234    }
235    public static new Image StaticItemImage {
236      get { return HeuristicLab.Common.Resources.VSImageLibrary.Event; }
237    }
238    public override Image ItemImage {
239      get {
240        if (ExecutionState == ExecutionState.Prepared) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutablePrepared;
241        else if (ExecutionState == ExecutionState.Started) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutableStarted;
242        else if (ExecutionState == ExecutionState.Paused) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutablePaused;
243        else if (ExecutionState == ExecutionState.Stopped) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutableStopped;
244        else return base.ItemImage;
245      }
246    }
247
248    public TimeSpan ExecutionTime {
249      get {
250        if (ExecutionState != ExecutionState.Prepared)
251          return TimeSpan.FromMilliseconds(clonedAlgorithms.Select(x => x.ExecutionTime.TotalMilliseconds).Sum());
252        return TimeSpan.Zero;
253      }
254    }
255    #endregion
256
257    protected override void OnNameChanged() {
258      base.OnNameChanged();
259      Runs.OptimizerName = Name;
260    }
261
262    public void Prepare() {
263      if (ExecutionState == ExecutionState.Started)
264        throw new InvalidOperationException(string.Format("Prepare not allowed in execution state \"{0}\".", ExecutionState));
265      results.Clear();
266      clonedAlgorithms.Clear();
267      if (Algorithm != null) {
268        Algorithm.Prepare();
269        if (Algorithm.ExecutionState == ExecutionState.Prepared) OnPrepared();
270      }
271    }
272    public void Prepare(bool clearRuns) {
273      if (clearRuns) runs.Clear();
274      Prepare();
275    }
276
277    public void Start() {
278      if ((ExecutionState != ExecutionState.Prepared) && (ExecutionState != ExecutionState.Paused))
279        throw new InvalidOperationException(string.Format("Start not allowed in execution state \"{0}\".", ExecutionState));
280
281      if (Algorithm != null) {
282        //create cloned algorithms
283        if (clonedAlgorithms.Count == 0) {
284          int testSamplesCount = (SamplesEnd.Value - SamplesStart.Value) / Folds.Value;
285          IDataset shuffledDataset = null;
286          for (int i = 0; i < Folds.Value; i++) {
287            var cloner = new Cloner();
288            if (ShuffleSamples.Value) {
289              var dataAnalysisProblem = (IDataAnalysisProblem)algorithm.Problem;
290              var dataset = (Dataset)dataAnalysisProblem.ProblemData.Dataset;
291              shuffledDataset = shuffledDataset ?? dataset.Shuffle(new FastRandom());
292              cloner.RegisterClonedObject(dataset, shuffledDataset);
293            }
294            IAlgorithm clonedAlgorithm = cloner.Clone(Algorithm);
295            clonedAlgorithm.Name = algorithm.Name + " Fold " + i;
296            IDataAnalysisProblem problem = clonedAlgorithm.Problem as IDataAnalysisProblem;
297            ISymbolicDataAnalysisProblem symbolicProblem = problem as ISymbolicDataAnalysisProblem;
298
299            int testStart = (i * testSamplesCount) + SamplesStart.Value;
300            int testEnd = (i + 1) == Folds.Value ? SamplesEnd.Value : (i + 1) * testSamplesCount + SamplesStart.Value;
301
302            problem.ProblemData.TrainingPartition.Start = SamplesStart.Value;
303            problem.ProblemData.TrainingPartition.End = SamplesEnd.Value;
304            problem.ProblemData.TestPartition.Start = testStart;
305            problem.ProblemData.TestPartition.End = testEnd;
306            DataAnalysisProblemData problemData = problem.ProblemData as DataAnalysisProblemData;
307            if (problemData != null) {
308              problemData.TrainingPartitionParameter.Hidden = false;
309              problemData.TestPartitionParameter.Hidden = false;
310            }
311
312            if (symbolicProblem != null) {
313              symbolicProblem.FitnessCalculationPartition.Start = SamplesStart.Value;
314              symbolicProblem.FitnessCalculationPartition.End = SamplesEnd.Value;
315            }
316            clonedAlgorithm.Prepare();
317            clonedAlgorithms.Add(clonedAlgorithm);
318          }
319          // save the shuffled problem data because it is necessary when creating the ensemble solution
320          if (shuffledProblemData == null && shuffledDataset != null) {
321            var dataAnalysisProblem = (IDataAnalysisProblem)algorithm.Problem;
322            var dataset = (Dataset)dataAnalysisProblem.ProblemData.Dataset;
323            var cloner = new Cloner();
324            cloner.RegisterClonedObject(dataset, shuffledDataset);
325            shuffledProblemData = cloner.Clone(dataAnalysisProblem.ProblemData);
326          }
327        }
328
329        //start prepared or paused cloned algorithms
330        int startedAlgorithms = 0;
331        foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
332          if (startedAlgorithms < NumberOfWorkers.Value) {
333            if (clonedAlgorithm.ExecutionState == ExecutionState.Prepared ||
334                clonedAlgorithm.ExecutionState == ExecutionState.Paused) {
335
336              // start and wait until the alg is started
337              using (var signal = new ManualResetEvent(false)) {
338                EventHandler signalSetter = (sender, args) => { signal.Set(); };
339                clonedAlgorithm.Started += signalSetter;
340                clonedAlgorithm.Start();
341                signal.WaitOne();
342                clonedAlgorithm.Started -= signalSetter;
343
344                startedAlgorithms++;
345              }
346            }
347          }
348        }
349        OnStarted();
350      }
351    }
352
353    private bool pausePending;
354    public void Pause() {
355      if (ExecutionState != ExecutionState.Started)
356        throw new InvalidOperationException(string.Format("Pause not allowed in execution state \"{0}\".", ExecutionState));
357      if (!pausePending) {
358        pausePending = true;
359        PauseAllClonedAlgorithms();
360      }
361    }
362    private void PauseAllClonedAlgorithms() {
363      foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
364        if (clonedAlgorithm.ExecutionState == ExecutionState.Started)
365          clonedAlgorithm.Pause();
366      }
367    }
368
369    private bool stopPending;
370    public void Stop() {
371      if ((ExecutionState != ExecutionState.Started) && (ExecutionState != ExecutionState.Paused))
372        throw new InvalidOperationException(string.Format("Stop not allowed in execution state \"{0}\".",
373                                                          ExecutionState));
374      if (!stopPending) {
375        stopPending = true;
376        StopAllClonedAlgorithms();
377      }
378    }
379    private void StopAllClonedAlgorithms() {
380      foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
381        if (clonedAlgorithm.ExecutionState == ExecutionState.Started ||
382            clonedAlgorithm.ExecutionState == ExecutionState.Paused)
383          clonedAlgorithm.Stop();
384      }
385    }
386
387    #region collect parameters and results
388    public override void CollectParameterValues(IDictionary<string, IItem> values) {
389      values.Add("Algorithm Name", new StringValue(Name));
390      values.Add("Algorithm Type", new StringValue(GetType().GetPrettyName()));
391      values.Add("Folds", new IntValue(Folds.Value));
392
393      if (algorithm != null) {
394        values.Add("CrossValidation Algorithm Name", new StringValue(Algorithm.Name));
395        values.Add("CrossValidation Algorithm Type", new StringValue(Algorithm.GetType().GetPrettyName()));
396        base.CollectParameterValues(values);
397      }
398      if (Problem != null) {
399        values.Add("Problem Name", new StringValue(Problem.Name));
400        values.Add("Problem Type", new StringValue(Problem.GetType().GetPrettyName()));
401        Problem.CollectParameterValues(values);
402      }
403    }
404
405    public void CollectResultValues(IDictionary<string, IItem> results) {
406      var clonedResults = (ResultCollection)this.results.Clone();
407      foreach (var result in clonedResults) {
408        results.Add(result.Name, result.Value);
409      }
410    }
411
412    private void AggregateResultValues(IDictionary<string, IItem> results) {
413      IEnumerable<IRun> runs = clonedAlgorithms.Select(alg => alg.Runs.FirstOrDefault()).Where(run => run != null);
414      IEnumerable<KeyValuePair<string, IItem>> resultCollections = runs.Where(x => x != null).SelectMany(x => x.Results).ToList();
415
416      foreach (IResult result in ExtractAndAggregateResults<IntValue>(resultCollections))
417        results.Add(result.Name, result.Value);
418      foreach (IResult result in ExtractAndAggregateResults<DoubleValue>(resultCollections))
419        results.Add(result.Name, result.Value);
420      foreach (IResult result in ExtractAndAggregateResults<PercentValue>(resultCollections))
421        results.Add(result.Name, result.Value);
422      foreach (IResult result in ExtractAndAggregateRegressionSolutions(resultCollections)) {
423        results.Add(result.Name, result.Value);
424      }
425      foreach (IResult result in ExtractAndAggregateClassificationSolutions(resultCollections)) {
426        results.Add(result.Name, result.Value);
427      }
428      results.Add("Execution Time", new TimeSpanValue(this.ExecutionTime));
429      results.Add("CrossValidation Folds", new RunCollection(runs));
430    }
431
432    private IEnumerable<IResult> ExtractAndAggregateRegressionSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) {
433      Dictionary<string, List<IRegressionSolution>> resultSolutions = new Dictionary<string, List<IRegressionSolution>>();
434      foreach (var result in resultCollections) {
435        var regressionSolution = result.Value as IRegressionSolution;
436        if (regressionSolution != null) {
437          if (resultSolutions.ContainsKey(result.Key)) {
438            resultSolutions[result.Key].Add(regressionSolution);
439          } else {
440            resultSolutions.Add(result.Key, new List<IRegressionSolution>() { regressionSolution });
441          }
442        }
443      }
444      List<IResult> aggregatedResults = new List<IResult>();
445      foreach (KeyValuePair<string, List<IRegressionSolution>> solutions in resultSolutions) {
446        // clone manually to correctly clone references between cloned root objects
447        Cloner cloner = new Cloner();
448        var problemDataClone = ShuffleSamples.Value
449          ? (IRegressionProblemData)cloner.Clone(shuffledProblemData)
450          : (IRegressionProblemData)cloner.Clone(Problem.ProblemData);
451        // set partitions of problem data clone correctly
452        problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
453        problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
454        // clone models
455        var ensembleSolution = new RegressionEnsembleSolution(problemDataClone);
456        ensembleSolution.AddRegressionSolutions(solutions.Value);
457
458        aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));
459      }
460      List<IResult> flattenedResults = new List<IResult>();
461      CollectResultsRecursively("", aggregatedResults, flattenedResults);
462      return flattenedResults;
463    }
464
465    private IEnumerable<IResult> ExtractAndAggregateClassificationSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) {
466      Dictionary<string, List<IClassificationSolution>> resultSolutions = new Dictionary<string, List<IClassificationSolution>>();
467      foreach (var result in resultCollections) {
468        var classificationSolution = result.Value as IClassificationSolution;
469        if (classificationSolution != null) {
470          if (resultSolutions.ContainsKey(result.Key)) {
471            resultSolutions[result.Key].Add(classificationSolution);
472          } else {
473            resultSolutions.Add(result.Key, new List<IClassificationSolution>() { classificationSolution });
474          }
475        }
476      }
477      var aggregatedResults = new List<IResult>();
478      foreach (KeyValuePair<string, List<IClassificationSolution>> solutions in resultSolutions) {
479        // at least one algorithm (GBT with logistic regression loss) produces a classification solution even though the original problem is a regression problem.
480        var targetVariable = solutions.Value.First().ProblemData.TargetVariable;
481        var problemDataClone = ShuffleSamples.Value
482          ? new ClassificationProblemData(shuffledProblemData.Dataset, shuffledProblemData.AllowedInputVariables, targetVariable)
483          : new ClassificationProblemData(Problem.ProblemData.Dataset, Problem.ProblemData.AllowedInputVariables, targetVariable);
484        // set partitions of problem data clone correctly
485        problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
486        problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
487        // clone models
488        var ensembleSolution = new ClassificationEnsembleSolution(problemDataClone);
489        ensembleSolution.AddClassificationSolutions(solutions.Value);
490
491        aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));
492      }
493      List<IResult> flattenedResults = new List<IResult>();
494      CollectResultsRecursively("", aggregatedResults, flattenedResults);
495      return flattenedResults;
496    }
497
498    private void CollectResultsRecursively(string path, IEnumerable<IResult> results, IList<IResult> flattenedResults) {
499      foreach (IResult result in results) {
500        flattenedResults.Add(new Result(path + result.Name, result.Value));
501        ResultCollection childCollection = result.Value as ResultCollection;
502        if (childCollection != null) {
503          CollectResultsRecursively(path + result.Name + ".", childCollection, flattenedResults);
504        }
505      }
506    }
507
508    private static IEnumerable<IResult> ExtractAndAggregateResults<T>(IEnumerable<KeyValuePair<string, IItem>> results)
509  where T : class, IItem, new() {
510      Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>();
511      foreach (var resultValue in results.Where(r => r.Value.GetType() == typeof(T))) {
512        if (!resultValues.ContainsKey(resultValue.Key))
513          resultValues[resultValue.Key] = new List<double>();
514        resultValues[resultValue.Key].Add(ConvertToDouble(resultValue.Value));
515      }
516
517      DoubleValue doubleValue;
518      if (typeof(T) == typeof(PercentValue))
519        doubleValue = new PercentValue();
520      else if (typeof(T) == typeof(DoubleValue))
521        doubleValue = new DoubleValue();
522      else if (typeof(T) == typeof(IntValue))
523        doubleValue = new DoubleValue();
524      else
525        throw new NotSupportedException();
526
527      List<IResult> aggregatedResults = new List<IResult>();
528      foreach (KeyValuePair<string, List<double>> resultValue in resultValues) {
529        doubleValue.Value = resultValue.Value.Average();
530        aggregatedResults.Add(new Result(resultValue.Key + " (average)", (IItem)doubleValue.Clone()));
531        doubleValue.Value = resultValue.Value.StandardDeviation();
532        aggregatedResults.Add(new Result(resultValue.Key + " (std.dev.)", (IItem)doubleValue.Clone()));
533      }
534      return aggregatedResults;
535    }
536
537    private static double ConvertToDouble(IItem item) {
538      if (item is DoubleValue) return ((DoubleValue)item).Value;
539      else if (item is IntValue) return ((IntValue)item).Value;
540      else throw new NotSupportedException("Could not convert any item type to double");
541    }
542    #endregion
543
544    #region events
545    private void RegisterEvents() {
546      Folds.ValueChanged += new EventHandler(Folds_ValueChanged);
547      RegisterClonedAlgorithmsEvents();
548    }
549    private void Folds_ValueChanged(object sender, EventArgs e) {
550      if (ExecutionState != ExecutionState.Prepared)
551        throw new InvalidOperationException("Can not change number of folds if the execution state is not prepared.");
552    }
553
554
555    #region template algorithms events
556    public event EventHandler AlgorithmChanged;
557    private void OnAlgorithmChanged() {
558      EventHandler handler = AlgorithmChanged;
559      if (handler != null) handler(this, EventArgs.Empty);
560      OnProblemChanged();
561      if (Problem == null) ExecutionState = ExecutionState.Stopped;
562    }
563    private void RegisterAlgorithmEvents() {
564      algorithm.ProblemChanged += new EventHandler(Algorithm_ProblemChanged);
565      algorithm.ExecutionStateChanged += new EventHandler(Algorithm_ExecutionStateChanged);
566      if (Problem != null) {
567        Problem.Reset += new EventHandler(Problem_Reset);
568        Problem.ProblemDataChanged += Problem_ProblemDataChanged;
569      }
570    }
571    private void DeregisterAlgorithmEvents() {
572      algorithm.ProblemChanged -= new EventHandler(Algorithm_ProblemChanged);
573      algorithm.ExecutionStateChanged -= new EventHandler(Algorithm_ExecutionStateChanged);
574      if (Problem != null) {
575        Problem.Reset -= new EventHandler(Problem_Reset);
576        Problem.ProblemDataChanged -= Problem_ProblemDataChanged;
577      }
578    }
579    private void Algorithm_ProblemChanged(object sender, EventArgs e) {
580      if (algorithm.Problem != null && !(algorithm.Problem is IDataAnalysisProblem)) {
581        algorithm.Problem = problem;
582        throw new ArgumentException("A cross validation algorithm can only contain DataAnalysisProblems.");
583      }
584      if (problem != null) problem.Reset -= new EventHandler(Problem_Reset);
585      problem = (IDataAnalysisProblem)algorithm.Problem;
586      if (problem != null) problem.Reset += new EventHandler(Problem_Reset);
587      OnProblemChanged();
588    }
589    public event EventHandler ProblemChanged;
590    private void OnProblemChanged() {
591      EventHandler handler = ProblemChanged;
592      if (handler != null) handler(this, EventArgs.Empty);
593      ConfigureProblem();
594    }
595    public event EventHandler ProblemDataChanged;
596    private void OnProblemDataChanged() {
597      var handler = ProblemDataChanged;
598      if (handler != null) handler(this, EventArgs.Empty);
599      shuffledProblemData = null;
600    }
601    private void Problem_ProblemDataChanged(object sender, EventArgs e) {
602      OnProblemDataChanged();
603    }
604    private void Problem_Reset(object sender, EventArgs e) {
605      ConfigureProblem();
606    }
607    private void ConfigureProblem() {
608      SamplesStart.Value = 0;
609      if (Problem != null) {
610        SamplesEnd.Value = Problem.ProblemData.Dataset.Rows;
611
612        DataAnalysisProblemData problemData = Problem.ProblemData as DataAnalysisProblemData;
613        if (problemData != null) {
614          problemData.TrainingPartitionParameter.Hidden = true;
615          problemData.TestPartitionParameter.Hidden = true;
616        }
617        ISymbolicDataAnalysisProblem symbolicProblem = Problem as ISymbolicDataAnalysisProblem;
618        if (symbolicProblem != null) {
619          symbolicProblem.FitnessCalculationPartitionParameter.Hidden = true;
620          symbolicProblem.FitnessCalculationPartition.Start = SamplesStart.Value;
621          symbolicProblem.FitnessCalculationPartition.End = SamplesEnd.Value;
622          symbolicProblem.ValidationPartitionParameter.Hidden = true;
623          symbolicProblem.ValidationPartition.Start = 0;
624          symbolicProblem.ValidationPartition.End = 0;
625        }
626      } else
627        SamplesEnd.Value = 0;
628    }
629
630    private void Algorithm_ExecutionStateChanged(object sender, EventArgs e) {
631      switch (Algorithm.ExecutionState) {
632        case ExecutionState.Prepared:
633          OnPrepared();
634          break;
635        case ExecutionState.Started: throw new InvalidOperationException("Algorithm template can not be started.");
636        case ExecutionState.Paused: throw new InvalidOperationException("Algorithm template can not be paused.");
637        case ExecutionState.Stopped:
638          OnStopped();
639          break;
640      }
641    }
642    #endregion
643
644    #region clonedAlgorithms events
645    private void RegisterClonedAlgorithmsEvents() {
646      clonedAlgorithms.ItemsAdded += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded);
647      clonedAlgorithms.ItemsRemoved += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved);
648      clonedAlgorithms.CollectionReset += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset);
649      foreach (IAlgorithm algorithm in clonedAlgorithms)
650        RegisterClonedAlgorithmEvents(algorithm);
651    }
652    private void DeregisterClonedAlgorithmsEvents() {
653      clonedAlgorithms.ItemsAdded -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded);
654      clonedAlgorithms.ItemsRemoved -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved);
655      clonedAlgorithms.CollectionReset -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset);
656      foreach (IAlgorithm algorithm in clonedAlgorithms)
657        DeregisterClonedAlgorithmEvents(algorithm);
658    }
659    private void ClonedAlgorithms_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
660      foreach (IAlgorithm algorithm in e.Items)
661        RegisterClonedAlgorithmEvents(algorithm);
662    }
663    private void ClonedAlgorithms_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
664      foreach (IAlgorithm algorithm in e.Items)
665        DeregisterClonedAlgorithmEvents(algorithm);
666    }
667    private void ClonedAlgorithms_CollectionReset(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
668      foreach (IAlgorithm algorithm in e.OldItems)
669        DeregisterClonedAlgorithmEvents(algorithm);
670      foreach (IAlgorithm algorithm in e.Items)
671        RegisterClonedAlgorithmEvents(algorithm);
672    }
673    private void RegisterClonedAlgorithmEvents(IAlgorithm algorithm) {
674      algorithm.ExceptionOccurred += new EventHandler<EventArgs<Exception>>(ClonedAlgorithm_ExceptionOccurred);
675      algorithm.ExecutionTimeChanged += new EventHandler(ClonedAlgorithm_ExecutionTimeChanged);
676      algorithm.Started += new EventHandler(ClonedAlgorithm_Started);
677      algorithm.Paused += new EventHandler(ClonedAlgorithm_Paused);
678      algorithm.Stopped += new EventHandler(ClonedAlgorithm_Stopped);
679    }
680    private void DeregisterClonedAlgorithmEvents(IAlgorithm algorithm) {
681      algorithm.ExceptionOccurred -= new EventHandler<EventArgs<Exception>>(ClonedAlgorithm_ExceptionOccurred);
682      algorithm.ExecutionTimeChanged -= new EventHandler(ClonedAlgorithm_ExecutionTimeChanged);
683      algorithm.Started -= new EventHandler(ClonedAlgorithm_Started);
684      algorithm.Paused -= new EventHandler(ClonedAlgorithm_Paused);
685      algorithm.Stopped -= new EventHandler(ClonedAlgorithm_Stopped);
686    }
687    private void ClonedAlgorithm_ExceptionOccurred(object sender, EventArgs<Exception> e) {
688      OnExceptionOccurred(e.Value);
689    }
690    private void ClonedAlgorithm_ExecutionTimeChanged(object sender, EventArgs e) {
691      OnExecutionTimeChanged();
692    }
693
694    private readonly object locker = new object();
695    private readonly object resultLocker = new object();
696    private void ClonedAlgorithm_Started(object sender, EventArgs e) {
697      IAlgorithm algorithm = sender as IAlgorithm;
698      lock (resultLocker) {
699        if (algorithm != null && !results.ContainsKey(algorithm.Name))
700          results.Add(new Result(algorithm.Name, "Contains results for the specific fold.", algorithm.Results));
701      }
702    }
703
704    private void ClonedAlgorithm_Paused(object sender, EventArgs e) {
705      lock (locker) {
706        if (pausePending && clonedAlgorithms.All(alg => alg.ExecutionState != ExecutionState.Started))
707          OnPaused();
708      }
709    }
710
711    private void ClonedAlgorithm_Stopped(object sender, EventArgs e) {
712      lock (locker) {
713        if (!stopPending && ExecutionState == ExecutionState.Started) {
714          IAlgorithm preparedAlgorithm = clonedAlgorithms.FirstOrDefault(alg => alg.ExecutionState == ExecutionState.Prepared ||
715                                                                                alg.ExecutionState == ExecutionState.Paused);
716          if (preparedAlgorithm != null) preparedAlgorithm.Start();
717        }
718        if (ExecutionState != ExecutionState.Stopped) {
719          if (clonedAlgorithms.All(alg => alg.ExecutionState == ExecutionState.Stopped))
720            OnStopped();
721          else if (stopPending &&
722                   clonedAlgorithms.All(
723                     alg => alg.ExecutionState == ExecutionState.Prepared || alg.ExecutionState == ExecutionState.Stopped))
724            OnStopped();
725        }
726      }
727    }
728    #endregion
729    #endregion
730
731    #region event firing
732    public event EventHandler ExecutionStateChanged;
733    private void OnExecutionStateChanged() {
734      EventHandler handler = ExecutionStateChanged;
735      if (handler != null) handler(this, EventArgs.Empty);
736    }
737    public event EventHandler ExecutionTimeChanged;
738    private void OnExecutionTimeChanged() {
739      EventHandler handler = ExecutionTimeChanged;
740      if (handler != null) handler(this, EventArgs.Empty);
741    }
742    public event EventHandler Prepared;
743    private void OnPrepared() {
744      ExecutionState = ExecutionState.Prepared;
745      EventHandler handler = Prepared;
746      if (handler != null) handler(this, EventArgs.Empty);
747      OnExecutionTimeChanged();
748    }
749    public event EventHandler Started;
750    private void OnStarted() {
751      ExecutionState = ExecutionState.Started;
752      EventHandler handler = Started;
753      if (handler != null) handler(this, EventArgs.Empty);
754    }
755    public event EventHandler Paused;
756    private void OnPaused() {
757      pausePending = false;
758      ExecutionState = ExecutionState.Paused;
759      EventHandler handler = Paused;
760      if (handler != null) handler(this, EventArgs.Empty);
761    }
762    public event EventHandler Stopped;
763    private void OnStopped() {
764      stopPending = false;
765      Dictionary<string, IItem> collectedResults = new Dictionary<string, IItem>();
766      AggregateResultValues(collectedResults);
767      results.AddRange(collectedResults.Select(x => new Result(x.Key, x.Value)).Cast<IResult>().ToArray());
768      clonedAlgorithms.Clear();
769      runsCounter++;
770      runs.Add(new Run(string.Format("{0} Run {1}", Name, runsCounter), this));
771      ExecutionState = ExecutionState.Stopped;
772      EventHandler handler = Stopped;
773      if (handler != null) handler(this, EventArgs.Empty);
774    }
775    public event EventHandler<EventArgs<Exception>> ExceptionOccurred;
776    private void OnExceptionOccurred(Exception exception) {
777      EventHandler<EventArgs<Exception>> handler = ExceptionOccurred;
778      if (handler != null) handler(this, new EventArgs<Exception>(exception));
779    }
780    public event EventHandler StoreAlgorithmInEachRunChanged;
781    private void OnStoreAlgorithmInEachRunChanged() {
782      EventHandler handler = StoreAlgorithmInEachRunChanged;
783      if (handler != null) handler(this, EventArgs.Empty);
784    }
785    #endregion
786  }
787}
Note: See TracBrowser for help on using the repository browser.