Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/28/12 16:25:58 (12 years ago)
Author:
gkronber
Message:

#1720 implemented estimated class values caching in ClassificationEnsembleSolution

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r8139 r8153  
    3737  [Creatable("Data Analysis - Ensembles")]
    3838  public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution {
     39    private readonly Dictionary<int, double> trainingEstimatedValuesCache = new Dictionary<int, double>();
     40    private readonly Dictionary<int, double> testEstimatedValuesCache = new Dictionary<int, double>();
     41    private readonly Dictionary<int, double> estimatedValuesCache = new Dictionary<int, double>();
     42
    3943    public new IClassificationEnsembleModel Model {
    4044      get { return (IClassificationEnsembleModel)base.Model; }
     
    149153      get {
    150154        var rows = ProblemData.TrainingIndices;
    151         var estimatedValuesEnumerators = (from model in Model.Models
    152                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
    153                                          .ToList();
    154         var rowsEnumerator = rows.GetEnumerator();
    155         // aggregate to make sure that MoveNext is called for all enumerators
    156         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    157           int currentRow = rowsEnumerator.Current;
    158 
    159           var selectedEnumerators = from pair in estimatedValuesEnumerators
    160                                     where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
    161                                     select pair.EstimatedValuesEnumerator;
    162           yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     155        var rowsToEvaluate = rows.Except(trainingEstimatedValuesCache.Keys);
     156        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     157        var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
     158
     159        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     160          trainingEstimatedValuesCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
    163161        }
     162
     163        return rows.Select(row => trainingEstimatedValuesCache[row]);
    164164      }
    165165    }
     
    168168      get {
    169169        var rows = ProblemData.TestIndices;
    170         var estimatedValuesEnumerators = (from model in Model.Models
    171                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
    172                                          .ToList();
    173         var rowsEnumerator = ProblemData.TestIndices.GetEnumerator();
    174         // aggregate to make sure that MoveNext is called for all enumerators
    175         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    176           int currentRow = rowsEnumerator.Current;
    177 
    178           var selectedEnumerators = from pair in estimatedValuesEnumerators
    179                                     where RowIsTestForModel(currentRow, pair.Model)
    180                                     select pair.EstimatedValuesEnumerator;
    181 
    182           yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     170        var rowsToEvaluate = rows.Except(testEstimatedValuesCache.Keys);
     171        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     172        var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();
     173
     174        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     175          testEstimatedValuesCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
    183176        }
     177
     178        return rows.Select(row => testEstimatedValuesCache[row]);
     179      }
     180    }
     181
     182    private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IClassificationModel, bool> modelSelectionPredicate) {
     183      var estimatedValuesEnumerators = (from model in Model.Models
     184                                        select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     185                                       .ToList();
     186      var rowsEnumerator = rows.GetEnumerator();
     187      // aggregate to make sure that MoveNext is called for all enumerators
     188      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     189        int currentRow = rowsEnumerator.Current;
     190
     191        var selectedEnumerators = from pair in estimatedValuesEnumerators
     192                                  where modelSelectionPredicate(currentRow, pair.Model)
     193                                  select pair.EstimatedValuesEnumerator;
     194
     195        yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
    184196      }
    185197    }
     
    196208
    197209    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    198       return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows)
    199              select AggregateEstimatedClassValues(xs);
     210      var rowsToEvaluate = rows.Except(estimatedValuesCache.Keys);
     211      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     212      var valuesEnumerator = (from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rowsToEvaluate)
     213                              select AggregateEstimatedClassValues(xs))
     214                             .GetEnumerator();
     215
     216      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     217        estimatedValuesCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     218      }
     219
     220      return rows.Select(row => estimatedValuesCache[row]);
    200221    }
    201222
     
    223244
    224245    protected override void OnProblemDataChanged() {
     246      trainingEstimatedValuesCache.Clear();
     247      testEstimatedValuesCache.Clear();
     248      estimatedValuesCache.Clear();
     249
    225250      IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset,
    226251                                                                     ProblemData.AllowedInputVariables,
     
    251276    public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
    252277      classificationSolutions.AddRange(solutions);
     278
     279      trainingEstimatedValuesCache.Clear();
     280      testEstimatedValuesCache.Clear();
     281      estimatedValuesCache.Clear();
    253282    }
    254283    public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
    255284      classificationSolutions.RemoveRange(solutions);
     285
     286      trainingEstimatedValuesCache.Clear();
     287      testEstimatedValuesCache.Clear();
     288      estimatedValuesCache.Clear();
    256289    }
    257290
     
    275308      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    276309      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
     310
     311      trainingEstimatedValuesCache.Clear();
     312      testEstimatedValuesCache.Clear();
     313      estimatedValuesCache.Clear();
    277314    }
    278315
     
    282319      trainingPartitions.Remove(solution.Model);
    283320      testPartitions.Remove(solution.Model);
     321
     322      trainingEstimatedValuesCache.Clear();
     323      testEstimatedValuesCache.Clear();
     324      estimatedValuesCache.Clear();
    284325    }
    285326  }
Note: See TracChangeset for help on using the changeset viewer.