- Timestamp:
- 10/06/10 15:56:09 (14 years ago)
- Location:
- branches/HeuristicLab.Classification/HeuristicLab.Algorithms.DataAnalysis/3.3
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Classification/HeuristicLab.Algorithms.DataAnalysis/3.3/CrossValidation.cs
r4542 r4561 21 21 22 22 using System; 23 using System.Collections.Generic; 23 24 using System.Drawing; 24 25 using System.Linq; … … 35 36 [Creatable("Data Analysis")] 36 37 [StorableClass] 37 public sealed class CrossValidation : NamedItem, IOptimizer{38 public sealed class CrossValidation : ParameterizedNamedItem, IAlgorithm { 38 39 public CrossValidation() 39 40 : base() { … … 44 45 executionTime = TimeSpan.Zero; 45 46 runs = new RunCollection(); 47 runsCounter = 0; 46 48 47 49 algorithm = null; 48 50 clonedAlgorithms = new ItemCollection<IAlgorithm>(); 49 51 readOnlyClonedAlgorithms = null; 50 51 folds = new IntValue(1); 52 results = new ResultCollection(); 53 54 folds = new IntValue(2); 52 55 numberOfWorkers = new IntValue(1); 53 56 samplesStart = new IntValue(0); 54 57 samplesEnd = new IntValue(0); 55 58 storeAlgorithmInEachRun = false; 56 59 57 60 RegisterEvents(); … … 69 72 70 73 public override IDeepCloneable Clone(Cloner cloner) { 71 if (ExecutionState == ExecutionState.Started) throw new InvalidOperationException(string.Format("Clone not allowed in execution state \"{0}\".", ExecutionState)); 72 CrossValidation clone = new CrossValidation(false); 73 cloner.RegisterClonedObject(this, clone); 74 clone.name = name; 75 clone.description = description; 74 CrossValidation clone = (CrossValidation)base.Clone(cloner); 75 clone.DeregisterEvents(); 76 76 clone.executionState = executionState; 77 77 clone.executionTime = executionTime; 78 clone.storeAlgorithmInEachRun = storeAlgorithmInEachRun; 78 79 clone.runs = (RunCollection)cloner.Clone(runs); 80 clone.runsCounter = runsCounter; 79 81 clone.algorithm = (IAlgorithm)cloner.Clone(algorithm); 80 82 clone.clonedAlgorithms = (ItemCollection<IAlgorithm>)cloner.Clone(clonedAlgorithms); … … 102 104 if (algorithm != null) DeregisterAlgorithmEvents(); 103 105 algorithm = value; 106 Parameters.Clear(); 104 107 105 108 if (algorithm != null) { 109 algorithm.StoreAlgorithmInEachRun = StoreAlgorithmInEachRun; 106 110 RegisterAlgorithmEvents(); 107 111 algorithm.Prepare(true); 112 Parameters.AddRange(algorithm.Parameters); 108 113 } 109 114 OnAlgorithmChanged(); 115 if (algorithm != null) OnProblemChanged(); 110 116 Prepare(); 111 117 } 112 118 } 113 119 } 120 114 121 115 122 [Storable] … … 130 137 } 131 138 139 IProblem IAlgorithm.Problem { 140 get { return Problem; } 141 set { 142 if (value != null && !ProblemType.IsInstanceOfType(value)) 143 throw new ArgumentException("Only DataAnalysisProblems could be used for the cross validation."); 144 Problem = (IDataAnalysisProblem)value; 145 } 146 } 147 public Type ProblemType { 148 get { return typeof(IDataAnalysisProblem); } 149 } 150 132 151 [Storable] 133 152 private ItemCollection<IAlgorithm> clonedAlgorithms; … … 141 160 142 161 [Storable] 162 private ResultCollection results; 163 public ResultCollection Results { 164 get { return results; } 165 } 166 167 [Storable] 143 168 private IntValue folds; 144 169 public IntValue Folds { … … 162 187 163 188 [Storable] 189 private bool storeAlgorithmInEachRun; 190 public bool StoreAlgorithmInEachRun { 191 get { return storeAlgorithmInEachRun; } 192 set { 193 if (storeAlgorithmInEachRun != value) { 194 storeAlgorithmInEachRun = value; 195 OnStoreAlgorithmInEachRunChanged(); 196 } 197 } 198 } 199 200 [Storable] 201 private int runsCounter; 202 [Storable] 164 203 private RunCollection runs; 165 204 public RunCollection Runs { 166 205 get { return runs; } 167 206 } 207 168 208 [Storable] 169 209 private ExecutionState executionState; … … 192 232 public TimeSpan ExecutionTime { 193 233 get { 194 if (ExecutionState != ExecutionState.Stopped )234 if (ExecutionState != ExecutionState.Stopped && ExecutionState != ExecutionState.Prepared) 195 235 return executionTime + TimeSpan.FromMilliseconds(clonedAlgorithms.Select(x => x.ExecutionTime.TotalMilliseconds).Sum()); 196 236 else … … 227 267 //create cloned algorithms 228 268 if (clonedAlgorithms.Count == 0) { 269 int testSamplesCount = (SamplesEnd.Value - SamplesStart.Value) / Folds.Value; 229 270 for (int i = 0; i < Folds.Value; i++) { 230 271 IAlgorithm clonedAlgorithm = (IAlgorithm)algorithm.Clone(); 231 272 clonedAlgorithm.Name = algorithm.Name + " Fold " + i; 273 IDataAnalysisProblem problem = clonedAlgorithm.Problem as IDataAnalysisProblem; 274 problem.DataAnalysisProblemData.TestSamplesEnd.Value = (i + 1) == Folds.Value ? SamplesEnd.Value : (i + 1) * testSamplesCount + SamplesStart.Value; 275 problem.DataAnalysisProblemData.TestSamplesStart.Value = (i * testSamplesCount) + SamplesStart.Value; 232 276 clonedAlgorithms.Add(clonedAlgorithm); 233 277 } … … 283 327 } 284 328 329 #region collect parameters and results 330 public override void CollectParameterValues(IDictionary<string, IItem> values) { 331 values.Add("Algorithm Name", new StringValue(Name)); 332 values.Add("Algorithm Type", new StringValue(GetType().GetPrettyName())); 333 values.Add("Folds", new IntValue(Folds.Value)); 334 335 if (algorithm != null) { 336 values.Add("CrossValidation Algorithm Name", new StringValue(Algorithm.Name)); 337 values.Add("CrossValidation Algorithm Type", new StringValue(Algorithm.GetType().GetPrettyName())); 338 base.CollectParameterValues(values); 339 } 340 if (Problem != null) { 341 values.Add("Problem Name", new StringValue(Problem.Name)); 342 values.Add("Problem Type", new StringValue(Problem.GetType().GetPrettyName())); 343 Problem.CollectParameterValues(values); 344 } 345 } 346 347 public void CollectResultValues(IDictionary<string, IItem> results) { 348 Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>(); 349 IEnumerable<IRun> runs = ClonedAlgorithms.Select(alg => alg.Runs.FirstOrDefault()).Where(run => run != null); 350 IEnumerable<KeyValuePair<string, IItem>> resultCollections = runs.Where(x => x != null).SelectMany(x => x.Results).ToList(); 351 352 foreach (IResult result in ExtractAndAggregateResults<IntValue>(resultCollections)) 353 results.Add(result.Name, result.Value); 354 foreach (IResult result in ExtractAndAggregateResults<DoubleValue>(resultCollections)) 355 results.Add(result.Name, result.Value); 356 foreach (IResult result in ExtractAndAggregateResults<PercentValue>(resultCollections)) 357 results.Add(result.Name, result.Value); 358 359 results.Add("Execution Time", new TimeSpanValue(TimeSpan.FromMilliseconds(clonedAlgorithms.Select(x => x.ExecutionTime.TotalMilliseconds).Sum()))); 360 results.Add("CrossValidation Folds", new RunCollection(runs)); 361 } 362 363 private static IEnumerable<IResult> ExtractAndAggregateResults<T>(IEnumerable<KeyValuePair<string, IItem>> results) 364 where T : class, IItem, new() { 365 Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>(); 366 foreach (var resultValue in results.Where(r => r.Value.GetType() == typeof(T))) { 367 if (!resultValues.ContainsKey(resultValue.Key)) 368 resultValues[resultValue.Key] = new List<double>(); 369 resultValues[resultValue.Key].Add(ConvertToDouble(resultValue.Value)); 370 } 371 372 DoubleValue doubleValue; 373 if (typeof(T) == typeof(PercentValue)) 374 doubleValue = new PercentValue(); 375 else if (typeof(T) == typeof(DoubleValue)) 376 doubleValue = new DoubleValue(); 377 else if (typeof(T) == typeof(IntValue)) 378 doubleValue = new DoubleValue(); 379 else 380 throw new NotSupportedException(); 381 382 List<IResult> aggregatedResults = new List<IResult>(); 383 foreach (KeyValuePair<string, List<double>> resultValue in resultValues) { 384 doubleValue.Value = resultValue.Value.Average(); 385 aggregatedResults.Add(new Result(resultValue.Key, (IItem)doubleValue.Clone())); 386 doubleValue.Value = resultValue.Value.StandardDeviation(); 387 aggregatedResults.Add(new Result(resultValue.Key + " StdDev", (IItem)doubleValue.Clone())); 388 } 389 return aggregatedResults; 390 } 391 392 private static double ConvertToDouble(IItem item) { 393 if (item is DoubleValue) return ((DoubleValue)item).Value; 394 else if (item is IntValue) return ((IntValue)item).Value; 395 else throw new NotSupportedException("Could not convert any item type to double"); 396 } 397 398 #endregion 399 285 400 #region events 286 401 private void RegisterEvents() { 287 402 Folds.ValueChanged += new EventHandler(Folds_ValueChanged); 288 NumberOfWorkers.ValueChanged += new EventHandler(NumberOfWorkers_ValueChanged); 403 SamplesStart.ValueChanged += new EventHandler(SamplesStart_ValueChanged); 404 SamplesEnd.ValueChanged += new EventHandler(SamplesEnd_ValueChanged); 289 405 RegisterClonedAlgorithmsEvents(); 290 406 RegisterRunsEvents(); 407 } 408 private void DeregisterEvents() { 409 Folds.ValueChanged -= new EventHandler(Folds_ValueChanged); 410 SamplesStart.ValueChanged -= new EventHandler(SamplesStart_ValueChanged); 411 SamplesEnd.ValueChanged -= new EventHandler(SamplesEnd_ValueChanged); 412 DeregisterClonedAlgorithmsEvents(); 413 DeregisterRunsEvents(); 414 291 415 } 292 416 private void Folds_ValueChanged(object sender, EventArgs e) { … … 294 418 throw new InvalidOperationException("Can not change number of folds if the execution state is not prepared."); 295 419 } 296 private void NumberOfWorkers_ValueChanged(object sender, EventArgs e) { 297 if (ExecutionState == ExecutionState.Started) { 298 int workers = numberOfWorkers.Value; 299 int runningWorkers = clonedAlgorithms.Count(alg => alg.ExecutionState == ExecutionState.Started); 300 301 foreach (IAlgorithm algorithm in clonedAlgorithms) { 302 if (algorithm.ExecutionState == ExecutionState.Prepared || 303 algorithm.ExecutionState == ExecutionState.Paused) { 304 if (runningWorkers < workers) { 305 algorithm.Start(); 306 runningWorkers++; 307 } 308 } else if (algorithm.ExecutionState == ExecutionState.Started) { 309 if (runningWorkers > workers) { 310 algorithm.Pause(); 311 runningWorkers--; 312 } 313 } 314 } 315 } 420 private void SamplesStart_ValueChanged(object sender, EventArgs e) { 421 if (Problem != null) Problem.DataAnalysisProblemData.TrainingSamplesStart.Value = SamplesStart.Value; 422 } 423 private void SamplesEnd_ValueChanged(object sender, EventArgs e) { 424 if (Problem != null) Problem.DataAnalysisProblemData.TrainingSamplesEnd.Value = SamplesEnd.Value; 316 425 } 317 426 … … 338 447 } 339 448 cachedProblem = (IDataAnalysisProblem)algorithm.Problem; 449 OnProblemChanged(); 450 } 451 public event EventHandler ProblemChanged; 452 private void OnProblemChanged() { 453 EventHandler handler = ProblemChanged; 454 if (handler != null) handler(this, EventArgs.Empty); 455 340 456 SamplesStart.Value = 0; 341 if ( algorithm.Problem != null)457 if (Problem != null) 342 458 SamplesEnd.Value = Problem.DataAnalysisProblemData.Dataset.Rows; 343 459 else 344 460 SamplesEnd.Value = 0; 345 OnProblemChanged();346 }347 public event EventHandler ProblemChanged;348 private void OnProblemChanged() {349 EventHandler handler = ProblemChanged;350 if (handler != null) handler(this, EventArgs.Empty);351 461 } 352 462 … … 371 481 RegisterClonedAlgorithmEvents(algorithm); 372 482 } 483 private void DeregisterClonedAlgorithmsEvents() { 484 clonedAlgorithms.ItemsAdded -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded); 485 clonedAlgorithms.ItemsRemoved -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved); 486 clonedAlgorithms.CollectionReset -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset); 487 foreach (IAlgorithm algorithm in clonedAlgorithms) 488 DeregisterClonedAlgorithmEvents(algorithm); 489 } 373 490 private void ClonedAlgorithms_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) { 374 491 foreach (IAlgorithm algorithm in e.Items) … … 435 552 if (preparedAlgorithm != null) preparedAlgorithm.Start(); 436 553 } 437 if (clonedAlgorithms.All(alg => alg.ExecutionState == ExecutionState.Stopped)) 438 OnStopped(); 439 else if (stopPending && clonedAlgorithms.All(alg => alg.ExecutionState == ExecutionState.Prepared || alg.ExecutionState == ExecutionState.Stopped)) 440 OnStopped(); 554 if (ExecutionState != ExecutionState.Stopped) { 555 if (clonedAlgorithms.All(alg => alg.ExecutionState == ExecutionState.Stopped)) 556 OnStopped(); 557 else if (stopPending && 558 clonedAlgorithms.All( 559 alg => alg.ExecutionState == ExecutionState.Prepared || alg.ExecutionState == ExecutionState.Stopped)) 560 OnStopped(); 561 } 441 562 } 442 563 } … … 449 570 runs.ItemsRemoved += new CollectionItemsChangedEventHandler<IRun>(Runs_ItemsRemoved); 450 571 } 572 private void DeregisterRunsEvents() { 573 runs.CollectionReset -= new CollectionItemsChangedEventHandler<IRun>(Runs_CollectionReset); 574 runs.ItemsAdded -= new CollectionItemsChangedEventHandler<IRun>(Runs_ItemsAdded); 575 runs.ItemsRemoved -= new CollectionItemsChangedEventHandler<IRun>(Runs_ItemsRemoved); 576 } 451 577 private void Runs_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRun> e) { 452 578 foreach (IRun run in e.OldItems) { … … 454 580 run.Results.TryGetValue("Execution Time", out item); 455 581 TimeSpanValue executionTime = item as TimeSpanValue; 456 if (executionTime != null) ExecutionTime -=executionTime.Value;582 if (executionTime != null) ExecutionTime = this.executionTime - executionTime.Value; 457 583 } 458 584 foreach (IRun run in e.Items) { … … 460 586 run.Results.TryGetValue("Execution Time", out item); 461 587 TimeSpanValue executionTime = item as TimeSpanValue; 462 if (executionTime != null) ExecutionTime += executionTime.Value; 463 } 588 if (executionTime != null) ExecutionTime = this.executionTime + executionTime.Value; 589 } 590 runsCounter = Runs.Count; 464 591 } 465 592 private void Runs_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRun> e) { … … 468 595 run.Results.TryGetValue("Execution Time", out item); 469 596 TimeSpanValue executionTime = item as TimeSpanValue; 470 if (executionTime != null) ExecutionTime +=executionTime.Value;597 if (executionTime != null) this.executionTime = this.executionTime + executionTime.Value; 471 598 } 472 599 } … … 476 603 run.Results.TryGetValue("Execution Time", out item); 477 604 TimeSpanValue executionTime = item as TimeSpanValue; 478 if (executionTime != null) ExecutionTime -=executionTime.Value;605 if (executionTime != null) ExecutionTime = this.executionTime - executionTime.Value; 479 606 } 480 607 } … … 516 643 private void OnStopped() { 517 644 stopPending = false; 645 runsCounter++; 646 runs.Add(new Run(string.Format("{0} Run {1}", Name, runsCounter), this)); 518 647 ExecutionState = ExecutionState.Stopped; 519 //TODO create run;520 648 EventHandler handler = Stopped; 521 649 if (handler != null) handler(this, EventArgs.Empty); … … 526 654 if (handler != null) handler(this, new EventArgs<Exception>(exception)); 527 655 } 656 public event EventHandler StoreAlgorithmInEachRunChanged; 657 private void OnStoreAlgorithmInEachRunChanged() { 658 EventHandler handler = StoreAlgorithmInEachRunChanged; 659 if (handler != null) handler(this, EventArgs.Empty); 660 } 528 661 #endregion 529 662 } -
branches/HeuristicLab.Classification/HeuristicLab.Algorithms.DataAnalysis/3.3/HeuristicLab.Algorithms.DataAnalysis-3.3.csproj
r4536 r4561 148 148 <Reference Include="System.Data" /> 149 149 <Reference Include="System.Drawing" /> 150 <Reference Include="System.Windows.Forms" /> 150 151 <Reference Include="System.Xml" /> 151 152 </ItemGroup>
Note: See TracChangeset
for help on using the changeset viewer.