Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Classification/HeuristicLab.Algorithms.DataAnalysis/3.3/CrossValidation.cs @ 4542

Last change on this file since 4542 was 4542, checked in by mkommend, 14 years ago

Added persistence, cloning to CrossValidation (ticket #1199)

File size: 21.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Drawing;
24using System.Linq;
25using HeuristicLab.Collections;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Optimization;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  [Item("Cross Validation", "Cross Validation wrapper for data analysis algorithms.")]
35  [Creatable("Data Analysis")]
36  [StorableClass]
37  public sealed class CrossValidation : NamedItem, IOptimizer {
38    public CrossValidation()
39      : base() {
40      name = ItemName;
41      description = ItemDescription;
42
43      executionState = ExecutionState.Stopped;
44      executionTime = TimeSpan.Zero;
45      runs = new RunCollection();
46
47      algorithm = null;
48      clonedAlgorithms = new ItemCollection<IAlgorithm>();
49      readOnlyClonedAlgorithms = null;
50
51      folds = new IntValue(1);
52      numberOfWorkers = new IntValue(1);
53      samplesStart = new IntValue(0);
54      samplesEnd = new IntValue(0);
55
56
57      RegisterEvents();
58    }
59
60    #region persistence and cloning
61    [StorableConstructor]
62    private CrossValidation(bool deserializing)
63      : base(deserializing) {
64    }
65    [StorableHook(HookType.AfterDeserialization)]
66    private void AfterDeserialization() {
67      RegisterEvents();
68    }
69
70    public override IDeepCloneable Clone(Cloner cloner) {
71      if (ExecutionState == ExecutionState.Started) throw new InvalidOperationException(string.Format("Clone not allowed in execution state \"{0}\".", ExecutionState));
72      CrossValidation clone = new CrossValidation(false);
73      cloner.RegisterClonedObject(this, clone);
74      clone.name = name;
75      clone.description = description;
76      clone.executionState = executionState;
77      clone.executionTime = executionTime;
78      clone.runs = (RunCollection)cloner.Clone(runs);
79      clone.algorithm = (IAlgorithm)cloner.Clone(algorithm);
80      clone.clonedAlgorithms = (ItemCollection<IAlgorithm>)cloner.Clone(clonedAlgorithms);
81      clone.folds = (IntValue)cloner.Clone(folds);
82      clone.numberOfWorkers = (IntValue)cloner.Clone(numberOfWorkers);
83      clone.samplesStart = (IntValue)cloner.Clone(samplesStart);
84      clone.samplesEnd = (IntValue)cloner.Clone(samplesEnd);
85      clone.RegisterEvents();
86      return clone;
87    }
88
89    #endregion
90
91    #region properties
92    [Storable]
93    private IAlgorithm algorithm;
94    public IAlgorithm Algorithm {
95      get { return algorithm; }
96      set {
97        if (ExecutionState != ExecutionState.Prepared && ExecutionState != ExecutionState.Stopped)
98          throw new InvalidOperationException("Changing the algorithm is only allowed if the CrossValidation is stopped or prepared.");
99        if (algorithm != value) {
100          if (value != null && value.Problem != null && !(value.Problem is IDataAnalysisProblem))
101            throw new ArgumentException("Only algorithms with a DataAnalysisProblem could be used for the cross validation.");
102          if (algorithm != null) DeregisterAlgorithmEvents();
103          algorithm = value;
104
105          if (algorithm != null) {
106            RegisterAlgorithmEvents();
107            algorithm.Prepare(true);
108          }
109          OnAlgorithmChanged();
110          Prepare();
111        }
112      }
113    }
114
115    [Storable]
116    private IDataAnalysisProblem cachedProblem;
117    public IDataAnalysisProblem Problem {
118      get {
119        if (algorithm == null)
120          return null;
121        return (IDataAnalysisProblem)algorithm.Problem;
122      }
123      set {
124        if (ExecutionState != ExecutionState.Prepared && ExecutionState != ExecutionState.Stopped)
125          throw new InvalidOperationException("Changing the problem is only allowed if the CrossValidation is stopped or prepared.");
126        if (algorithm == null) throw new ArgumentNullException("Could not set a problem before an algorithm was set.");
127        algorithm.Problem = value;
128        cachedProblem = value;
129      }
130    }
131
132    [Storable]
133    private ItemCollection<IAlgorithm> clonedAlgorithms;
134    private ReadOnlyItemCollection<IAlgorithm> readOnlyClonedAlgorithms;
135    public IItemCollection<IAlgorithm> ClonedAlgorithms {
136      get {
137        if (readOnlyClonedAlgorithms == null) readOnlyClonedAlgorithms = clonedAlgorithms.AsReadOnly();
138        return readOnlyClonedAlgorithms;
139      }
140    }
141
142    [Storable]
143    private IntValue folds;
144    public IntValue Folds {
145      get { return folds; }
146    }
147    [Storable]
148    private IntValue samplesStart;
149    public IntValue SamplesStart {
150      get { return samplesStart; }
151    }
152    [Storable]
153    private IntValue samplesEnd;
154    public IntValue SamplesEnd {
155      get { return samplesEnd; }
156    }
157    [Storable]
158    private IntValue numberOfWorkers;
159    public IntValue NumberOfWorkers {
160      get { return numberOfWorkers; }
161    }
162
163    [Storable]
164    private RunCollection runs;
165    public RunCollection Runs {
166      get { return runs; }
167    }
168    [Storable]
169    private ExecutionState executionState;
170    public ExecutionState ExecutionState {
171      get { return executionState; }
172      private set {
173        if (executionState != value) {
174          executionState = value;
175          OnExecutionStateChanged();
176          OnItemImageChanged();
177        }
178      }
179    }
180    public override Image ItemImage {
181      get {
182        if (ExecutionState == ExecutionState.Prepared) return HeuristicLab.Common.Resources.VS2008ImageLibrary.ExecutablePrepared;
183        else if (ExecutionState == ExecutionState.Started) return HeuristicLab.Common.Resources.VS2008ImageLibrary.ExecutableStarted;
184        else if (ExecutionState == ExecutionState.Paused) return HeuristicLab.Common.Resources.VS2008ImageLibrary.ExecutablePaused;
185        else if (ExecutionState == ExecutionState.Stopped) return HeuristicLab.Common.Resources.VS2008ImageLibrary.ExecutableStopped;
186        else return HeuristicLab.Common.Resources.VS2008ImageLibrary.Event;
187      }
188    }
189
190    [Storable]
191    private TimeSpan executionTime;
192    public TimeSpan ExecutionTime {
193      get {
194        if (ExecutionState != ExecutionState.Stopped)
195          return executionTime + TimeSpan.FromMilliseconds(clonedAlgorithms.Select(x => x.ExecutionTime.TotalMilliseconds).Sum());
196        else
197          return executionTime;
198      }
199      private set {
200        executionTime = value;
201        OnExecutionTimeChanged();
202      }
203    }
204    #endregion
205
206    public void Prepare() {
207      if ((ExecutionState != ExecutionState.Prepared) && (ExecutionState != ExecutionState.Paused) && (ExecutionState != ExecutionState.Stopped))
208        throw new InvalidOperationException(string.Format("Prepare not allowed in execution state \"{0}\".", ExecutionState));
209      clonedAlgorithms.Clear();
210      if (Algorithm != null) {
211        Algorithm.Prepare();
212        if (Algorithm.ExecutionState == ExecutionState.Prepared) OnPrepared();
213      }
214    }
215    public void Prepare(bool clearRuns) {
216      if (clearRuns) runs.Clear();
217      Prepare();
218    }
219
220    private bool startPending;
221    public void Start() {
222      if ((ExecutionState != ExecutionState.Prepared) && (ExecutionState != ExecutionState.Paused))
223        throw new InvalidOperationException(string.Format("Start not allowed in execution state \"{0}\".", ExecutionState));
224
225      if (Algorithm != null && !startPending) {
226        startPending = true;
227        //create cloned algorithms
228        if (clonedAlgorithms.Count == 0) {
229          for (int i = 0; i < Folds.Value; i++) {
230            IAlgorithm clonedAlgorithm = (IAlgorithm)algorithm.Clone();
231            clonedAlgorithm.Name = algorithm.Name + " Fold " + i;
232            clonedAlgorithms.Add(clonedAlgorithm);
233          }
234        }
235
236        //start prepared or paused cloned algorithms
237        int startedAlgorithms = 0;
238        foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
239          if (startedAlgorithms < NumberOfWorkers.Value) {
240            if (clonedAlgorithm.ExecutionState == ExecutionState.Prepared ||
241                clonedAlgorithm.ExecutionState == ExecutionState.Paused) {
242              clonedAlgorithm.Start();
243              startedAlgorithms++;
244            }
245          }
246        }
247        OnStarted();
248      }
249    }
250
251    private bool pausePending;
252    public void Pause() {
253      if (ExecutionState != ExecutionState.Started)
254        throw new InvalidOperationException(string.Format("Pause not allowed in execution state \"{0}\".", ExecutionState));
255      if (!pausePending) {
256        pausePending = true;
257        if (!startPending) PauseAllClonedAlgorithms();
258      }
259    }
260    private void PauseAllClonedAlgorithms() {
261      foreach (IAlgorithm clonedAlgorithm in ClonedAlgorithms) {
262        if (clonedAlgorithm.ExecutionState == ExecutionState.Started)
263          clonedAlgorithm.Pause();
264      }
265    }
266
267    private bool stopPending;
268    public void Stop() {
269      if ((ExecutionState != ExecutionState.Started) && (ExecutionState != ExecutionState.Paused))
270        throw new InvalidOperationException(string.Format("Stop not allowed in execution state \"{0}\".",
271                                                          ExecutionState));
272      if (!stopPending) {
273        stopPending = true;
274        if (!startPending) StopAllClonedAlgorithms();
275      }
276    }
277    private void StopAllClonedAlgorithms() {
278      foreach (IAlgorithm clonedAlgorithm in ClonedAlgorithms) {
279        if (clonedAlgorithm.ExecutionState == ExecutionState.Started ||
280            clonedAlgorithm.ExecutionState == ExecutionState.Paused)
281          clonedAlgorithm.Stop();
282      }
283    }
284
285    #region events
286    private void RegisterEvents() {
287      Folds.ValueChanged += new EventHandler(Folds_ValueChanged);
288      NumberOfWorkers.ValueChanged += new EventHandler(NumberOfWorkers_ValueChanged);
289      RegisterClonedAlgorithmsEvents();
290      RegisterRunsEvents();
291    }
292    private void Folds_ValueChanged(object sender, EventArgs e) {
293      if (ExecutionState != ExecutionState.Prepared)
294        throw new InvalidOperationException("Can not change number of folds if the execution state is not prepared.");
295    }
296    private void NumberOfWorkers_ValueChanged(object sender, EventArgs e) {
297      if (ExecutionState == ExecutionState.Started) {
298        int workers = numberOfWorkers.Value;
299        int runningWorkers = clonedAlgorithms.Count(alg => alg.ExecutionState == ExecutionState.Started);
300
301        foreach (IAlgorithm algorithm in clonedAlgorithms) {
302          if (algorithm.ExecutionState == ExecutionState.Prepared ||
303              algorithm.ExecutionState == ExecutionState.Paused) {
304            if (runningWorkers < workers) {
305              algorithm.Start();
306              runningWorkers++;
307            }
308          } else if (algorithm.ExecutionState == ExecutionState.Started) {
309            if (runningWorkers > workers) {
310              algorithm.Pause();
311              runningWorkers--;
312            }
313          }
314        }
315      }
316    }
317
318    #region template algorithms events
319    public event EventHandler AlgorithmChanged;
320    private void OnAlgorithmChanged() {
321      EventHandler handler = AlgorithmChanged;
322      if (handler != null) handler(this, EventArgs.Empty);
323      OnProblemChanged();
324      if (Problem == null) ExecutionState = ExecutionState.Stopped;
325    }
326    private void RegisterAlgorithmEvents() {
327      algorithm.ProblemChanged += new EventHandler(Algorithm_ProblemChanged);
328      algorithm.ExecutionStateChanged += new EventHandler(Algorithm_ExecutionStateChanged);
329    }
330    private void DeregisterAlgorithmEvents() {
331      algorithm.ProblemChanged -= new EventHandler(Algorithm_ProblemChanged);
332      algorithm.ExecutionStateChanged -= new EventHandler(Algorithm_ExecutionStateChanged);
333    }
334    private void Algorithm_ProblemChanged(object sender, EventArgs e) {
335      if (algorithm.Problem != null && !(algorithm.Problem is IDataAnalysisProblem)) {
336        algorithm.Problem = cachedProblem;
337        throw new ArgumentException("A cross validation algorithm can only contain DataAnalysisProblems.");
338      }
339      cachedProblem = (IDataAnalysisProblem)algorithm.Problem;
340      SamplesStart.Value = 0;
341      if (algorithm.Problem != null)
342        SamplesEnd.Value = Problem.DataAnalysisProblemData.Dataset.Rows;
343      else
344        SamplesEnd.Value = 0;
345      OnProblemChanged();
346    }
347    public event EventHandler ProblemChanged;
348    private void OnProblemChanged() {
349      EventHandler handler = ProblemChanged;
350      if (handler != null) handler(this, EventArgs.Empty);
351    }
352
353    private void Algorithm_ExecutionStateChanged(object sender, EventArgs e) {
354      switch (Algorithm.ExecutionState) {
355        case ExecutionState.Prepared: OnPrepared();
356          break;
357        case ExecutionState.Started: throw new InvalidOperationException("Algorithm template can not be started.");
358        case ExecutionState.Paused: throw new InvalidOperationException("Algorithm template can not be paused.");
359        case ExecutionState.Stopped: OnStopped();
360          break;
361      }
362    }
363    #endregion
364
365    #region clonedAlgorithms events
366    private void RegisterClonedAlgorithmsEvents() {
367      clonedAlgorithms.ItemsAdded += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded);
368      clonedAlgorithms.ItemsRemoved += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved);
369      clonedAlgorithms.CollectionReset += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset);
370      foreach (IAlgorithm algorithm in clonedAlgorithms)
371        RegisterClonedAlgorithmEvents(algorithm);
372    }
373    private void ClonedAlgorithms_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
374      foreach (IAlgorithm algorithm in e.Items)
375        RegisterClonedAlgorithmEvents(algorithm);
376    }
377    private void ClonedAlgorithms_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
378      foreach (IAlgorithm algorithm in e.Items)
379        DeregisterClonedAlgorithmEvents(algorithm);
380    }
381    private void ClonedAlgorithms_CollectionReset(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
382      foreach (IAlgorithm algorithm in e.OldItems)
383        DeregisterClonedAlgorithmEvents(algorithm);
384      foreach (IAlgorithm algorithm in e.Items)
385        RegisterClonedAlgorithmEvents(algorithm);
386    }
387    private void RegisterClonedAlgorithmEvents(IAlgorithm algorithm) {
388      algorithm.ExceptionOccurred += new EventHandler<EventArgs<Exception>>(ClonedAlgorithm_ExceptionOccurred);
389      algorithm.ExecutionTimeChanged += new EventHandler(ClonedAlgorithm_ExecutionTimeChanged);
390      algorithm.Started += new EventHandler(ClonedAlgorithm_Started);
391      algorithm.Paused += new EventHandler(ClonedAlgorithm_Paused);
392      algorithm.Stopped += new EventHandler(ClonedAlgorithm_Stopped);
393    }
394    private void DeregisterClonedAlgorithmEvents(IAlgorithm algorithm) {
395      algorithm.ExceptionOccurred -= new EventHandler<EventArgs<Exception>>(ClonedAlgorithm_ExceptionOccurred);
396      algorithm.ExecutionTimeChanged -= new EventHandler(ClonedAlgorithm_ExecutionTimeChanged);
397      algorithm.Started -= new EventHandler(ClonedAlgorithm_Started);
398      algorithm.Paused -= new EventHandler(ClonedAlgorithm_Paused);
399      algorithm.Stopped -= new EventHandler(ClonedAlgorithm_Stopped);
400    }
401    private void ClonedAlgorithm_ExceptionOccurred(object sender, EventArgs<Exception> e) {
402      OnExceptionOccurred(e.Value);
403    }
404    private void ClonedAlgorithm_ExecutionTimeChanged(object sender, EventArgs e) {
405      OnExecutionTimeChanged();
406    }
407
408    private readonly object locker = new object();
409    private void ClonedAlgorithm_Started(object sender, EventArgs e) {
410      lock (locker) {
411        if (startPending) {
412          int startedAlgorithms = clonedAlgorithms.Count(alg => alg.ExecutionState == ExecutionState.Started);
413          if (startedAlgorithms == NumberOfWorkers.Value ||
414             clonedAlgorithms.All(alg => alg.ExecutionState != ExecutionState.Prepared))
415            startPending = false;
416
417          if (pausePending) PauseAllClonedAlgorithms();
418          if (stopPending) StopAllClonedAlgorithms();
419        }
420      }
421    }
422
423    private void ClonedAlgorithm_Paused(object sender, EventArgs e) {
424      lock (locker) {
425        if (pausePending && clonedAlgorithms.All(alg => alg.ExecutionState != ExecutionState.Started))
426          OnPaused();
427      }
428    }
429
430    private void ClonedAlgorithm_Stopped(object sender, EventArgs e) {
431      lock (locker) {
432        if (!stopPending && ExecutionState == ExecutionState.Started) {
433          IAlgorithm preparedAlgorithm = clonedAlgorithms.Where(alg => alg.ExecutionState == ExecutionState.Prepared ||
434                                                                       alg.ExecutionState == ExecutionState.Paused).FirstOrDefault();
435          if (preparedAlgorithm != null) preparedAlgorithm.Start();
436        }
437        if (clonedAlgorithms.All(alg => alg.ExecutionState == ExecutionState.Stopped))
438          OnStopped();
439        else if (stopPending && clonedAlgorithms.All(alg => alg.ExecutionState == ExecutionState.Prepared || alg.ExecutionState == ExecutionState.Stopped))
440          OnStopped();
441      }
442    }
443    #endregion
444
445    #region run events
446    private void RegisterRunsEvents() {
447      runs.CollectionReset += new CollectionItemsChangedEventHandler<IRun>(Runs_CollectionReset);
448      runs.ItemsAdded += new CollectionItemsChangedEventHandler<IRun>(Runs_ItemsAdded);
449      runs.ItemsRemoved += new CollectionItemsChangedEventHandler<IRun>(Runs_ItemsRemoved);
450    }
451    private void Runs_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRun> e) {
452      foreach (IRun run in e.OldItems) {
453        IItem item;
454        run.Results.TryGetValue("Execution Time", out item);
455        TimeSpanValue executionTime = item as TimeSpanValue;
456        if (executionTime != null) ExecutionTime -= executionTime.Value;
457      }
458      foreach (IRun run in e.Items) {
459        IItem item;
460        run.Results.TryGetValue("Execution Time", out item);
461        TimeSpanValue executionTime = item as TimeSpanValue;
462        if (executionTime != null) ExecutionTime += executionTime.Value;
463      }
464    }
465    private void Runs_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRun> e) {
466      foreach (IRun run in e.Items) {
467        IItem item;
468        run.Results.TryGetValue("Execution Time", out item);
469        TimeSpanValue executionTime = item as TimeSpanValue;
470        if (executionTime != null) ExecutionTime += executionTime.Value;
471      }
472    }
473    private void Runs_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRun> e) {
474      foreach (IRun run in e.Items) {
475        IItem item;
476        run.Results.TryGetValue("Execution Time", out item);
477        TimeSpanValue executionTime = item as TimeSpanValue;
478        if (executionTime != null) ExecutionTime -= executionTime.Value;
479      }
480    }
481    #endregion
482    #endregion
483
484    #region event firing
485    public event EventHandler ExecutionStateChanged;
486    private void OnExecutionStateChanged() {
487      EventHandler handler = ExecutionStateChanged;
488      if (handler != null) handler(this, EventArgs.Empty);
489    }
490    public event EventHandler ExecutionTimeChanged;
491    private void OnExecutionTimeChanged() {
492      EventHandler handler = ExecutionTimeChanged;
493      if (handler != null) handler(this, EventArgs.Empty);
494    }
495    public event EventHandler Prepared;
496    private void OnPrepared() {
497      ExecutionState = ExecutionState.Prepared;
498      EventHandler handler = Prepared;
499      if (handler != null) handler(this, EventArgs.Empty);
500    }
501    public event EventHandler Started;
502    private void OnStarted() {
503      startPending = false;
504      ExecutionState = ExecutionState.Started;
505      EventHandler handler = Started;
506      if (handler != null) handler(this, EventArgs.Empty);
507    }
508    public event EventHandler Paused;
509    private void OnPaused() {
510      pausePending = false;
511      ExecutionState = ExecutionState.Paused;
512      EventHandler handler = Paused;
513      if (handler != null) handler(this, EventArgs.Empty);
514    }
515    public event EventHandler Stopped;
516    private void OnStopped() {
517      stopPending = false;
518      ExecutionState = ExecutionState.Stopped;
519      //TODO create run;
520      EventHandler handler = Stopped;
521      if (handler != null) handler(this, EventArgs.Empty);
522    }
523    public event EventHandler<EventArgs<Exception>> ExceptionOccurred;
524    private void OnExceptionOccurred(Exception exception) {
525      EventHandler<EventArgs<Exception>> handler = ExceptionOccurred;
526      if (handler != null) handler(this, new EventArgs<Exception>(exception));
527    }
528    #endregion
529  }
530}
Note: See TracBrowser for help on using the repository browser.