Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Classification/HeuristicLab.Algorithms.DataAnalysis/3.3/CrossValidation.cs @ 4536

Last change on this file since 4536 was 4536, checked in by mkommend, 13 years ago

Added provisional version of the cross validation (ticket #1199).

File size: 14.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Drawing;
24using System.Linq;
25using HeuristicLab.Collections;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Optimization;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  [Item("Cross Validation", "Cross Validation wrapper for data analysis algorithms.")]
35  [Creatable("Data Analysis")]
36  [StorableClass]
37  public sealed class CrossValidation : NamedItem, IOptimizer {
38    //TODO cloning, persistence, creation of clonedAlgs, execution logic, ...
39    public CrossValidation()
40      : base() {
41      name = ItemName;
42      description = ItemDescription;
43
44      executionState = ExecutionState.Stopped;
45      executionTime = TimeSpan.Zero;
46      Runs = new RunCollection();
47
48      algorithm = null;
49      clonedAlgorithms = new ItemCollection<IAlgorithm>();
50      RegisterClonedAlgorithmEvents();
51      readOnlyClonedAlgorithms = null;
52
53      folds = new IntValue(1);
54      samplesStart = new IntValue(0);
55      samplesEnd = new IntValue(0);
56      numberOfWorkers = new IntValue(1);
57    }
58
59    #region properties
60    private IAlgorithm algorithm;
61    public IAlgorithm Algorithm {
62      get { return algorithm; }
63      set {
64        if (algorithm != value) {
65          if (value != null && value.Problem != null && !(value.Problem is IDataAnalysisProblem))
66            throw new ArgumentException("Only algorithms with a DataAnalysisProblem could be used for the cross validation.");
67          if (algorithm != null) DeregisterAlgorithmEvents();
68          algorithm = value;
69          algorithm.Prepare(true);
70          if (algorithm != null) RegisterAlgorithmEvents();
71          OnAlgorithmChanged();
72        }
73      }
74    }
75    private IDataAnalysisProblem cachedProblem;
76    public IDataAnalysisProblem Problem {
77      get {
78        if (algorithm == null)
79          return null;
80        return (IDataAnalysisProblem)algorithm.Problem;
81      }
82      set {
83        if (algorithm == null) throw new ArgumentNullException("Could not set a problem before an algorithm was set.");
84        algorithm.Problem = value;
85        cachedProblem = value;
86      }
87    }
88
89    private ItemCollection<IAlgorithm> clonedAlgorithms;
90    private ReadOnlyItemCollection<IAlgorithm> readOnlyClonedAlgorithms;
91    public IItemCollection<IAlgorithm> ClonedAlgorithms {
92      get {
93        if (readOnlyClonedAlgorithms == null) readOnlyClonedAlgorithms = clonedAlgorithms.AsReadOnly();
94        return readOnlyClonedAlgorithms;
95      }
96    }
97
98    private IntValue folds;
99    public IntValue Folds {
100      get { return folds; }
101    }
102    private IntValue samplesStart;
103    public IntValue SamplesStart {
104      get { return samplesStart; }
105    }
106    private IntValue samplesEnd;
107    public IntValue SamplesEnd {
108      get { return samplesEnd; }
109    }
110    private IntValue numberOfWorkers;
111    public IntValue NumberOfWorkers {
112      get { return numberOfWorkers; }
113    }
114
115    private RunCollection runs;
116    public RunCollection Runs {
117      get { return runs; }
118      private set {
119        if (value == null) throw new ArgumentNullException();
120        if (runs != value) {
121          if (runs != null) DeregisterRunsEvents();
122          runs = value;
123          if (runs != null) RegisterRunsEvents();
124        }
125      }
126    }
127    private ExecutionState executionState;
128    public ExecutionState ExecutionState {
129      get { return executionState; }
130      private set {
131        if (executionState != value) {
132          executionState = value;
133          OnExecutionStateChanged();
134          OnItemImageChanged();
135        }
136      }
137    }
138    public override Image ItemImage {
139      get {
140        if (ExecutionState == ExecutionState.Prepared) return HeuristicLab.Common.Resources.VS2008ImageLibrary.ExecutablePrepared;
141        else if (ExecutionState == ExecutionState.Started) return HeuristicLab.Common.Resources.VS2008ImageLibrary.ExecutableStarted;
142        else if (ExecutionState == ExecutionState.Paused) return HeuristicLab.Common.Resources.VS2008ImageLibrary.ExecutablePaused;
143        else if (ExecutionState == ExecutionState.Stopped) return HeuristicLab.Common.Resources.VS2008ImageLibrary.ExecutableStopped;
144        else return HeuristicLab.Common.Resources.VS2008ImageLibrary.Event;
145      }
146    }
147
148    private TimeSpan executionTime;
149    public TimeSpan ExecutionTime {
150      get {
151        if (ExecutionState != ExecutionState.Stopped)
152          return executionTime + TimeSpan.FromMilliseconds(clonedAlgorithms.Select(x => x.ExecutionTime.TotalMilliseconds).Sum());
153        else
154          return executionTime;
155      }
156      private set {
157        executionTime = value;
158        OnExecutionTimeChanged();
159      }
160    }
161    #endregion
162
163    public void Prepare() {
164      clonedAlgorithms.Clear();
165      OnPrepared();
166    }
167    public void Prepare(bool clearRuns) {
168      if (clearRuns) runs.Clear();
169      Prepare();
170    }
171    public void Start() {
172      for (int i = 0; i < Folds.Value; i++) {
173        IAlgorithm clonedAlgorithm = (IAlgorithm)algorithm.Clone();
174        clonedAlgorithm.Name = algorithm.Name + " Fold " + i;
175        clonedAlgorithms.Add(clonedAlgorithm);
176      }
177      for (int i = 0; i < NumberOfWorkers.Value && i < clonedAlgorithms.Count; i++)
178        clonedAlgorithms.ElementAt(i).Start();
179      OnStarted();
180    }
181    public void Pause() {
182      OnPaused();
183    }
184    public void Stop() {
185      OnStopped();
186    }
187
188    #region events
189    public event EventHandler AlgorithmChanged;
190    private void OnAlgorithmChanged() {
191      EventHandler handler = AlgorithmChanged;
192      if (handler != null) handler(this, EventArgs.Empty);
193      OnProblemChanged();
194    }
195    private void RegisterAlgorithmEvents() {
196      algorithm.ProblemChanged += new EventHandler(Algorithm_ProblemChanged);
197    }
198    private void DeregisterAlgorithmEvents() {
199      algorithm.ProblemChanged -= new EventHandler(Algorithm_ProblemChanged);
200    }
201    private void Algorithm_ProblemChanged(object sender, EventArgs e) {
202      if (algorithm.Problem != null && !(algorithm.Problem is IDataAnalysisProblem)) {
203        algorithm.Problem = cachedProblem;
204        throw new ArgumentException("A cross validation algorithm can only contain DataAnalysisProblems.");
205      }
206      cachedProblem = (IDataAnalysisProblem)algorithm.Problem;
207      SamplesStart.Value = 0;
208      if (algorithm.Problem != null)
209        SamplesEnd.Value = Problem.DataAnalysisProblemData.Dataset.Rows;
210      else
211        SamplesEnd.Value = 0;
212      OnProblemChanged();
213    }
214    public event EventHandler ProblemChanged;
215    private void OnProblemChanged() {
216      EventHandler handler = ProblemChanged;
217      if (handler != null) handler(this, EventArgs.Empty);
218    }
219
220    private void RegisterClonedAlgorithmEvents() {
221      clonedAlgorithms.ItemsAdded += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded);
222      clonedAlgorithms.ItemsRemoved += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved);
223      clonedAlgorithms.CollectionReset += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset);
224    }
225    private void ClonedAlgorithms_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
226      foreach (IAlgorithm algorithm in e.Items)
227        RegisterClonedAlgorithmEvents(algorithm);
228    }
229    private void ClonedAlgorithms_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
230      foreach (IAlgorithm algorithm in e.Items)
231        DeregisterClonedAlgorithmEvents(algorithm);
232    }
233    private void ClonedAlgorithms_CollectionReset(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
234      foreach (IAlgorithm algorithm in e.Items)
235        RegisterClonedAlgorithmEvents(algorithm);
236      foreach (IAlgorithm algorithm in e.OldItems)
237        DeregisterClonedAlgorithmEvents(algorithm);
238    }
239    private void RegisterClonedAlgorithmEvents(IAlgorithm algorithm) {
240      algorithm.ExceptionOccurred += new EventHandler<EventArgs<Exception>>(Algorithm_ExceptionOccurred);
241      algorithm.ExecutionTimeChanged += new EventHandler(Algorithm_ExecutionTimeChanged);
242      algorithm.Runs.ItemsAdded += new CollectionItemsChangedEventHandler<IRun>(Algorithm_Runs_Added);
243      algorithm.Paused += new EventHandler(Algorithm_Paused);
244      algorithm.Stopped += new EventHandler(Algorithm_Stopped);
245    }
246    private void DeregisterClonedAlgorithmEvents(IAlgorithm algorithm) {
247      algorithm.ExceptionOccurred -= new EventHandler<EventArgs<Exception>>(Algorithm_ExceptionOccurred);
248      algorithm.ExecutionTimeChanged += new EventHandler(Algorithm_ExecutionTimeChanged);
249      algorithm.Runs.ItemsAdded -= new CollectionItemsChangedEventHandler<IRun>(Algorithm_Runs_Added);
250      algorithm.Paused -= new EventHandler(Algorithm_Paused);
251      algorithm.Stopped -= new EventHandler(Algorithm_Stopped);
252    }
253    private void Algorithm_ExceptionOccurred(object sender, EventArgs<Exception> e) {
254      OnExceptionOccurred(e.Value);
255    }
256    private void Algorithm_ExecutionTimeChanged(object sender, EventArgs e) {
257      OnExecutionTimeChanged();
258    }
259    private void Algorithm_Runs_Added(object sender, CollectionItemsChangedEventArgs<IRun> e) {
260      throw new NotImplementedException("TODO added finished run to actual results if the cross validation is running");
261    }
262    private void Algorithm_Paused(object sender, EventArgs e) {
263      throw new NotImplementedException("TODO pause the cross validation if an algorithm no algorithm is running.");
264    }
265    private void Algorithm_Stopped(object sender, EventArgs e) {
266      throw new NotImplementedException("TODO stop the cross validation if all algorithms are stopped. and start remaining prepared algs");
267    }
268
269    public event EventHandler ExecutionStateChanged;
270    private void OnExecutionStateChanged() {
271      EventHandler handler = ExecutionStateChanged;
272      if (handler != null) handler(this, EventArgs.Empty);
273    }
274    public event EventHandler ExecutionTimeChanged;
275    private void OnExecutionTimeChanged() {
276      EventHandler handler = ExecutionTimeChanged;
277      if (handler != null) handler(this, EventArgs.Empty);
278    }
279    public event EventHandler Prepared;
280    private void OnPrepared() {
281      ExecutionState = ExecutionState.Prepared;
282      EventHandler handler = Prepared;
283      if (handler != null) handler(this, EventArgs.Empty);
284    }
285    public event EventHandler Started;
286    private void OnStarted() {
287      ExecutionState = ExecutionState.Started;
288      EventHandler handler = Started;
289      if (handler != null) handler(this, EventArgs.Empty);
290    }
291    public event EventHandler Paused;
292    private void OnPaused() {
293      ExecutionState = ExecutionState.Paused;
294      EventHandler handler = Paused;
295      if (handler != null) handler(this, EventArgs.Empty);
296    }
297    public event EventHandler Stopped;
298    private void OnStopped() {
299      ExecutionState = ExecutionState.Stopped;
300      EventHandler handler = Stopped;
301      if (handler != null) handler(this, EventArgs.Empty);
302    }
303    public event EventHandler<EventArgs<Exception>> ExceptionOccurred;
304    private void OnExceptionOccurred(Exception exception) {
305      EventHandler<EventArgs<Exception>> handler = ExceptionOccurred;
306      if (handler != null) handler(this, new EventArgs<Exception>(exception));
307    }
308
309    private void RegisterRunsEvents() {
310      runs.CollectionReset += new CollectionItemsChangedEventHandler<IRun>(Runs_CollectionReset);
311      runs.ItemsAdded += new CollectionItemsChangedEventHandler<IRun>(Runs_ItemsAdded);
312      runs.ItemsRemoved += new CollectionItemsChangedEventHandler<IRun>(Runs_ItemsRemoved);
313    }
314    private void DeregisterRunsEvents() {
315      runs.CollectionReset -= new CollectionItemsChangedEventHandler<IRun>(Runs_CollectionReset);
316      runs.ItemsAdded -= new CollectionItemsChangedEventHandler<IRun>(Runs_ItemsAdded);
317      runs.ItemsRemoved -= new CollectionItemsChangedEventHandler<IRun>(Runs_ItemsRemoved);
318    }
319    private void Runs_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRun> e) {
320      foreach (IRun run in e.OldItems) {
321        IItem item;
322        run.Results.TryGetValue("Execution Time", out item);
323        TimeSpanValue executionTime = item as TimeSpanValue;
324        if (executionTime != null) ExecutionTime -= executionTime.Value;
325      }
326      foreach (IRun run in e.Items) {
327        IItem item;
328        run.Results.TryGetValue("Execution Time", out item);
329        TimeSpanValue executionTime = item as TimeSpanValue;
330        if (executionTime != null) ExecutionTime += executionTime.Value;
331      }
332    }
333    private void Runs_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRun> e) {
334      foreach (IRun run in e.Items) {
335        IItem item;
336        run.Results.TryGetValue("Execution Time", out item);
337        TimeSpanValue executionTime = item as TimeSpanValue;
338        if (executionTime != null) ExecutionTime += executionTime.Value;
339      }
340    }
341    private void Runs_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRun> e) {
342      foreach (IRun run in e.Items) {
343        IItem item;
344        run.Results.TryGetValue("Execution Time", out item);
345        TimeSpanValue executionTime = item as TimeSpanValue;
346        if (executionTime != null) ExecutionTime -= executionTime.Value;
347      }
348    }
349    #endregion
350  }
351}
Note: See TracBrowser for help on using the repository browser.