Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/CrossValidation.cs @ 15263

Last change on this file since 15263 was 15077, checked in by mkommend, 7 years ago

#2760: Reordered backwards compatiblity and event registration in after deserialization hook of CrossValidation.

File size: 33.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Threading;
27using HeuristicLab.Collections;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Optimization;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using HeuristicLab.Random;
36
37namespace HeuristicLab.Algorithms.DataAnalysis {
38  [Item("Cross Validation (CV)", "Cross-validation wrapper for data analysis algorithms.")]
39  [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 100)]
40  [StorableClass]
41  public sealed class CrossValidation : ParameterizedNamedItem, IAlgorithm, IStorableContent {
42    [Storable]
43    private int seed;
44
45    public CrossValidation()
46      : base() {
47      name = ItemName;
48      description = ItemDescription;
49
50      executionState = ExecutionState.Stopped;
51      runs = new RunCollection { OptimizerName = name };
52      runsCounter = 0;
53
54      algorithm = null;
55      clonedAlgorithms = new ItemCollection<IAlgorithm>();
56      results = new ResultCollection();
57
58      folds = new IntValue(2);
59      numberOfWorkers = new IntValue(1);
60      samplesStart = new IntValue(0);
61      samplesEnd = new IntValue(0);
62      shuffleSamples = new BoolValue(false);
63      storeAlgorithmInEachRun = false;
64
65      RegisterEvents();
66      if (Algorithm != null) RegisterAlgorithmEvents();
67    }
68
69    public string Filename { get; set; }
70
71    #region persistence and cloning
72    [StorableConstructor]
73    private CrossValidation(bool deserializing)
74      : base(deserializing) {
75    }
76    [StorableHook(HookType.AfterDeserialization)]
77    private void AfterDeserialization() {
78      // BackwardsCompatibility3.3
79      #region Backwards compatible code, remove with 3.4
80      if (shuffleSamples == null) shuffleSamples = new BoolValue(false);
81      #endregion
82
83      RegisterEvents();
84      if (Algorithm != null) RegisterAlgorithmEvents();
85    }
86
87    private CrossValidation(CrossValidation original, Cloner cloner)
88      : base(original, cloner) {
89      executionState = original.executionState;
90      storeAlgorithmInEachRun = original.storeAlgorithmInEachRun;
91      runs = cloner.Clone(original.runs);
92      runsCounter = original.runsCounter;
93      algorithm = cloner.Clone(original.algorithm);
94      clonedAlgorithms = cloner.Clone(original.clonedAlgorithms);
95      results = cloner.Clone(original.results);
96
97      folds = cloner.Clone(original.folds);
98      numberOfWorkers = cloner.Clone(original.numberOfWorkers);
99      samplesStart = cloner.Clone(original.samplesStart);
100      samplesEnd = cloner.Clone(original.samplesEnd);
101      shuffleSamples = cloner.Clone(original.shuffleSamples);
102      seed = original.seed;
103
104      RegisterEvents();
105      if (Algorithm != null) RegisterAlgorithmEvents();
106    }
107    public override IDeepCloneable Clone(Cloner cloner) {
108      return new CrossValidation(this, cloner);
109    }
110
111    #endregion
112
113    #region properties
114    [Storable]
115    private IAlgorithm algorithm;
116    public IAlgorithm Algorithm {
117      get { return algorithm; }
118      set {
119        if (ExecutionState != ExecutionState.Prepared && ExecutionState != ExecutionState.Stopped)
120          throw new InvalidOperationException("Changing the algorithm is only allowed if the CrossValidation is stopped or prepared.");
121        if (algorithm != value) {
122          if (value != null && value.Problem != null && !(value.Problem is IDataAnalysisProblem))
123            throw new ArgumentException("Only algorithms with a DataAnalysisProblem could be used for the cross validation.");
124          if (algorithm != null) DeregisterAlgorithmEvents();
125          algorithm = value;
126          Parameters.Clear();
127
128          if (algorithm != null) {
129            algorithm.StoreAlgorithmInEachRun = false;
130            RegisterAlgorithmEvents();
131            algorithm.Prepare(true);
132            Parameters.AddRange(algorithm.Parameters);
133          }
134          OnAlgorithmChanged();
135          Prepare();
136        }
137      }
138    }
139
140
141    [Storable]
142    private IDataAnalysisProblem problem;
143    public IDataAnalysisProblem Problem {
144      get {
145        if (algorithm == null)
146          return null;
147        return (IDataAnalysisProblem)algorithm.Problem;
148      }
149      set {
150        if (ExecutionState != ExecutionState.Prepared && ExecutionState != ExecutionState.Stopped)
151          throw new InvalidOperationException("Changing the problem is only allowed if the CrossValidation is stopped or prepared.");
152        if (algorithm == null) throw new ArgumentNullException("Could not set a problem before an algorithm was set.");
153        algorithm.Problem = value;
154        problem = value;
155      }
156    }
157
158    IProblem IAlgorithm.Problem {
159      get { return Problem; }
160      set {
161        if (value != null && !ProblemType.IsInstanceOfType(value))
162          throw new ArgumentException("Only DataAnalysisProblems could be used for the cross validation.");
163        Problem = (IDataAnalysisProblem)value;
164      }
165    }
166    public Type ProblemType {
167      get { return typeof(IDataAnalysisProblem); }
168    }
169
170    [Storable]
171    private ItemCollection<IAlgorithm> clonedAlgorithms;
172
173    public IEnumerable<IOptimizer> NestedOptimizers {
174      get {
175        if (Algorithm == null) yield break;
176        yield return Algorithm;
177      }
178    }
179
180    [Storable]
181    private ResultCollection results;
182    public ResultCollection Results {
183      get { return results; }
184    }
185    [Storable]
186    private BoolValue shuffleSamples;
187    public BoolValue ShuffleSamples {
188      get { return shuffleSamples; }
189    }
190    [Storable]
191    private IntValue folds;
192    public IntValue Folds {
193      get { return folds; }
194    }
195    [Storable]
196    private IntValue samplesStart;
197    public IntValue SamplesStart {
198      get { return samplesStart; }
199    }
200    [Storable]
201    private IntValue samplesEnd;
202    public IntValue SamplesEnd {
203      get { return samplesEnd; }
204    }
205    [Storable]
206    private IntValue numberOfWorkers;
207    public IntValue NumberOfWorkers {
208      get { return numberOfWorkers; }
209    }
210
211    [Storable]
212    private bool storeAlgorithmInEachRun;
213    public bool StoreAlgorithmInEachRun {
214      get { return storeAlgorithmInEachRun; }
215      set {
216        if (storeAlgorithmInEachRun != value) {
217          storeAlgorithmInEachRun = value;
218          OnStoreAlgorithmInEachRunChanged();
219        }
220      }
221    }
222
223    [Storable]
224    private int runsCounter;
225    [Storable]
226    private RunCollection runs;
227    public RunCollection Runs {
228      get { return runs; }
229    }
230
231    [Storable]
232    private ExecutionState executionState;
233    public ExecutionState ExecutionState {
234      get { return executionState; }
235      private set {
236        if (executionState != value) {
237          executionState = value;
238          OnExecutionStateChanged();
239          OnItemImageChanged();
240        }
241      }
242    }
243    public static new Image StaticItemImage {
244      get { return HeuristicLab.Common.Resources.VSImageLibrary.Event; }
245    }
246    public override Image ItemImage {
247      get {
248        if (ExecutionState == ExecutionState.Prepared) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutablePrepared;
249        else if (ExecutionState == ExecutionState.Started) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutableStarted;
250        else if (ExecutionState == ExecutionState.Paused) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutablePaused;
251        else if (ExecutionState == ExecutionState.Stopped) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutableStopped;
252        else return base.ItemImage;
253      }
254    }
255
256    public TimeSpan ExecutionTime {
257      get {
258        if (ExecutionState != ExecutionState.Prepared)
259          return TimeSpan.FromMilliseconds(clonedAlgorithms.Select(x => x.ExecutionTime.TotalMilliseconds).Sum());
260        return TimeSpan.Zero;
261      }
262    }
263    #endregion
264
265    protected override void OnNameChanged() {
266      base.OnNameChanged();
267      Runs.OptimizerName = Name;
268    }
269
270    public void Prepare() {
271      if (ExecutionState == ExecutionState.Started)
272        throw new InvalidOperationException(string.Format("Prepare not allowed in execution state \"{0}\".", ExecutionState));
273      results.Clear();
274      clonedAlgorithms.Clear();
275      if (Algorithm != null) {
276        Algorithm.Prepare();
277        if (Algorithm.ExecutionState == ExecutionState.Prepared) OnPrepared();
278      }
279    }
280    public void Prepare(bool clearRuns) {
281      if (clearRuns) runs.Clear();
282      Prepare();
283    }
284
285    public void Start() {
286      if ((ExecutionState != ExecutionState.Prepared) && (ExecutionState != ExecutionState.Paused))
287        throw new InvalidOperationException(string.Format("Start not allowed in execution state \"{0}\".", ExecutionState));
288
289      seed = new FastRandom().NextInt();
290
291      if (Algorithm != null) {
292        //create cloned algorithms
293        if (clonedAlgorithms.Count == 0) {
294          int testSamplesCount = (SamplesEnd.Value - SamplesStart.Value) / Folds.Value;
295          IDataset shuffledDataset = null;
296          for (int i = 0; i < Folds.Value; i++) {
297            var cloner = new Cloner();
298            if (ShuffleSamples.Value) {
299              var random = new FastRandom(seed);
300              var dataAnalysisProblem = (IDataAnalysisProblem)algorithm.Problem;
301              var dataset = (Dataset)dataAnalysisProblem.ProblemData.Dataset;
302              shuffledDataset = shuffledDataset ?? dataset.Shuffle(random);
303              cloner.RegisterClonedObject(dataset, shuffledDataset);
304            }
305            IAlgorithm clonedAlgorithm = cloner.Clone(Algorithm);
306            clonedAlgorithm.Name = algorithm.Name + " Fold " + i;
307            IDataAnalysisProblem problem = clonedAlgorithm.Problem as IDataAnalysisProblem;
308            ISymbolicDataAnalysisProblem symbolicProblem = problem as ISymbolicDataAnalysisProblem;
309
310            int testStart = (i * testSamplesCount) + SamplesStart.Value;
311            int testEnd = (i + 1) == Folds.Value ? SamplesEnd.Value : (i + 1) * testSamplesCount + SamplesStart.Value;
312
313            problem.ProblemData.TrainingPartition.Start = SamplesStart.Value;
314            problem.ProblemData.TrainingPartition.End = SamplesEnd.Value;
315            problem.ProblemData.TestPartition.Start = testStart;
316            problem.ProblemData.TestPartition.End = testEnd;
317            DataAnalysisProblemData problemData = problem.ProblemData as DataAnalysisProblemData;
318            if (problemData != null) {
319              problemData.TrainingPartitionParameter.Hidden = false;
320              problemData.TestPartitionParameter.Hidden = false;
321            }
322
323            if (symbolicProblem != null) {
324              symbolicProblem.FitnessCalculationPartition.Start = SamplesStart.Value;
325              symbolicProblem.FitnessCalculationPartition.End = SamplesEnd.Value;
326            }
327            clonedAlgorithm.Prepare();
328            clonedAlgorithms.Add(clonedAlgorithm);
329          }
330        }
331
332        //start prepared or paused cloned algorithms
333        int startedAlgorithms = 0;
334        foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
335          if (startedAlgorithms < NumberOfWorkers.Value) {
336            if (clonedAlgorithm.ExecutionState == ExecutionState.Prepared ||
337                clonedAlgorithm.ExecutionState == ExecutionState.Paused) {
338
339              // start and wait until the alg is started
340              using (var signal = new ManualResetEvent(false)) {
341                EventHandler signalSetter = (sender, args) => { signal.Set(); };
342                clonedAlgorithm.Started += signalSetter;
343                clonedAlgorithm.Start();
344                signal.WaitOne();
345                clonedAlgorithm.Started -= signalSetter;
346
347                startedAlgorithms++;
348              }
349            }
350          }
351        }
352        OnStarted();
353      }
354    }
355
356    private bool pausePending;
357    public void Pause() {
358      if (ExecutionState != ExecutionState.Started)
359        throw new InvalidOperationException(string.Format("Pause not allowed in execution state \"{0}\".", ExecutionState));
360      if (!pausePending) {
361        pausePending = true;
362        PauseAllClonedAlgorithms();
363      }
364    }
365    private void PauseAllClonedAlgorithms() {
366      foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
367        if (clonedAlgorithm.ExecutionState == ExecutionState.Started)
368          clonedAlgorithm.Pause();
369      }
370    }
371
372    private bool stopPending;
373    public void Stop() {
374      if ((ExecutionState != ExecutionState.Started) && (ExecutionState != ExecutionState.Paused))
375        throw new InvalidOperationException(string.Format("Stop not allowed in execution state \"{0}\".",
376                                                          ExecutionState));
377      if (!stopPending) {
378        stopPending = true;
379        StopAllClonedAlgorithms();
380      }
381    }
382    private void StopAllClonedAlgorithms() {
383      foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
384        if (clonedAlgorithm.ExecutionState == ExecutionState.Started ||
385            clonedAlgorithm.ExecutionState == ExecutionState.Paused)
386          clonedAlgorithm.Stop();
387      }
388    }
389
390    #region collect parameters and results
391    public override void CollectParameterValues(IDictionary<string, IItem> values) {
392      values.Add("Algorithm Name", new StringValue(Name));
393      values.Add("Algorithm Type", new StringValue(GetType().GetPrettyName()));
394      values.Add("Folds", new IntValue(Folds.Value));
395
396      if (algorithm != null) {
397        values.Add("CrossValidation Algorithm Name", new StringValue(Algorithm.Name));
398        values.Add("CrossValidation Algorithm Type", new StringValue(Algorithm.GetType().GetPrettyName()));
399        base.CollectParameterValues(values);
400      }
401      if (Problem != null) {
402        values.Add("Problem Name", new StringValue(Problem.Name));
403        values.Add("Problem Type", new StringValue(Problem.GetType().GetPrettyName()));
404        Problem.CollectParameterValues(values);
405      }
406    }
407
408    public void CollectResultValues(IDictionary<string, IItem> results) {
409      var clonedResults = (ResultCollection)this.results.Clone();
410      foreach (var result in clonedResults) {
411        results.Add(result.Name, result.Value);
412      }
413    }
414
415    private void AggregateResultValues(IDictionary<string, IItem> results) {
416      IEnumerable<IRun> runs = clonedAlgorithms.Select(alg => alg.Runs.FirstOrDefault()).Where(run => run != null);
417      IEnumerable<KeyValuePair<string, IItem>> resultCollections = runs.Where(x => x != null).SelectMany(x => x.Results).ToList();
418
419      foreach (IResult result in ExtractAndAggregateResults<IntValue>(resultCollections))
420        results.Add(result.Name, result.Value);
421      foreach (IResult result in ExtractAndAggregateResults<DoubleValue>(resultCollections))
422        results.Add(result.Name, result.Value);
423      foreach (IResult result in ExtractAndAggregateResults<PercentValue>(resultCollections))
424        results.Add(result.Name, result.Value);
425      foreach (IResult result in ExtractAndAggregateRegressionSolutions(resultCollections)) {
426        results.Add(result.Name, result.Value);
427      }
428      foreach (IResult result in ExtractAndAggregateClassificationSolutions(resultCollections)) {
429        results.Add(result.Name, result.Value);
430      }
431      results.Add("Execution Time", new TimeSpanValue(this.ExecutionTime));
432      results.Add("CrossValidation Folds", new RunCollection(runs));
433    }
434
435    private IEnumerable<IResult> ExtractAndAggregateRegressionSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) {
436      Dictionary<string, List<IRegressionSolution>> resultSolutions = new Dictionary<string, List<IRegressionSolution>>();
437      foreach (var result in resultCollections) {
438        var regressionSolution = result.Value as IRegressionSolution;
439        if (regressionSolution != null) {
440          if (resultSolutions.ContainsKey(result.Key)) {
441            resultSolutions[result.Key].Add(regressionSolution);
442          } else {
443            resultSolutions.Add(result.Key, new List<IRegressionSolution>() { regressionSolution });
444          }
445        }
446      }
447      List<IResult> aggregatedResults = new List<IResult>();
448      foreach (KeyValuePair<string, List<IRegressionSolution>> solutions in resultSolutions) {
449        // clone manually to correctly clone references between cloned root objects
450        Cloner cloner = new Cloner();
451        if (ShuffleSamples.Value) {
452          var dataset = (Dataset)Problem.ProblemData.Dataset;
453          var random = new FastRandom(seed);
454          var shuffledDataset = dataset.Shuffle(random);
455          cloner.RegisterClonedObject(dataset, shuffledDataset);
456        }
457        var problemDataClone = (IRegressionProblemData)cloner.Clone(Problem.ProblemData);
458        // set partitions of problem data clone correctly
459        problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
460        problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
461        // clone models
462        var ensembleSolution = new RegressionEnsembleSolution(problemDataClone);
463        ensembleSolution.AddRegressionSolutions(solutions.Value);
464
465        aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));
466      }
467      List<IResult> flattenedResults = new List<IResult>();
468      CollectResultsRecursively("", aggregatedResults, flattenedResults);
469      return flattenedResults;
470    }
471
472    private IEnumerable<IResult> ExtractAndAggregateClassificationSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) {
473      Dictionary<string, List<IClassificationSolution>> resultSolutions = new Dictionary<string, List<IClassificationSolution>>();
474      foreach (var result in resultCollections) {
475        var classificationSolution = result.Value as IClassificationSolution;
476        if (classificationSolution != null) {
477          if (resultSolutions.ContainsKey(result.Key)) {
478            resultSolutions[result.Key].Add(classificationSolution);
479          } else {
480            resultSolutions.Add(result.Key, new List<IClassificationSolution>() { classificationSolution });
481          }
482        }
483      }
484      var aggregatedResults = new List<IResult>();
485      foreach (KeyValuePair<string, List<IClassificationSolution>> solutions in resultSolutions) {
486        // at least one algorithm (GBT with logistic regression loss) produces a classification solution even though the original problem is a regression problem.
487        var targetVariable = solutions.Value.First().ProblemData.TargetVariable;
488        var dataset = (Dataset)Problem.ProblemData.Dataset;
489        if (ShuffleSamples.Value) {
490          var random = new FastRandom(seed);
491          dataset = dataset.Shuffle(random);
492        }
493        var problemDataClone = new ClassificationProblemData(dataset, Problem.ProblemData.AllowedInputVariables, targetVariable);
494        // set partitions of problem data clone correctly
495        problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
496        problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
497        // clone models
498        var ensembleSolution = new ClassificationEnsembleSolution(problemDataClone);
499        ensembleSolution.AddClassificationSolutions(solutions.Value);
500
501        aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));
502      }
503      List<IResult> flattenedResults = new List<IResult>();
504      CollectResultsRecursively("", aggregatedResults, flattenedResults);
505      return flattenedResults;
506    }
507
508    private void CollectResultsRecursively(string path, IEnumerable<IResult> results, IList<IResult> flattenedResults) {
509      foreach (IResult result in results) {
510        flattenedResults.Add(new Result(path + result.Name, result.Value));
511        ResultCollection childCollection = result.Value as ResultCollection;
512        if (childCollection != null) {
513          CollectResultsRecursively(path + result.Name + ".", childCollection, flattenedResults);
514        }
515      }
516    }
517
518    private static IEnumerable<IResult> ExtractAndAggregateResults<T>(IEnumerable<KeyValuePair<string, IItem>> results)
519  where T : class, IItem, new() {
520      Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>();
521      foreach (var resultValue in results.Where(r => r.Value.GetType() == typeof(T))) {
522        if (!resultValues.ContainsKey(resultValue.Key))
523          resultValues[resultValue.Key] = new List<double>();
524        resultValues[resultValue.Key].Add(ConvertToDouble(resultValue.Value));
525      }
526
527      DoubleValue doubleValue;
528      if (typeof(T) == typeof(PercentValue))
529        doubleValue = new PercentValue();
530      else if (typeof(T) == typeof(DoubleValue))
531        doubleValue = new DoubleValue();
532      else if (typeof(T) == typeof(IntValue))
533        doubleValue = new DoubleValue();
534      else
535        throw new NotSupportedException();
536
537      List<IResult> aggregatedResults = new List<IResult>();
538      foreach (KeyValuePair<string, List<double>> resultValue in resultValues) {
539        doubleValue.Value = resultValue.Value.Average();
540        aggregatedResults.Add(new Result(resultValue.Key + " (average)", (IItem)doubleValue.Clone()));
541        doubleValue.Value = resultValue.Value.StandardDeviation();
542        aggregatedResults.Add(new Result(resultValue.Key + " (std.dev.)", (IItem)doubleValue.Clone()));
543      }
544      return aggregatedResults;
545    }
546
547    private static double ConvertToDouble(IItem item) {
548      if (item is DoubleValue) return ((DoubleValue)item).Value;
549      else if (item is IntValue) return ((IntValue)item).Value;
550      else throw new NotSupportedException("Could not convert any item type to double");
551    }
552    #endregion
553
554    #region events
555    private void RegisterEvents() {
556      Folds.ValueChanged += new EventHandler(Folds_ValueChanged);
557      RegisterClonedAlgorithmsEvents();
558    }
559    private void Folds_ValueChanged(object sender, EventArgs e) {
560      if (ExecutionState != ExecutionState.Prepared)
561        throw new InvalidOperationException("Can not change number of folds if the execution state is not prepared.");
562    }
563
564
565    #region template algorithms events
566    public event EventHandler AlgorithmChanged;
567    private void OnAlgorithmChanged() {
568      EventHandler handler = AlgorithmChanged;
569      if (handler != null) handler(this, EventArgs.Empty);
570      OnProblemChanged();
571      if (Problem == null) ExecutionState = ExecutionState.Stopped;
572    }
573    private void RegisterAlgorithmEvents() {
574      algorithm.ProblemChanged += new EventHandler(Algorithm_ProblemChanged);
575      algorithm.ExecutionStateChanged += new EventHandler(Algorithm_ExecutionStateChanged);
576      if (Problem != null) {
577        Problem.Reset += new EventHandler(Problem_Reset);
578      }
579    }
580    private void DeregisterAlgorithmEvents() {
581      algorithm.ProblemChanged -= new EventHandler(Algorithm_ProblemChanged);
582      algorithm.ExecutionStateChanged -= new EventHandler(Algorithm_ExecutionStateChanged);
583      if (Problem != null) {
584        Problem.Reset -= new EventHandler(Problem_Reset);
585      }
586    }
587    private void Algorithm_ProblemChanged(object sender, EventArgs e) {
588      if (algorithm.Problem != null && !(algorithm.Problem is IDataAnalysisProblem)) {
589        algorithm.Problem = problem;
590        throw new ArgumentException("A cross validation algorithm can only contain DataAnalysisProblems.");
591      }
592      if (problem != null) problem.Reset -= new EventHandler(Problem_Reset);
593      problem = (IDataAnalysisProblem)algorithm.Problem;
594      if (problem != null) problem.Reset += new EventHandler(Problem_Reset);
595      OnProblemChanged();
596    }
597    public event EventHandler ProblemChanged;
598    private void OnProblemChanged() {
599      EventHandler handler = ProblemChanged;
600      if (handler != null) handler(this, EventArgs.Empty);
601      ConfigureProblem();
602    }
603    private void Problem_Reset(object sender, EventArgs e) {
604      ConfigureProblem();
605    }
606    private void ConfigureProblem() {
607      SamplesStart.Value = 0;
608      if (Problem != null) {
609        SamplesEnd.Value = Problem.ProblemData.Dataset.Rows;
610
611        DataAnalysisProblemData problemData = Problem.ProblemData as DataAnalysisProblemData;
612        if (problemData != null) {
613          problemData.TrainingPartitionParameter.Hidden = true;
614          problemData.TestPartitionParameter.Hidden = true;
615        }
616        ISymbolicDataAnalysisProblem symbolicProblem = Problem as ISymbolicDataAnalysisProblem;
617        if (symbolicProblem != null) {
618          symbolicProblem.FitnessCalculationPartitionParameter.Hidden = true;
619          symbolicProblem.FitnessCalculationPartition.Start = SamplesStart.Value;
620          symbolicProblem.FitnessCalculationPartition.End = SamplesEnd.Value;
621          symbolicProblem.ValidationPartitionParameter.Hidden = true;
622          symbolicProblem.ValidationPartition.Start = 0;
623          symbolicProblem.ValidationPartition.End = 0;
624        }
625      } else
626        SamplesEnd.Value = 0;
627    }
628
629    private void Algorithm_ExecutionStateChanged(object sender, EventArgs e) {
630      switch (Algorithm.ExecutionState) {
631        case ExecutionState.Prepared:
632          OnPrepared();
633          break;
634        case ExecutionState.Started: throw new InvalidOperationException("Algorithm template can not be started.");
635        case ExecutionState.Paused: throw new InvalidOperationException("Algorithm template can not be paused.");
636        case ExecutionState.Stopped:
637          OnStopped();
638          break;
639      }
640    }
641    #endregion
642
643    #region clonedAlgorithms events
644    private void RegisterClonedAlgorithmsEvents() {
645      clonedAlgorithms.ItemsAdded += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded);
646      clonedAlgorithms.ItemsRemoved += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved);
647      clonedAlgorithms.CollectionReset += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset);
648      foreach (IAlgorithm algorithm in clonedAlgorithms)
649        RegisterClonedAlgorithmEvents(algorithm);
650    }
651    private void DeregisterClonedAlgorithmsEvents() {
652      clonedAlgorithms.ItemsAdded -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded);
653      clonedAlgorithms.ItemsRemoved -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved);
654      clonedAlgorithms.CollectionReset -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset);
655      foreach (IAlgorithm algorithm in clonedAlgorithms)
656        DeregisterClonedAlgorithmEvents(algorithm);
657    }
658    private void ClonedAlgorithms_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
659      foreach (IAlgorithm algorithm in e.Items)
660        RegisterClonedAlgorithmEvents(algorithm);
661    }
662    private void ClonedAlgorithms_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
663      foreach (IAlgorithm algorithm in e.Items)
664        DeregisterClonedAlgorithmEvents(algorithm);
665    }
666    private void ClonedAlgorithms_CollectionReset(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
667      foreach (IAlgorithm algorithm in e.OldItems)
668        DeregisterClonedAlgorithmEvents(algorithm);
669      foreach (IAlgorithm algorithm in e.Items)
670        RegisterClonedAlgorithmEvents(algorithm);
671    }
672    private void RegisterClonedAlgorithmEvents(IAlgorithm algorithm) {
673      algorithm.ExceptionOccurred += new EventHandler<EventArgs<Exception>>(ClonedAlgorithm_ExceptionOccurred);
674      algorithm.ExecutionTimeChanged += new EventHandler(ClonedAlgorithm_ExecutionTimeChanged);
675      algorithm.Started += new EventHandler(ClonedAlgorithm_Started);
676      algorithm.Paused += new EventHandler(ClonedAlgorithm_Paused);
677      algorithm.Stopped += new EventHandler(ClonedAlgorithm_Stopped);
678    }
679    private void DeregisterClonedAlgorithmEvents(IAlgorithm algorithm) {
680      algorithm.ExceptionOccurred -= new EventHandler<EventArgs<Exception>>(ClonedAlgorithm_ExceptionOccurred);
681      algorithm.ExecutionTimeChanged -= new EventHandler(ClonedAlgorithm_ExecutionTimeChanged);
682      algorithm.Started -= new EventHandler(ClonedAlgorithm_Started);
683      algorithm.Paused -= new EventHandler(ClonedAlgorithm_Paused);
684      algorithm.Stopped -= new EventHandler(ClonedAlgorithm_Stopped);
685    }
686    private void ClonedAlgorithm_ExceptionOccurred(object sender, EventArgs<Exception> e) {
687      OnExceptionOccurred(e.Value);
688    }
689    private void ClonedAlgorithm_ExecutionTimeChanged(object sender, EventArgs e) {
690      OnExecutionTimeChanged();
691    }
692
693    private readonly object locker = new object();
694    private readonly object resultLocker = new object();
695    private void ClonedAlgorithm_Started(object sender, EventArgs e) {
696      IAlgorithm algorithm = sender as IAlgorithm;
697      lock (resultLocker) {
698        if (algorithm != null && !results.ContainsKey(algorithm.Name))
699          results.Add(new Result(algorithm.Name, "Contains results for the specific fold.", algorithm.Results));
700      }
701    }
702
703    private void ClonedAlgorithm_Paused(object sender, EventArgs e) {
704      lock (locker) {
705        if (pausePending && clonedAlgorithms.All(alg => alg.ExecutionState != ExecutionState.Started))
706          OnPaused();
707      }
708    }
709
710    private void ClonedAlgorithm_Stopped(object sender, EventArgs e) {
711      lock (locker) {
712        if (!stopPending && ExecutionState == ExecutionState.Started) {
713          IAlgorithm preparedAlgorithm = clonedAlgorithms.FirstOrDefault(alg => alg.ExecutionState == ExecutionState.Prepared ||
714                                                                                alg.ExecutionState == ExecutionState.Paused);
715          if (preparedAlgorithm != null) preparedAlgorithm.Start();
716        }
717        if (ExecutionState != ExecutionState.Stopped) {
718          if (clonedAlgorithms.All(alg => alg.ExecutionState == ExecutionState.Stopped))
719            OnStopped();
720          else if (stopPending &&
721                   clonedAlgorithms.All(
722                     alg => alg.ExecutionState == ExecutionState.Prepared || alg.ExecutionState == ExecutionState.Stopped))
723            OnStopped();
724        }
725      }
726    }
727    #endregion
728    #endregion
729
730    #region event firing
731    public event EventHandler ExecutionStateChanged;
732    private void OnExecutionStateChanged() {
733      EventHandler handler = ExecutionStateChanged;
734      if (handler != null) handler(this, EventArgs.Empty);
735    }
736    public event EventHandler ExecutionTimeChanged;
737    private void OnExecutionTimeChanged() {
738      EventHandler handler = ExecutionTimeChanged;
739      if (handler != null) handler(this, EventArgs.Empty);
740    }
741    public event EventHandler Prepared;
742    private void OnPrepared() {
743      ExecutionState = ExecutionState.Prepared;
744      EventHandler handler = Prepared;
745      if (handler != null) handler(this, EventArgs.Empty);
746      OnExecutionTimeChanged();
747    }
748    public event EventHandler Started;
749    private void OnStarted() {
750      ExecutionState = ExecutionState.Started;
751      EventHandler handler = Started;
752      if (handler != null) handler(this, EventArgs.Empty);
753    }
754    public event EventHandler Paused;
755    private void OnPaused() {
756      pausePending = false;
757      ExecutionState = ExecutionState.Paused;
758      EventHandler handler = Paused;
759      if (handler != null) handler(this, EventArgs.Empty);
760    }
761    public event EventHandler Stopped;
762    private void OnStopped() {
763      stopPending = false;
764      Dictionary<string, IItem> collectedResults = new Dictionary<string, IItem>();
765      AggregateResultValues(collectedResults);
766      results.AddRange(collectedResults.Select(x => new Result(x.Key, x.Value)).Cast<IResult>().ToArray());
767      clonedAlgorithms.Clear();
768      runsCounter++;
769      runs.Add(new Run(string.Format("{0} Run {1}", Name, runsCounter), this));
770      ExecutionState = ExecutionState.Stopped;
771      EventHandler handler = Stopped;
772      if (handler != null) handler(this, EventArgs.Empty);
773    }
774    public event EventHandler<EventArgs<Exception>> ExceptionOccurred;
775    private void OnExceptionOccurred(Exception exception) {
776      EventHandler<EventArgs<Exception>> handler = ExceptionOccurred;
777      if (handler != null) handler(this, new EventArgs<Exception>(exception));
778    }
779    public event EventHandler StoreAlgorithmInEachRunChanged;
780    private void OnStoreAlgorithmInEachRunChanged() {
781      EventHandler handler = StoreAlgorithmInEachRunChanged;
782      if (handler != null) handler(this, EventArgs.Empty);
783    }
784    #endregion
785  }
786}
Note: See TracBrowser for help on using the repository browser.