Free cookie consent management tool by TermsFeed Policy Generator

Changeset 8151


Ignore:
Timestamp:
06/28/12 15:55:16 (12 years ago)
Author:
gkronber
Message:

#1720: preparation for estimated values caching in regression ensemble solution

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r8139 r8151  
    3737  [Creatable("Data Analysis - Ensembles")]
    3838  public sealed class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {
     39    private readonly Dictionary<int, double> trainingEstimatedValuesCache = new Dictionary<int, double>();
     40    private readonly Dictionary<int, double> testEstimatedValuesCache = new Dictionary<int, double>();
     41    private readonly Dictionary<int, double> estimatedValuesCache = new Dictionary<int, double>();
     42
    3943    public new IRegressionEnsembleModel Model {
    4044      get { return (IRegressionEnsembleModel)base.Model; }
     
    152156    #region Evaluation
    153157    public override IEnumerable<double> EstimatedTrainingValues {
    154       get {
    155         var rows = ProblemData.TrainingIndices;
    156         var estimatedValuesEnumerators = (from model in Model.Models
    157                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    158                                          .ToList();
    159         var rowsEnumerator = rows.GetEnumerator();
    160         // aggregate to make sure that MoveNext is called for all enumerators
    161         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    162           int currentRow = rowsEnumerator.Current;
    163 
    164           var selectedEnumerators = from pair in estimatedValuesEnumerators
    165                                     where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
    166                                     select pair.EstimatedValuesEnumerator;
    167           yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
    168         }
    169       }
     158      get { return GetEstimatedValues(ProblemData.TrainingIndices, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)); }
    170159    }
    171160
    172161    public override IEnumerable<double> EstimatedTestValues {
    173       get {
    174         var rows = ProblemData.TestIndices;
    175         var estimatedValuesEnumerators = (from model in Model.Models
    176                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    177                                          .ToList();
    178         var rowsEnumerator = ProblemData.TestIndices.GetEnumerator();
    179         // aggregate to make sure that MoveNext is called for all enumerators
    180         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    181           int currentRow = rowsEnumerator.Current;
    182 
    183           var selectedEnumerators = from pair in estimatedValuesEnumerators
    184                                     where RowIsTestForModel(currentRow, pair.Model)
    185                                     select pair.EstimatedValuesEnumerator;
    186 
    187           yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
    188         }
     162      get { return GetEstimatedValues(ProblemData.TestIndices, RowIsTestForModel); }
     163    }
     164
     165    private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
     166      var estimatedValuesEnumerators = (from model in Model.Models
     167                                        select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
     168                                       .ToList();
     169      var rowsEnumerator = rows.GetEnumerator();
     170      // aggregate to make sure that MoveNext is called for all enumerators
     171      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     172        int currentRow = rowsEnumerator.Current;
     173
     174        var selectedEnumerators = from pair in estimatedValuesEnumerators
     175                                  where modelSelectionPredicate(currentRow, pair.Model)
     176                                  select pair.EstimatedValuesEnumerator;
     177
     178        yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
    189179      }
    190180    }
Note: See TracChangeset for help on using the changeset viewer.