Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/CrossValidation.cs @ 8738

Last change on this file since 8738 was 8738, checked in by mkommend, 12 years ago

#1673: Added new property AlgorithmName to the RunCollection and synced the property with the name of the surrounding IOptimizer. The AlgorithmName is used by the RunCollectionViews as prefix for its caption if it was set.

File size: 32.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using HeuristicLab.Collections;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Optimization;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Problems.DataAnalysis;
33using HeuristicLab.Problems.DataAnalysis.Symbolic;
34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  [Item("Cross Validation", "Cross-validation wrapper for data analysis algorithms.")]
37  [Creatable("Data Analysis")]
38  [StorableClass]
39  public sealed class CrossValidation : ParameterizedNamedItem, IAlgorithm, IStorableContent {
40    public CrossValidation()
41      : base() {
42      name = ItemName;
43      description = ItemDescription;
44
45      executionState = ExecutionState.Stopped;
46      runs = new RunCollection { AlgorithmName = name };
47      runsCounter = 0;
48
49      algorithm = null;
50      clonedAlgorithms = new ItemCollection<IAlgorithm>();
51      results = new ResultCollection();
52
53      folds = new IntValue(2);
54      numberOfWorkers = new IntValue(1);
55      samplesStart = new IntValue(0);
56      samplesEnd = new IntValue(0);
57      storeAlgorithmInEachRun = false;
58
59      RegisterEvents();
60      if (Algorithm != null) RegisterAlgorithmEvents();
61    }
62
63    public string Filename { get; set; }
64
65    #region persistence and cloning
66    [StorableConstructor]
67    private CrossValidation(bool deserializing)
68      : base(deserializing) {
69    }
70    [StorableHook(HookType.AfterDeserialization)]
71    private void AfterDeserialization() {
72      RegisterEvents();
73      if (Algorithm != null) RegisterAlgorithmEvents();
74    }
75
76    private CrossValidation(CrossValidation original, Cloner cloner)
77      : base(original, cloner) {
78      executionState = original.executionState;
79      storeAlgorithmInEachRun = original.storeAlgorithmInEachRun;
80      runs = cloner.Clone(original.runs);
81      runsCounter = original.runsCounter;
82      algorithm = cloner.Clone(original.algorithm);
83      clonedAlgorithms = cloner.Clone(original.clonedAlgorithms);
84      results = cloner.Clone(original.results);
85
86      folds = cloner.Clone(original.folds);
87      numberOfWorkers = cloner.Clone(original.numberOfWorkers);
88      samplesStart = cloner.Clone(original.samplesStart);
89      samplesEnd = cloner.Clone(original.samplesEnd);
90      RegisterEvents();
91      if (Algorithm != null) RegisterAlgorithmEvents();
92    }
93    public override IDeepCloneable Clone(Cloner cloner) {
94      return new CrossValidation(this, cloner);
95    }
96
97    #endregion
98
99    #region properties
100    [Storable]
101    private IAlgorithm algorithm;
102    public IAlgorithm Algorithm {
103      get { return algorithm; }
104      set {
105        if (ExecutionState != ExecutionState.Prepared && ExecutionState != ExecutionState.Stopped)
106          throw new InvalidOperationException("Changing the algorithm is only allowed if the CrossValidation is stopped or prepared.");
107        if (algorithm != value) {
108          if (value != null && value.Problem != null && !(value.Problem is IDataAnalysisProblem))
109            throw new ArgumentException("Only algorithms with a DataAnalysisProblem could be used for the cross validation.");
110          if (algorithm != null) DeregisterAlgorithmEvents();
111          algorithm = value;
112          Parameters.Clear();
113
114          if (algorithm != null) {
115            algorithm.StoreAlgorithmInEachRun = false;
116            RegisterAlgorithmEvents();
117            algorithm.Prepare(true);
118            Parameters.AddRange(algorithm.Parameters);
119          }
120          OnAlgorithmChanged();
121          if (algorithm != null) OnProblemChanged();
122          Prepare();
123        }
124      }
125    }
126
127
128    [Storable]
129    private IDataAnalysisProblem problem;
130    public IDataAnalysisProblem Problem {
131      get {
132        if (algorithm == null)
133          return null;
134        return (IDataAnalysisProblem)algorithm.Problem;
135      }
136      set {
137        if (ExecutionState != ExecutionState.Prepared && ExecutionState != ExecutionState.Stopped)
138          throw new InvalidOperationException("Changing the problem is only allowed if the CrossValidation is stopped or prepared.");
139        if (algorithm == null) throw new ArgumentNullException("Could not set a problem before an algorithm was set.");
140        algorithm.Problem = value;
141        problem = value;
142      }
143    }
144
145    IProblem IAlgorithm.Problem {
146      get { return Problem; }
147      set {
148        if (value != null && !ProblemType.IsInstanceOfType(value))
149          throw new ArgumentException("Only DataAnalysisProblems could be used for the cross validation.");
150        Problem = (IDataAnalysisProblem)value;
151      }
152    }
153    public Type ProblemType {
154      get { return typeof(IDataAnalysisProblem); }
155    }
156
157    [Storable]
158    private ItemCollection<IAlgorithm> clonedAlgorithms;
159
160    public IEnumerable<IOptimizer> NestedOptimizers {
161      get {
162        if (Algorithm == null) yield break;
163        yield return Algorithm;
164      }
165    }
166
167    [Storable]
168    private ResultCollection results;
169    public ResultCollection Results {
170      get { return results; }
171    }
172
173    [Storable]
174    private IntValue folds;
175    public IntValue Folds {
176      get { return folds; }
177    }
178    [Storable]
179    private IntValue samplesStart;
180    public IntValue SamplesStart {
181      get { return samplesStart; }
182    }
183    [Storable]
184    private IntValue samplesEnd;
185    public IntValue SamplesEnd {
186      get { return samplesEnd; }
187    }
188    [Storable]
189    private IntValue numberOfWorkers;
190    public IntValue NumberOfWorkers {
191      get { return numberOfWorkers; }
192    }
193
194    [Storable]
195    private bool storeAlgorithmInEachRun;
196    public bool StoreAlgorithmInEachRun {
197      get { return storeAlgorithmInEachRun; }
198      set {
199        if (storeAlgorithmInEachRun != value) {
200          storeAlgorithmInEachRun = value;
201          OnStoreAlgorithmInEachRunChanged();
202        }
203      }
204    }
205
206    [Storable]
207    private int runsCounter;
208    [Storable]
209    private RunCollection runs;
210    public RunCollection Runs {
211      get { return runs; }
212    }
213
214    [Storable]
215    private ExecutionState executionState;
216    public ExecutionState ExecutionState {
217      get { return executionState; }
218      private set {
219        if (executionState != value) {
220          executionState = value;
221          OnExecutionStateChanged();
222          OnItemImageChanged();
223        }
224      }
225    }
226    public static new Image StaticItemImage {
227      get { return HeuristicLab.Common.Resources.VSImageLibrary.Event; }
228    }
229    public override Image ItemImage {
230      get {
231        if (ExecutionState == ExecutionState.Prepared) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutablePrepared;
232        else if (ExecutionState == ExecutionState.Started) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutableStarted;
233        else if (ExecutionState == ExecutionState.Paused) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutablePaused;
234        else if (ExecutionState == ExecutionState.Stopped) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutableStopped;
235        else return base.ItemImage;
236      }
237    }
238
239    public TimeSpan ExecutionTime {
240      get {
241        if (ExecutionState != ExecutionState.Prepared)
242          return TimeSpan.FromMilliseconds(clonedAlgorithms.Select(x => x.ExecutionTime.TotalMilliseconds).Sum());
243        return TimeSpan.Zero;
244      }
245    }
246    #endregion
247
248    protected override void OnNameChanged() {
249      base.OnNameChanged();
250      Runs.AlgorithmName = Name;
251    }
252
253    public void Prepare() {
254      if (ExecutionState == ExecutionState.Started)
255        throw new InvalidOperationException(string.Format("Prepare not allowed in execution state \"{0}\".", ExecutionState));
256      results.Clear();
257      clonedAlgorithms.Clear();
258      if (Algorithm != null) {
259        Algorithm.Prepare();
260        if (Algorithm.ExecutionState == ExecutionState.Prepared) OnPrepared();
261      }
262    }
263    public void Prepare(bool clearRuns) {
264      if (clearRuns) runs.Clear();
265      Prepare();
266    }
267
268    private bool startPending;
269    public void Start() {
270      if ((ExecutionState != ExecutionState.Prepared) && (ExecutionState != ExecutionState.Paused))
271        throw new InvalidOperationException(string.Format("Start not allowed in execution state \"{0}\".", ExecutionState));
272
273      if (Algorithm != null && !startPending) {
274        startPending = true;
275        //create cloned algorithms
276        if (clonedAlgorithms.Count == 0) {
277          int testSamplesCount = (SamplesEnd.Value - SamplesStart.Value) / Folds.Value;
278
279          for (int i = 0; i < Folds.Value; i++) {
280            IAlgorithm clonedAlgorithm = (IAlgorithm)algorithm.Clone();
281            clonedAlgorithm.Name = algorithm.Name + " Fold " + i;
282            IDataAnalysisProblem problem = clonedAlgorithm.Problem as IDataAnalysisProblem;
283            ISymbolicDataAnalysisProblem symbolicProblem = problem as ISymbolicDataAnalysisProblem;
284
285            int testStart = (i * testSamplesCount) + SamplesStart.Value;
286            int testEnd = (i + 1) == Folds.Value ? SamplesEnd.Value : (i + 1) * testSamplesCount + SamplesStart.Value;
287
288            problem.ProblemData.TrainingPartition.Start = SamplesStart.Value;
289            problem.ProblemData.TrainingPartition.End = SamplesEnd.Value;
290            problem.ProblemData.TestPartition.Start = testStart;
291            problem.ProblemData.TestPartition.End = testEnd;
292            DataAnalysisProblemData problemData = problem.ProblemData as DataAnalysisProblemData;
293            if (problemData != null) {
294              problemData.TrainingPartitionParameter.Hidden = false;
295              problemData.TestPartitionParameter.Hidden = false;
296            }
297
298            if (symbolicProblem != null) {
299              symbolicProblem.FitnessCalculationPartition.Start = SamplesStart.Value;
300              symbolicProblem.FitnessCalculationPartition.End = SamplesEnd.Value;
301            }
302            clonedAlgorithm.Prepare();
303            clonedAlgorithms.Add(clonedAlgorithm);
304          }
305        }
306
307        //start prepared or paused cloned algorithms
308        int startedAlgorithms = 0;
309        foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
310          if (startedAlgorithms < NumberOfWorkers.Value) {
311            if (clonedAlgorithm.ExecutionState == ExecutionState.Prepared ||
312                clonedAlgorithm.ExecutionState == ExecutionState.Paused) {
313              clonedAlgorithm.Start();
314              startedAlgorithms++;
315            }
316          }
317        }
318        OnStarted();
319      }
320    }
321
322    private bool pausePending;
323    public void Pause() {
324      if (ExecutionState != ExecutionState.Started)
325        throw new InvalidOperationException(string.Format("Pause not allowed in execution state \"{0}\".", ExecutionState));
326      if (!pausePending) {
327        pausePending = true;
328        if (!startPending) PauseAllClonedAlgorithms();
329      }
330    }
331    private void PauseAllClonedAlgorithms() {
332      foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
333        if (clonedAlgorithm.ExecutionState == ExecutionState.Started)
334          clonedAlgorithm.Pause();
335      }
336    }
337
338    private bool stopPending;
339    public void Stop() {
340      if ((ExecutionState != ExecutionState.Started) && (ExecutionState != ExecutionState.Paused))
341        throw new InvalidOperationException(string.Format("Stop not allowed in execution state \"{0}\".",
342                                                          ExecutionState));
343      if (!stopPending) {
344        stopPending = true;
345        if (!startPending) StopAllClonedAlgorithms();
346      }
347    }
348    private void StopAllClonedAlgorithms() {
349      foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
350        if (clonedAlgorithm.ExecutionState == ExecutionState.Started ||
351            clonedAlgorithm.ExecutionState == ExecutionState.Paused)
352          clonedAlgorithm.Stop();
353      }
354    }
355
356    #region collect parameters and results
357    public override void CollectParameterValues(IDictionary<string, IItem> values) {
358      values.Add("Algorithm Name", new StringValue(Name));
359      values.Add("Algorithm Type", new StringValue(GetType().GetPrettyName()));
360      values.Add("Folds", new IntValue(Folds.Value));
361
362      if (algorithm != null) {
363        values.Add("CrossValidation Algorithm Name", new StringValue(Algorithm.Name));
364        values.Add("CrossValidation Algorithm Type", new StringValue(Algorithm.GetType().GetPrettyName()));
365        base.CollectParameterValues(values);
366      }
367      if (Problem != null) {
368        values.Add("Problem Name", new StringValue(Problem.Name));
369        values.Add("Problem Type", new StringValue(Problem.GetType().GetPrettyName()));
370        Problem.CollectParameterValues(values);
371      }
372    }
373
374    public void CollectResultValues(IDictionary<string, IItem> results) {
375      var clonedResults = (ResultCollection)this.results.Clone();
376      foreach (var result in clonedResults) {
377        results.Add(result.Name, result.Value);
378      }
379    }
380
381    private void AggregateResultValues(IDictionary<string, IItem> results) {
382      Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>();
383      IEnumerable<IRun> runs = clonedAlgorithms.Select(alg => alg.Runs.FirstOrDefault()).Where(run => run != null);
384      IEnumerable<KeyValuePair<string, IItem>> resultCollections = runs.Where(x => x != null).SelectMany(x => x.Results).ToList();
385
386      foreach (IResult result in ExtractAndAggregateResults<IntValue>(resultCollections))
387        results.Add(result.Name, result.Value);
388      foreach (IResult result in ExtractAndAggregateResults<DoubleValue>(resultCollections))
389        results.Add(result.Name, result.Value);
390      foreach (IResult result in ExtractAndAggregateResults<PercentValue>(resultCollections))
391        results.Add(result.Name, result.Value);
392      foreach (IResult result in ExtractAndAggregateRegressionSolutions(resultCollections)) {
393        results.Add(result.Name, result.Value);
394      }
395      foreach (IResult result in ExtractAndAggregateClassificationSolutions(resultCollections)) {
396        results.Add(result.Name, result.Value);
397      }
398      results.Add("Execution Time", new TimeSpanValue(this.ExecutionTime));
399      results.Add("CrossValidation Folds", new RunCollection(runs));
400    }
401
402    private IEnumerable<IResult> ExtractAndAggregateRegressionSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) {
403      Dictionary<string, List<IRegressionSolution>> resultSolutions = new Dictionary<string, List<IRegressionSolution>>();
404      foreach (var result in resultCollections) {
405        var regressionSolution = result.Value as IRegressionSolution;
406        if (regressionSolution != null) {
407          if (resultSolutions.ContainsKey(result.Key)) {
408            resultSolutions[result.Key].Add(regressionSolution);
409          } else {
410            resultSolutions.Add(result.Key, new List<IRegressionSolution>() { regressionSolution });
411          }
412        }
413      }
414      List<IResult> aggregatedResults = new List<IResult>();
415      foreach (KeyValuePair<string, List<IRegressionSolution>> solutions in resultSolutions) {
416        // clone manually to correctly clone references between cloned root objects
417        Cloner cloner = new Cloner();
418        var problemDataClone = (IRegressionProblemData)cloner.Clone(Problem.ProblemData);
419        // set partitions of problem data clone correctly
420        problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
421        problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
422        // clone models
423        var ensembleSolution = new RegressionEnsembleSolution(problemDataClone);
424        ensembleSolution.AddRegressionSolutions(solutions.Value);
425
426        aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));
427      }
428      List<IResult> flattenedResults = new List<IResult>();
429      CollectResultsRecursively("", aggregatedResults, flattenedResults);
430      return flattenedResults;
431    }
432
433    private IEnumerable<IResult> ExtractAndAggregateClassificationSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) {
434      Dictionary<string, List<IClassificationSolution>> resultSolutions = new Dictionary<string, List<IClassificationSolution>>();
435      foreach (var result in resultCollections) {
436        var classificationSolution = result.Value as IClassificationSolution;
437        if (classificationSolution != null) {
438          if (resultSolutions.ContainsKey(result.Key)) {
439            resultSolutions[result.Key].Add(classificationSolution);
440          } else {
441            resultSolutions.Add(result.Key, new List<IClassificationSolution>() { classificationSolution });
442          }
443        }
444      }
445      var aggregatedResults = new List<IResult>();
446      foreach (KeyValuePair<string, List<IClassificationSolution>> solutions in resultSolutions) {
447        // clone manually to correctly clone references between cloned root objects
448        Cloner cloner = new Cloner();
449        var problemDataClone = (IClassificationProblemData)cloner.Clone(Problem.ProblemData);
450        // set partitions of problem data clone correctly
451        problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
452        problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
453        // clone models
454        var ensembleSolution = new ClassificationEnsembleSolution(problemDataClone);
455        ensembleSolution.AddClassificationSolutions(solutions.Value);
456
457        aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));
458      }
459      List<IResult> flattenedResults = new List<IResult>();
460      CollectResultsRecursively("", aggregatedResults, flattenedResults);
461      return flattenedResults;
462    }
463
464    private void CollectResultsRecursively(string path, IEnumerable<IResult> results, IList<IResult> flattenedResults) {
465      foreach (IResult result in results) {
466        flattenedResults.Add(new Result(path + result.Name, result.Value));
467        ResultCollection childCollection = result.Value as ResultCollection;
468        if (childCollection != null) {
469          CollectResultsRecursively(path + result.Name + ".", childCollection, flattenedResults);
470        }
471      }
472    }
473
474    private static IEnumerable<IResult> ExtractAndAggregateResults<T>(IEnumerable<KeyValuePair<string, IItem>> results)
475  where T : class, IItem, new() {
476      Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>();
477      foreach (var resultValue in results.Where(r => r.Value.GetType() == typeof(T))) {
478        if (!resultValues.ContainsKey(resultValue.Key))
479          resultValues[resultValue.Key] = new List<double>();
480        resultValues[resultValue.Key].Add(ConvertToDouble(resultValue.Value));
481      }
482
483      DoubleValue doubleValue;
484      if (typeof(T) == typeof(PercentValue))
485        doubleValue = new PercentValue();
486      else if (typeof(T) == typeof(DoubleValue))
487        doubleValue = new DoubleValue();
488      else if (typeof(T) == typeof(IntValue))
489        doubleValue = new DoubleValue();
490      else
491        throw new NotSupportedException();
492
493      List<IResult> aggregatedResults = new List<IResult>();
494      foreach (KeyValuePair<string, List<double>> resultValue in resultValues) {
495        doubleValue.Value = resultValue.Value.Average();
496        aggregatedResults.Add(new Result(resultValue.Key + " (average)", (IItem)doubleValue.Clone()));
497        doubleValue.Value = resultValue.Value.StandardDeviation();
498        aggregatedResults.Add(new Result(resultValue.Key + " (std.dev.)", (IItem)doubleValue.Clone()));
499      }
500      return aggregatedResults;
501    }
502
503    private static double ConvertToDouble(IItem item) {
504      if (item is DoubleValue) return ((DoubleValue)item).Value;
505      else if (item is IntValue) return ((IntValue)item).Value;
506      else throw new NotSupportedException("Could not convert any item type to double");
507    }
508    #endregion
509
510    #region events
511    private void RegisterEvents() {
512      Folds.ValueChanged += new EventHandler(Folds_ValueChanged);
513      SamplesStart.ValueChanged += new EventHandler(SamplesStart_ValueChanged);
514      SamplesEnd.ValueChanged += new EventHandler(SamplesEnd_ValueChanged);
515      RegisterClonedAlgorithmsEvents();
516    }
517    private void Folds_ValueChanged(object sender, EventArgs e) {
518      if (ExecutionState != ExecutionState.Prepared)
519        throw new InvalidOperationException("Can not change number of folds if the execution state is not prepared.");
520    }
521
522    private bool samplesChanged = false;
523    private void SamplesStart_ValueChanged(object sender, EventArgs e) {
524      samplesChanged = true;
525      if (Problem != null) Problem.ProblemData.TrainingPartition.Start = SamplesStart.Value;
526      samplesChanged = false;
527    }
528    private void SamplesEnd_ValueChanged(object sender, EventArgs e) {
529      samplesChanged = true;
530      if (Problem != null) Problem.ProblemData.TrainingPartition.End = SamplesEnd.Value;
531      samplesChanged = false;
532    }
533
534    #region template algorithms events
535    public event EventHandler AlgorithmChanged;
536    private void OnAlgorithmChanged() {
537      EventHandler handler = AlgorithmChanged;
538      if (handler != null) handler(this, EventArgs.Empty);
539      OnProblemChanged();
540      if (Problem == null) ExecutionState = ExecutionState.Stopped;
541    }
542    private void RegisterAlgorithmEvents() {
543      algorithm.ProblemChanged += new EventHandler(Algorithm_ProblemChanged);
544      algorithm.ExecutionStateChanged += new EventHandler(Algorithm_ExecutionStateChanged);
545    }
546    private void DeregisterAlgorithmEvents() {
547      algorithm.ProblemChanged -= new EventHandler(Algorithm_ProblemChanged);
548      algorithm.ExecutionStateChanged -= new EventHandler(Algorithm_ExecutionStateChanged);
549    }
550    private void Algorithm_ProblemChanged(object sender, EventArgs e) {
551      if (algorithm.Problem != null && !(algorithm.Problem is IDataAnalysisProblem)) {
552        algorithm.Problem = problem;
553        throw new ArgumentException("A cross validation algorithm can only contain DataAnalysisProblems.");
554      }
555      algorithm.Problem.Reset += (x, y) => OnProblemChanged();
556      problem = (IDataAnalysisProblem)algorithm.Problem;
557      OnProblemChanged();
558    }
559    public event EventHandler ProblemChanged;
560    private void OnProblemChanged() {
561      EventHandler handler = ProblemChanged;
562      if (handler != null) handler(this, EventArgs.Empty);
563      if (samplesChanged) return;
564
565      SamplesStart.Value = 0;
566      if (Problem != null) {
567        SamplesEnd.Value = Problem.ProblemData.Dataset.Rows;
568
569        DataAnalysisProblemData problemData = Problem.ProblemData as DataAnalysisProblemData;
570        if (problemData != null) {
571          problemData.TrainingPartitionParameter.Hidden = true;
572          problemData.TestPartitionParameter.Hidden = true;
573        }
574        ISymbolicDataAnalysisProblem symbolicProblem = Problem as ISymbolicDataAnalysisProblem;
575        if (symbolicProblem != null) {
576          symbolicProblem.FitnessCalculationPartitionParameter.Hidden = true;
577          symbolicProblem.FitnessCalculationPartition.Start = SamplesStart.Value;
578          symbolicProblem.FitnessCalculationPartition.End = SamplesEnd.Value;
579          symbolicProblem.ValidationPartitionParameter.Hidden = true;
580          symbolicProblem.ValidationPartition.Start = 0;
581          symbolicProblem.ValidationPartition.End = 0;
582        }
583      } else
584        SamplesEnd.Value = 0;
585
586      SamplesStart_ValueChanged(this, EventArgs.Empty);
587      SamplesEnd_ValueChanged(this, EventArgs.Empty);
588    }
589
590    private void Algorithm_ExecutionStateChanged(object sender, EventArgs e) {
591      switch (Algorithm.ExecutionState) {
592        case ExecutionState.Prepared: OnPrepared();
593          break;
594        case ExecutionState.Started: throw new InvalidOperationException("Algorithm template can not be started.");
595        case ExecutionState.Paused: throw new InvalidOperationException("Algorithm template can not be paused.");
596        case ExecutionState.Stopped: OnStopped();
597          break;
598      }
599    }
600    #endregion
601
602    #region clonedAlgorithms events
603    private void RegisterClonedAlgorithmsEvents() {
604      clonedAlgorithms.ItemsAdded += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded);
605      clonedAlgorithms.ItemsRemoved += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved);
606      clonedAlgorithms.CollectionReset += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset);
607      foreach (IAlgorithm algorithm in clonedAlgorithms)
608        RegisterClonedAlgorithmEvents(algorithm);
609    }
610    private void DeregisterClonedAlgorithmsEvents() {
611      clonedAlgorithms.ItemsAdded -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded);
612      clonedAlgorithms.ItemsRemoved -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved);
613      clonedAlgorithms.CollectionReset -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset);
614      foreach (IAlgorithm algorithm in clonedAlgorithms)
615        DeregisterClonedAlgorithmEvents(algorithm);
616    }
617    private void ClonedAlgorithms_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
618      foreach (IAlgorithm algorithm in e.Items)
619        RegisterClonedAlgorithmEvents(algorithm);
620    }
621    private void ClonedAlgorithms_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
622      foreach (IAlgorithm algorithm in e.Items)
623        DeregisterClonedAlgorithmEvents(algorithm);
624    }
625    private void ClonedAlgorithms_CollectionReset(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
626      foreach (IAlgorithm algorithm in e.OldItems)
627        DeregisterClonedAlgorithmEvents(algorithm);
628      foreach (IAlgorithm algorithm in e.Items)
629        RegisterClonedAlgorithmEvents(algorithm);
630    }
631    private void RegisterClonedAlgorithmEvents(IAlgorithm algorithm) {
632      algorithm.ExceptionOccurred += new EventHandler<EventArgs<Exception>>(ClonedAlgorithm_ExceptionOccurred);
633      algorithm.ExecutionTimeChanged += new EventHandler(ClonedAlgorithm_ExecutionTimeChanged);
634      algorithm.Started += new EventHandler(ClonedAlgorithm_Started);
635      algorithm.Paused += new EventHandler(ClonedAlgorithm_Paused);
636      algorithm.Stopped += new EventHandler(ClonedAlgorithm_Stopped);
637    }
638    private void DeregisterClonedAlgorithmEvents(IAlgorithm algorithm) {
639      algorithm.ExceptionOccurred -= new EventHandler<EventArgs<Exception>>(ClonedAlgorithm_ExceptionOccurred);
640      algorithm.ExecutionTimeChanged -= new EventHandler(ClonedAlgorithm_ExecutionTimeChanged);
641      algorithm.Started -= new EventHandler(ClonedAlgorithm_Started);
642      algorithm.Paused -= new EventHandler(ClonedAlgorithm_Paused);
643      algorithm.Stopped -= new EventHandler(ClonedAlgorithm_Stopped);
644    }
645    private void ClonedAlgorithm_ExceptionOccurred(object sender, EventArgs<Exception> e) {
646      OnExceptionOccurred(e.Value);
647    }
648    private void ClonedAlgorithm_ExecutionTimeChanged(object sender, EventArgs e) {
649      OnExecutionTimeChanged();
650    }
651
652    private readonly object locker = new object();
653    private void ClonedAlgorithm_Started(object sender, EventArgs e) {
654      lock (locker) {
655        IAlgorithm algorithm = sender as IAlgorithm;
656        if (algorithm != null && !results.ContainsKey(algorithm.Name))
657          results.Add(new Result(algorithm.Name, "Contains results for the specific fold.", algorithm.Results));
658
659        if (startPending) {
660          int startedAlgorithms = clonedAlgorithms.Count(alg => alg.ExecutionState == ExecutionState.Started);
661          if (startedAlgorithms == NumberOfWorkers.Value ||
662             clonedAlgorithms.All(alg => alg.ExecutionState != ExecutionState.Prepared))
663            startPending = false;
664
665          if (pausePending) PauseAllClonedAlgorithms();
666          if (stopPending) StopAllClonedAlgorithms();
667        }
668      }
669    }
670
671    private void ClonedAlgorithm_Paused(object sender, EventArgs e) {
672      lock (locker) {
673        if (pausePending && clonedAlgorithms.All(alg => alg.ExecutionState != ExecutionState.Started))
674          OnPaused();
675      }
676    }
677
678    private void ClonedAlgorithm_Stopped(object sender, EventArgs e) {
679      lock (locker) {
680        if (!stopPending && ExecutionState == ExecutionState.Started) {
681          IAlgorithm preparedAlgorithm = clonedAlgorithms.Where(alg => alg.ExecutionState == ExecutionState.Prepared ||
682                                                                       alg.ExecutionState == ExecutionState.Paused).FirstOrDefault();
683          if (preparedAlgorithm != null) preparedAlgorithm.Start();
684        }
685        if (ExecutionState != ExecutionState.Stopped) {
686          if (clonedAlgorithms.All(alg => alg.ExecutionState == ExecutionState.Stopped))
687            OnStopped();
688          else if (stopPending &&
689                   clonedAlgorithms.All(
690                     alg => alg.ExecutionState == ExecutionState.Prepared || alg.ExecutionState == ExecutionState.Stopped))
691            OnStopped();
692        }
693      }
694    }
695    #endregion
696    #endregion
697
698    #region event firing
699    public event EventHandler ExecutionStateChanged;
700    private void OnExecutionStateChanged() {
701      EventHandler handler = ExecutionStateChanged;
702      if (handler != null) handler(this, EventArgs.Empty);
703    }
704    public event EventHandler ExecutionTimeChanged;
705    private void OnExecutionTimeChanged() {
706      EventHandler handler = ExecutionTimeChanged;
707      if (handler != null) handler(this, EventArgs.Empty);
708    }
709    public event EventHandler Prepared;
710    private void OnPrepared() {
711      ExecutionState = ExecutionState.Prepared;
712      EventHandler handler = Prepared;
713      if (handler != null) handler(this, EventArgs.Empty);
714      OnExecutionTimeChanged();
715    }
716    public event EventHandler Started;
717    private void OnStarted() {
718      startPending = false;
719      ExecutionState = ExecutionState.Started;
720      EventHandler handler = Started;
721      if (handler != null) handler(this, EventArgs.Empty);
722    }
723    public event EventHandler Paused;
724    private void OnPaused() {
725      pausePending = false;
726      ExecutionState = ExecutionState.Paused;
727      EventHandler handler = Paused;
728      if (handler != null) handler(this, EventArgs.Empty);
729    }
730    public event EventHandler Stopped;
731    private void OnStopped() {
732      stopPending = false;
733      Dictionary<string, IItem> collectedResults = new Dictionary<string, IItem>();
734      AggregateResultValues(collectedResults);
735      results.AddRange(collectedResults.Select(x => new Result(x.Key, x.Value)).Cast<IResult>().ToArray());
736      runsCounter++;
737      runs.Add(new Run(string.Format("{0} Run {1}", Name, runsCounter), this));
738      ExecutionState = ExecutionState.Stopped;
739      EventHandler handler = Stopped;
740      if (handler != null) handler(this, EventArgs.Empty);
741    }
742    public event EventHandler<EventArgs<Exception>> ExceptionOccurred;
743    private void OnExceptionOccurred(Exception exception) {
744      EventHandler<EventArgs<Exception>> handler = ExceptionOccurred;
745      if (handler != null) handler(this, new EventArgs<Exception>(exception));
746    }
747    public event EventHandler StoreAlgorithmInEachRunChanged;
748    private void OnStoreAlgorithmInEachRunChanged() {
749      EventHandler handler = StoreAlgorithmInEachRunChanged;
750      if (handler != null) handler(this, EventArgs.Empty);
751    }
752    #endregion
753  }
754}
Note: See TracBrowser for help on using the repository browser.