- Timestamp:
- 07/22/15 16:20:26 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/CrossValidation.cs
r12504 r12800 52 52 results = new ResultCollection(); 53 53 54 folds = new IntValue(2);54 folds = 2; 55 55 numberOfWorkers = new IntValue(1); 56 56 samplesStart = new IntValue(0); … … 73 73 RegisterEvents(); 74 74 if (Algorithm != null) RegisterAlgorithmEvents(); 75 76 // BackwardsCompatibility3.4 77 #region Backwards compatible code, remove with 3.6 78 if (persistenceCompatibilityFolds != null) folds = persistenceCompatibilityFolds.Value; 79 #endregion 75 80 } 76 81 … … 85 90 results = cloner.Clone(original.results); 86 91 87 folds = cloner.Clone(original.folds);92 folds = original.folds; 88 93 numberOfWorkers = cloner.Clone(original.numberOfWorkers); 89 94 samplesStart = cloner.Clone(original.samplesStart); … … 171 176 } 172 177 173 [Storable] 174 private IntValue folds; 175 public IntValue Folds { 178 // BackwardsCompatibility3.4 179 #region Backwards compatible code, remove with 3.6 180 [Storable(Name = "folds")] 181 private IntValue persistenceCompatibilityFolds; 182 #endregion 183 184 [Storable(Name = "foldsNew")] 185 private int folds; 186 public int Folds { 176 187 get { return folds; } 177 } 188 set { 189 if (value != folds) { 190 folds = value; 191 partitionVariable = NoPartitionVariable; // setting folds updates the partition variable 192 OnFoldsChanged(); 193 } 194 } 195 } 196 197 [Storable] 198 // folds are either specified explicitly or by a varaible from the dataset 199 public const string NoPartitionVariable = "<none>"; 200 private string partitionVariable = NoPartitionVariable; 201 public string PartitionVariable { 202 get { return partitionVariable; } 203 set { 204 if (value != partitionVariable) { 205 partitionVariable = value; 206 UpdateNumberOfFolds(); 207 } 208 } 209 } 210 178 211 [Storable] 179 212 private IntValue samplesStart; … … 244 277 } 245 278 } 279 246 280 #endregion 247 281 … … 273 307 //create cloned algorithms 274 308 if (clonedAlgorithms.Count == 0) { 275 int testSamplesCount = (SamplesEnd.Value - SamplesStart.Value) / Folds.Value; 276 277 for (int i = 0; i < Folds.Value; i++) { 309 310 for (int i = 0; i < folds; i++) { 278 311 IAlgorithm clonedAlgorithm = (IAlgorithm)algorithm.Clone(); 279 312 clonedAlgorithm.Name = algorithm.Name + " Fold " + i; … … 281 314 ISymbolicDataAnalysisProblem symbolicProblem = problem as ISymbolicDataAnalysisProblem; 282 315 283 int testStart = (i * testSamplesCount) + SamplesStart.Value; 284 int testEnd = (i + 1) == Folds.Value ? SamplesEnd.Value : (i + 1) * testSamplesCount + SamplesStart.Value; 316 int trainingStart, trainingEnd, testStart, testEnd; 317 // assumes that partitions are subset of subsequent rows 318 GetTrainingAndTestPartitions(i, out trainingStart, out trainingEnd, out testStart, out testEnd); 285 319 286 320 problem.ProblemData.TrainingPartition.Start = SamplesStart.Value; … … 306 340 int startedAlgorithms = 0; 307 341 foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) { 308 if (startedAlgorithms < NumberOfWorkers.Value) {342 if (startedAlgorithms < numberOfWorkers.Value) { 309 343 if (clonedAlgorithm.ExecutionState == ExecutionState.Prepared || 310 344 clonedAlgorithm.ExecutionState == ExecutionState.Paused) { … … 324 358 } 325 359 OnStarted(); 360 } 361 } 362 363 private void GetTrainingAndTestPartitions(int fold, out int trainingStart, out int trainingEnd, out int testStart, out int testEnd) { 364 if (partitionVariable == NoPartitionVariable) { 365 // uniform split 366 int testSamplesCount = (SamplesEnd.Value - SamplesStart.Value) / folds; 367 trainingStart = SamplesStart.Value; 368 trainingEnd = SamplesEnd.Value; 369 testStart = (fold * testSamplesCount) + SamplesStart.Value; 370 testEnd = (fold + 1) == folds ? SamplesEnd.Value : (fold + 1) * testSamplesCount + SamplesStart.Value; 371 } else { 372 // group rowIdx by partition 373 var partition = Problem.ProblemData.Dataset.GetReadOnlyDoubleValues(partitionVariable); 374 var g = Enumerable.Range(SamplesStart.Value, SamplesEnd.Value - SamplesStart.Value).GroupBy(r => partition[r]).OrderBy(r => r.Key).ToList(); 375 trainingStart = SamplesStart.Value; 376 trainingEnd = SamplesEnd.Value; 377 testStart = g[fold].Min(); 378 testEnd = g[fold].Max() + 1; 326 379 } 327 380 } … … 365 418 values.Add("Algorithm Name", new StringValue(Name)); 366 419 values.Add("Algorithm Type", new StringValue(GetType().GetPrettyName())); 367 values.Add("Folds", new IntValue( Folds.Value));420 values.Add("Folds", new IntValue(folds)); 368 421 369 422 if (algorithm != null) { … … 516 569 #region events 517 570 private void RegisterEvents() { 518 Folds.ValueChanged += new EventHandler(Folds_ValueChanged); 571 SamplesStart.ValueChanged += (s, e) => UpdateNumberOfFolds(); 572 SamplesEnd.ValueChanged += (s, e) => UpdateNumberOfFolds(); 573 519 574 RegisterClonedAlgorithmsEvents(); 520 575 } 521 private void Folds_ValueChanged(object sender, EventArgs e) { 522 if (ExecutionState != ExecutionState.Prepared) 523 throw new InvalidOperationException("Can not change number of folds if the execution state is not prepared."); 576 577 private void UpdateNumberOfFolds() { 578 if (ExecutionState == ExecutionState.Paused || executionState == ExecutionState.Started) 579 throw new InvalidOperationException("Can not change number of folds if crossvalidation is paused or started."); 580 581 if (partitionVariable != NoPartitionVariable) { 582 // number of folds is the number of distinct values of the partition variable in the range [SamplesStart..SamplesEnd[ 583 var ds = Problem.ProblemData.Dataset; 584 var partitionValues = ds.GetDoubleValues(PartitionVariable, Enumerable.Range(SamplesStart.Value, SamplesEnd.Value - SamplesStart.Value)).Distinct().Count(); 585 folds = partitionValues; 586 OnFoldsChanged(); 587 } 524 588 } 525 589 … … 555 619 public event EventHandler ProblemChanged; 556 620 private void OnProblemChanged() { 621 ConfigureProblem(); 557 622 EventHandler handler = ProblemChanged; 558 623 if (handler != null) handler(this, EventArgs.Empty); 559 ConfigureProblem();560 624 } 561 625 … … 565 629 566 630 private void ConfigureProblem() { 631 folds = 2; 632 partitionVariable = NoPartitionVariable; 567 633 SamplesStart.Value = 0; 568 634 if (Problem != null) { … … 687 753 688 754 #region event firing 755 756 public event EventHandler FoldsChanged; 757 private void OnFoldsChanged() { 758 EventHandler handler = FoldsChanged; 759 if (handler != null) handler(this, EventArgs.Empty); 760 } 689 761 public event EventHandler ExecutionStateChanged; 690 762 private void OnExecutionStateChanged() {
Note: See TracChangeset
for help on using the changeset viewer.