Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/08/12 14:04:17 (12 years ago)
Author:
mkommend
Message:

#1081: Intermediate commit of trunk updates - interpreter changes must be redone.

Location:
branches/HeuristicLab.TimeSeries
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.TimeSeries

    • Property svn:ignore
      •  

        old new  
        2020bin
        2121protoc.exe
         22_ReSharper.HeuristicLab.TimeSeries-3.3
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis

  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r7268 r8430  
    3737  [Creatable("Data Analysis - Ensembles")]
    3838  public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution {
     39    private readonly Dictionary<int, double> trainingEvaluationCache = new Dictionary<int, double>();
     40    private readonly Dictionary<int, double> testEvaluationCache = new Dictionary<int, double>();
     41
    3942    public new IClassificationEnsembleModel Model {
    4043      get { return (IClassificationEnsembleModel)base.Model; }
     
    8588      }
    8689
     90      trainingEvaluationCache = new Dictionary<int, double>(original.ProblemData.TrainingIndices.Count());
     91      testEvaluationCache = new Dictionary<int, double>(original.ProblemData.TestIndices.Count());
     92
    8793      classificationSolutions = cloner.Clone(original.classificationSolutions);
    8894      RegisterClassificationSolutionsEventHandler();
     
    128134      }
    129135
     136      trainingEvaluationCache = new Dictionary<int, double>(problemData.TrainingIndices.Count());
     137      testEvaluationCache = new Dictionary<int, double>(problemData.TestIndices.Count());
     138
    130139      RegisterClassificationSolutionsEventHandler();
    131140      classificationSolutions.AddRange(solutions);
     
    148157    public override IEnumerable<double> EstimatedTrainingClassValues {
    149158      get {
    150         var rows = ProblemData.TrainingIndizes;
    151         var estimatedValuesEnumerators = (from model in Model.Models
    152                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
    153                                          .ToList();
    154         var rowsEnumerator = rows.GetEnumerator();
    155         // aggregate to make sure that MoveNext is called for all enumerators
    156         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    157           int currentRow = rowsEnumerator.Current;
    158 
    159           var selectedEnumerators = from pair in estimatedValuesEnumerators
    160                                     where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
    161                                     select pair.EstimatedValuesEnumerator;
    162           yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     159        var rows = ProblemData.TrainingIndices;
     160        var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys);
     161        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     162        var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
     163
     164        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     165          trainingEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
    163166        }
     167
     168        return rows.Select(row => trainingEvaluationCache[row]);
    164169      }
    165170    }
     
    167172    public override IEnumerable<double> EstimatedTestClassValues {
    168173      get {
    169         var rows = ProblemData.TestIndizes;
    170         var estimatedValuesEnumerators = (from model in Model.Models
    171                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
    172                                          .ToList();
    173         var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
    174         // aggregate to make sure that MoveNext is called for all enumerators
    175         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    176           int currentRow = rowsEnumerator.Current;
    177 
    178           var selectedEnumerators = from pair in estimatedValuesEnumerators
    179                                     where RowIsTestForModel(currentRow, pair.Model)
    180                                     select pair.EstimatedValuesEnumerator;
    181 
    182           yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     174        var rows = ProblemData.TestIndices;
     175        var rowsToEvaluate = rows.Except(testEvaluationCache.Keys);
     176        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     177        var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();
     178
     179        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     180          testEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
    183181        }
     182
     183        return rows.Select(row => testEvaluationCache[row]);
     184      }
     185    }
     186
     187    private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IClassificationModel, bool> modelSelectionPredicate) {
     188      var estimatedValuesEnumerators = (from model in Model.Models
     189                                        select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     190                                       .ToList();
     191      var rowsEnumerator = rows.GetEnumerator();
     192      // aggregate to make sure that MoveNext is called for all enumerators
     193      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     194        int currentRow = rowsEnumerator.Current;
     195
     196        var selectedEnumerators = from pair in estimatedValuesEnumerators
     197                                  where modelSelectionPredicate(currentRow, pair.Model)
     198                                  select pair.EstimatedValuesEnumerator;
     199
     200        yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
    184201      }
    185202    }
     
    196213
    197214    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    198       return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows)
    199              select AggregateEstimatedClassValues(xs);
     215      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
     216      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     217      var valuesEnumerator = (from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rowsToEvaluate)
     218                              select AggregateEstimatedClassValues(xs))
     219                             .GetEnumerator();
     220
     221      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     222        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     223      }
     224
     225      return rows.Select(row => evaluationCache[row]);
    200226    }
    201227
     
    223249
    224250    protected override void OnProblemDataChanged() {
     251      trainingEvaluationCache.Clear();
     252      testEvaluationCache.Clear();
     253      evaluationCache.Clear();
     254
    225255      IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset,
    226256                                                                     ProblemData.AllowedInputVariables,
     
    251281    public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
    252282      classificationSolutions.AddRange(solutions);
     283
     284      trainingEvaluationCache.Clear();
     285      testEvaluationCache.Clear();
     286      evaluationCache.Clear();
    253287    }
    254288    public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
    255289      classificationSolutions.RemoveRange(solutions);
     290
     291      trainingEvaluationCache.Clear();
     292      testEvaluationCache.Clear();
     293      evaluationCache.Clear();
    256294    }
    257295
     
    275313      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    276314      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
     315
     316      trainingEvaluationCache.Clear();
     317      testEvaluationCache.Clear();
     318      evaluationCache.Clear();
    277319    }
    278320
     
    282324      trainingPartitions.Remove(solution.Model);
    283325      testPartitions.Remove(solution.Model);
     326
     327      trainingEvaluationCache.Clear();
     328      testEvaluationCache.Clear();
     329      evaluationCache.Clear();
    284330    }
    285331  }
Note: See TracChangeset for help on using the changeset viewer.