Changeset 15150 for stable/HeuristicLab.Algorithms.DataAnalysis
- Timestamp:
- 07/06/17 11:39:20 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
stable/HeuristicLab.Algorithms.DataAnalysis/3.4/CrossValidation.cs
r15149 r15150 33 33 using HeuristicLab.Problems.DataAnalysis; 34 34 using HeuristicLab.Problems.DataAnalysis.Symbolic; 35 using HeuristicLab.Random; 35 36 36 37 namespace HeuristicLab.Algorithms.DataAnalysis { … … 39 40 [StorableClass] 40 41 public sealed class CrossValidation : ParameterizedNamedItem, IAlgorithm, IStorableContent { 42 [Storable] 43 private int seed; 44 41 45 public CrossValidation() 42 46 : base() { … … 56 60 samplesStart = new IntValue(0); 57 61 samplesEnd = new IntValue(0); 62 shuffleSamples = new BoolValue(false); 58 63 storeAlgorithmInEachRun = false; 59 64 … … 71 76 [StorableHook(HookType.AfterDeserialization)] 72 77 private void AfterDeserialization() { 78 // BackwardsCompatibility3.3 79 #region Backwards compatible code, remove with 3.4 80 if (shuffleSamples == null) shuffleSamples = new BoolValue(false); 81 #endregion 82 73 83 RegisterEvents(); 74 84 if (Algorithm != null) RegisterAlgorithmEvents(); … … 89 99 samplesStart = cloner.Clone(original.samplesStart); 90 100 samplesEnd = cloner.Clone(original.samplesEnd); 101 shuffleSamples = cloner.Clone(original.shuffleSamples); 102 seed = original.seed; 103 91 104 RegisterEvents(); 92 105 if (Algorithm != null) RegisterAlgorithmEvents(); … … 170 183 get { return results; } 171 184 } 172 185 [Storable] 186 private BoolValue shuffleSamples; 187 public BoolValue ShuffleSamples { 188 get { return shuffleSamples; } 189 } 173 190 [Storable] 174 191 private IntValue folds; … … 270 287 throw new InvalidOperationException(string.Format("Start not allowed in execution state \"{0}\".", ExecutionState)); 271 288 289 seed = new FastRandom().NextInt(); 290 272 291 if (Algorithm != null) { 273 292 //create cloned algorithms 274 293 if (clonedAlgorithms.Count == 0) { 275 294 int testSamplesCount = (SamplesEnd.Value - SamplesStart.Value) / Folds.Value; 276 295 IDataset shuffledDataset = null; 277 296 for (int i = 0; i < Folds.Value; i++) { 278 IAlgorithm clonedAlgorithm = (IAlgorithm)algorithm.Clone(); 297 var cloner = new Cloner(); 298 if (ShuffleSamples.Value) { 299 var random = new FastRandom(seed); 300 var dataAnalysisProblem = (IDataAnalysisProblem)algorithm.Problem; 301 var dataset = (Dataset)dataAnalysisProblem.ProblemData.Dataset; 302 shuffledDataset = shuffledDataset ?? dataset.Shuffle(random); 303 cloner.RegisterClonedObject(dataset, shuffledDataset); 304 } 305 IAlgorithm clonedAlgorithm = cloner.Clone(Algorithm); 279 306 clonedAlgorithm.Name = algorithm.Name + " Fold " + i; 280 307 IDataAnalysisProblem problem = clonedAlgorithm.Problem as IDataAnalysisProblem; … … 422 449 // clone manually to correctly clone references between cloned root objects 423 450 Cloner cloner = new Cloner(); 451 if (ShuffleSamples.Value) { 452 var dataset = (Dataset)Problem.ProblemData.Dataset; 453 var random = new FastRandom(seed); 454 var shuffledDataset = dataset.Shuffle(random); 455 cloner.RegisterClonedObject(dataset, shuffledDataset); 456 } 424 457 var problemDataClone = (IRegressionProblemData)cloner.Clone(Problem.ProblemData); 425 458 // set partitions of problem data clone correctly … … 453 486 // at least one algorithm (GBT with logistic regression loss) produces a classification solution even though the original problem is a regression problem. 454 487 var targetVariable = solutions.Value.First().ProblemData.TargetVariable; 455 var problemDataClone = new ClassificationProblemData(Problem.ProblemData.Dataset, 456 Problem.ProblemData.AllowedInputVariables, targetVariable); 488 var dataset = (Dataset)Problem.ProblemData.Dataset; 489 if (ShuffleSamples.Value) { 490 var random = new FastRandom(seed); 491 dataset = dataset.Shuffle(random); 492 } 493 var problemDataClone = new ClassificationProblemData(dataset, Problem.ProblemData.AllowedInputVariables, targetVariable); 457 494 // set partitions of problem data clone correctly 458 495 problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value; … … 537 574 algorithm.ProblemChanged += new EventHandler(Algorithm_ProblemChanged); 538 575 algorithm.ExecutionStateChanged += new EventHandler(Algorithm_ExecutionStateChanged); 539 if (Problem != null) Problem.Reset += new EventHandler(Problem_Reset); 576 if (Problem != null) { 577 Problem.Reset += new EventHandler(Problem_Reset); 578 } 540 579 } 541 580 private void DeregisterAlgorithmEvents() { 542 581 algorithm.ProblemChanged -= new EventHandler(Algorithm_ProblemChanged); 543 582 algorithm.ExecutionStateChanged -= new EventHandler(Algorithm_ExecutionStateChanged); 544 if (Problem != null) Problem.Reset -= new EventHandler(Problem_Reset); 583 if (Problem != null) { 584 Problem.Reset -= new EventHandler(Problem_Reset); 585 } 545 586 } 546 587 private void Algorithm_ProblemChanged(object sender, EventArgs e) { … … 560 601 ConfigureProblem(); 561 602 } 562 563 603 private void Problem_Reset(object sender, EventArgs e) { 564 604 ConfigureProblem(); 565 605 } 566 567 606 private void ConfigureProblem() { 568 607 SamplesStart.Value = 0; … … 590 629 private void Algorithm_ExecutionStateChanged(object sender, EventArgs e) { 591 630 switch (Algorithm.ExecutionState) { 592 case ExecutionState.Prepared: OnPrepared(); 631 case ExecutionState.Prepared: 632 OnPrepared(); 593 633 break; 594 634 case ExecutionState.Started: throw new InvalidOperationException("Algorithm template can not be started."); 595 635 case ExecutionState.Paused: throw new InvalidOperationException("Algorithm template can not be paused."); 596 case ExecutionState.Stopped: OnStopped(); 636 case ExecutionState.Stopped: 637 OnStopped(); 597 638 break; 598 639 } … … 724 765 AggregateResultValues(collectedResults); 725 766 results.AddRange(collectedResults.Select(x => new Result(x.Key, x.Value)).Cast<IResult>().ToArray()); 767 clonedAlgorithms.Clear(); 726 768 runsCounter++; 727 769 runs.Add(new Run(string.Format("{0} Run {1}", Name, runsCounter), this));
Note: See TracChangeset
for help on using the changeset viewer.