Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/22/15 16:20:26 (9 years ago)
Author:
gkronber
Message:

#2434 added support for a partition variable to CrossValidation

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/CrossValidation.cs

    r12504 r12800  
    5252      results = new ResultCollection();
    5353
    54       folds = new IntValue(2);
     54      folds = 2;
    5555      numberOfWorkers = new IntValue(1);
    5656      samplesStart = new IntValue(0);
     
    7373      RegisterEvents();
    7474      if (Algorithm != null) RegisterAlgorithmEvents();
     75
     76      // BackwardsCompatibility3.4
     77      #region Backwards compatible code, remove with 3.6
     78      if (persistenceCompatibilityFolds != null) folds = persistenceCompatibilityFolds.Value;
     79      #endregion
    7580    }
    7681
     
    8590      results = cloner.Clone(original.results);
    8691
    87       folds = cloner.Clone(original.folds);
     92      folds = original.folds;
    8893      numberOfWorkers = cloner.Clone(original.numberOfWorkers);
    8994      samplesStart = cloner.Clone(original.samplesStart);
     
    171176    }
    172177
    173     [Storable]
    174     private IntValue folds;
    175     public IntValue Folds {
     178    // BackwardsCompatibility3.4
     179    #region Backwards compatible code, remove with 3.6
     180    [Storable(Name = "folds")]
     181    private IntValue persistenceCompatibilityFolds;
     182    #endregion
     183
     184    [Storable(Name = "foldsNew")]
     185    private int folds;
     186    public int Folds {
    176187      get { return folds; }
    177     }
     188      set {
     189        if (value != folds) {
     190          folds = value;
     191          partitionVariable = NoPartitionVariable; // setting folds updates the partition variable
     192          OnFoldsChanged();
     193        }
     194      }
     195    }
     196
     197    [Storable]
     198    // folds are either specified explicitly or by a varaible from the dataset
     199    public const string NoPartitionVariable = "<none>";
     200    private string partitionVariable = NoPartitionVariable;
     201    public string PartitionVariable {
     202      get { return partitionVariable; }
     203      set {
     204        if (value != partitionVariable) {
     205          partitionVariable = value;
     206          UpdateNumberOfFolds();
     207        }
     208      }
     209    }
     210
    178211    [Storable]
    179212    private IntValue samplesStart;
     
    244277      }
    245278    }
     279
    246280    #endregion
    247281
     
    273307        //create cloned algorithms
    274308        if (clonedAlgorithms.Count == 0) {
    275           int testSamplesCount = (SamplesEnd.Value - SamplesStart.Value) / Folds.Value;
    276 
    277           for (int i = 0; i < Folds.Value; i++) {
     309
     310          for (int i = 0; i < folds; i++) {
    278311            IAlgorithm clonedAlgorithm = (IAlgorithm)algorithm.Clone();
    279312            clonedAlgorithm.Name = algorithm.Name + " Fold " + i;
     
    281314            ISymbolicDataAnalysisProblem symbolicProblem = problem as ISymbolicDataAnalysisProblem;
    282315
    283             int testStart = (i * testSamplesCount) + SamplesStart.Value;
    284             int testEnd = (i + 1) == Folds.Value ? SamplesEnd.Value : (i + 1) * testSamplesCount + SamplesStart.Value;
     316            int trainingStart, trainingEnd, testStart, testEnd;
     317            // assumes that partitions are subset of subsequent rows
     318            GetTrainingAndTestPartitions(i, out trainingStart, out trainingEnd, out testStart, out testEnd);
    285319
    286320            problem.ProblemData.TrainingPartition.Start = SamplesStart.Value;
     
    306340        int startedAlgorithms = 0;
    307341        foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
    308           if (startedAlgorithms < NumberOfWorkers.Value) {
     342          if (startedAlgorithms < numberOfWorkers.Value) {
    309343            if (clonedAlgorithm.ExecutionState == ExecutionState.Prepared ||
    310344                clonedAlgorithm.ExecutionState == ExecutionState.Paused) {
     
    324358        }
    325359        OnStarted();
     360      }
     361    }
     362
     363    private void GetTrainingAndTestPartitions(int fold, out int trainingStart, out int trainingEnd, out int testStart, out int testEnd) {
     364      if (partitionVariable == NoPartitionVariable) {
     365        // uniform split
     366        int testSamplesCount = (SamplesEnd.Value - SamplesStart.Value) / folds;
     367        trainingStart = SamplesStart.Value;
     368        trainingEnd = SamplesEnd.Value;
     369        testStart = (fold * testSamplesCount) + SamplesStart.Value;
     370        testEnd = (fold + 1) == folds ? SamplesEnd.Value : (fold + 1) * testSamplesCount + SamplesStart.Value;
     371      } else {
     372        // group rowIdx by partition
     373        var partition = Problem.ProblemData.Dataset.GetReadOnlyDoubleValues(partitionVariable);
     374        var g = Enumerable.Range(SamplesStart.Value, SamplesEnd.Value - SamplesStart.Value).GroupBy(r => partition[r]).OrderBy(r => r.Key).ToList();
     375        trainingStart = SamplesStart.Value;
     376        trainingEnd = SamplesEnd.Value;
     377        testStart = g[fold].Min();
     378        testEnd = g[fold].Max() + 1;
    326379      }
    327380    }
     
    365418      values.Add("Algorithm Name", new StringValue(Name));
    366419      values.Add("Algorithm Type", new StringValue(GetType().GetPrettyName()));
    367       values.Add("Folds", new IntValue(Folds.Value));
     420      values.Add("Folds", new IntValue(folds));
    368421
    369422      if (algorithm != null) {
     
    516569    #region events
    517570    private void RegisterEvents() {
    518       Folds.ValueChanged += new EventHandler(Folds_ValueChanged);
     571      SamplesStart.ValueChanged += (s, e) => UpdateNumberOfFolds();
     572      SamplesEnd.ValueChanged += (s, e) => UpdateNumberOfFolds();
     573
    519574      RegisterClonedAlgorithmsEvents();
    520575    }
    521     private void Folds_ValueChanged(object sender, EventArgs e) {
    522       if (ExecutionState != ExecutionState.Prepared)
    523         throw new InvalidOperationException("Can not change number of folds if the execution state is not prepared.");
     576
     577    private void UpdateNumberOfFolds() {
     578      if (ExecutionState == ExecutionState.Paused || executionState == ExecutionState.Started)
     579        throw new InvalidOperationException("Can not change number of folds if crossvalidation is paused or started.");
     580
     581      if (partitionVariable != NoPartitionVariable) {
     582        // number of folds is the number of distinct values of the partition variable in the range [SamplesStart..SamplesEnd[
     583        var ds = Problem.ProblemData.Dataset;
     584        var partitionValues = ds.GetDoubleValues(PartitionVariable, Enumerable.Range(SamplesStart.Value, SamplesEnd.Value - SamplesStart.Value)).Distinct().Count();
     585        folds = partitionValues;
     586        OnFoldsChanged();
     587      }
    524588    }
    525589
     
    555619    public event EventHandler ProblemChanged;
    556620    private void OnProblemChanged() {
     621      ConfigureProblem();
    557622      EventHandler handler = ProblemChanged;
    558623      if (handler != null) handler(this, EventArgs.Empty);
    559       ConfigureProblem();
    560624    }
    561625
     
    565629
    566630    private void ConfigureProblem() {
     631      folds = 2;
     632      partitionVariable = NoPartitionVariable;
    567633      SamplesStart.Value = 0;
    568634      if (Problem != null) {
     
    687753
    688754    #region event firing
     755
     756    public event EventHandler FoldsChanged;
     757    private void OnFoldsChanged() {
     758      EventHandler handler = FoldsChanged;
     759      if (handler != null) handler(this, EventArgs.Empty);
     760    }
    689761    public event EventHandler ExecutionStateChanged;
    690762    private void OnExecutionStateChanged() {
Note: See TracChangeset for help on using the changeset viewer.