Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/03/12 16:46:35 (12 years ago)
Author:
gkronber
Message:

#1847: merged r8084:8205 from trunk into GP move operators branch

Location:
branches/GP-MoveOperators
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/GP-MoveOperators

  • branches/GP-MoveOperators/HeuristicLab.Problems.DataAnalysis

  • branches/GP-MoveOperators/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r7832 r8206  
    3737  [Creatable("Data Analysis - Ensembles")]
    3838  public sealed class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {
     39    private readonly Dictionary<int, double> trainingEvaluationCache = new Dictionary<int, double>();
     40    private readonly Dictionary<int, double> testEvaluationCache = new Dictionary<int, double>();
     41
    3942    public new IRegressionEnsembleModel Model {
    4043      get { return (IRegressionEnsembleModel)base.Model; }
     
    5255
    5356    [Storable]
    54     private Dictionary<IRegressionModel, IntRange> trainingPartitions;
     57    private readonly Dictionary<IRegressionModel, IntRange> trainingPartitions;
    5558    [Storable]
    56     private Dictionary<IRegressionModel, IntRange> testPartitions;
     59    private readonly Dictionary<IRegressionModel, IntRange> testPartitions;
    5760
    5861    [StorableConstructor]
     
    8689      }
    8790
     91      trainingEvaluationCache = new Dictionary<int, double>(original.ProblemData.TrainingIndices.Count());
     92      testEvaluationCache = new Dictionary<int, double>(original.ProblemData.TestIndices.Count());
     93
    8894      regressionSolutions = cloner.Clone(original.regressionSolutions);
    8995      RegisterRegressionSolutionsEventHandler();
     
    133139      }
    134140
     141      trainingEvaluationCache = new Dictionary<int, double>(problemData.TrainingIndices.Count());
     142      testEvaluationCache = new Dictionary<int, double>(problemData.TestIndices.Count());
     143
    135144      RegisterRegressionSolutionsEventHandler();
    136145      regressionSolutions.AddRange(solutions);
     
    153162    public override IEnumerable<double> EstimatedTrainingValues {
    154163      get {
    155         var rows = ProblemData.TrainingIndizes;
    156         var estimatedValuesEnumerators = (from model in Model.Models
    157                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    158                                          .ToList();
    159         var rowsEnumerator = rows.GetEnumerator();
    160         // aggregate to make sure that MoveNext is called for all enumerators
    161         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    162           int currentRow = rowsEnumerator.Current;
    163 
    164           var selectedEnumerators = from pair in estimatedValuesEnumerators
    165                                     where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
    166                                     select pair.EstimatedValuesEnumerator;
    167           yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
     164        var rows = ProblemData.TrainingIndices;
     165        var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys);
     166        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     167        var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
     168
     169        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     170          trainingEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
    168171        }
     172
     173        return rows.Select(row => trainingEvaluationCache[row]);
    169174      }
    170175    }
     
    172177    public override IEnumerable<double> EstimatedTestValues {
    173178      get {
    174         var rows = ProblemData.TestIndizes;
    175         var estimatedValuesEnumerators = (from model in Model.Models
    176                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    177                                          .ToList();
    178         var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
    179         // aggregate to make sure that MoveNext is called for all enumerators
    180         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    181           int currentRow = rowsEnumerator.Current;
    182 
    183           var selectedEnumerators = from pair in estimatedValuesEnumerators
    184                                     where RowIsTestForModel(currentRow, pair.Model)
    185                                     select pair.EstimatedValuesEnumerator;
    186 
    187           yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
     179        var rows = ProblemData.TestIndices;
     180        var rowsToEvaluate = rows.Except(testEvaluationCache.Keys);
     181        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     182        var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();
     183
     184        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     185          testEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
    188186        }
     187
     188        return rows.Select(row => testEvaluationCache[row]);
     189      }
     190    }
     191
     192    private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
     193      var estimatedValuesEnumerators = (from model in Model.Models
     194                                        select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
     195                                       .ToList();
     196      var rowsEnumerator = rows.GetEnumerator();
     197      // aggregate to make sure that MoveNext is called for all enumerators
     198      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     199        int currentRow = rowsEnumerator.Current;
     200
     201        var selectedEnumerators = from pair in estimatedValuesEnumerators
     202                                  where modelSelectionPredicate(currentRow, pair.Model)
     203                                  select pair.EstimatedValuesEnumerator;
     204
     205        yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
    189206      }
    190207    }
     
    201218
    202219    public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    203       return from xs in GetEstimatedValueVectors(ProblemData.Dataset, rows)
    204              select AggregateEstimatedValues(xs);
     220      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
     221      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     222      var valuesEnumerator = (from xs in GetEstimatedValueVectors(ProblemData.Dataset, rowsToEvaluate)
     223                              select AggregateEstimatedValues(xs))
     224                             .GetEnumerator();
     225
     226      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     227        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     228      }
     229
     230      return rows.Select(row => evaluationCache[row]);
    205231    }
    206232
     
    230256
    231257    protected override void OnProblemDataChanged() {
     258      trainingEvaluationCache.Clear();
     259      testEvaluationCache.Clear();
     260      evaluationCache.Clear();
    232261      IRegressionProblemData problemData = new RegressionProblemData(ProblemData.Dataset,
    233262                                                                     ProblemData.AllowedInputVariables,
     
    258287    public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
    259288      regressionSolutions.AddRange(solutions);
     289
     290      trainingEvaluationCache.Clear();
     291      testEvaluationCache.Clear();
     292      evaluationCache.Clear();
    260293    }
    261294    public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
    262295      regressionSolutions.RemoveRange(solutions);
     296
     297      trainingEvaluationCache.Clear();
     298      testEvaluationCache.Clear();
     299      evaluationCache.Clear();
    263300    }
    264301
     
    282319      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    283320      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
     321
     322      trainingEvaluationCache.Clear();
     323      testEvaluationCache.Clear();
     324      evaluationCache.Clear();
    284325    }
    285326
     
    289330      trainingPartitions.Remove(solution.Model);
    290331      testPartitions.Remove(solution.Model);
     332
     333      trainingEvaluationCache.Clear();
     334      testEvaluationCache.Clear();
     335      evaluationCache.Clear();
    291336    }
    292337  }
Note: See TracChangeset for help on using the changeset viewer.