Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/14/16 17:16:12 (8 years ago)
Author:
mkommend
Message:

#2590: Extracted estimated values calculation from RegressionEnsembleSolution to the according model.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r12816 r13697  
    169169        var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys);
    170170        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    171         var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
     171        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
    172172
    173173        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    184184        var rowsToEvaluate = rows.Except(testEvaluationCache.Keys);
    185185        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    186         var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();
     186        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, RowIsTestForModel).GetEnumerator();
    187187
    188188        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    193193      }
    194194    }
    195 
    196     private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
    197       var estimatedValuesEnumerators = (from model in Model.Models
    198                                         select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    199                                        .ToList();
    200       var rowsEnumerator = rows.GetEnumerator();
    201       // aggregate to make sure that MoveNext is called for all enumerators
    202       while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    203         int currentRow = rowsEnumerator.Current;
    204 
    205         var selectedEnumerators = from pair in estimatedValuesEnumerators
    206                                   where modelSelectionPredicate(currentRow, pair.Model)
    207                                   select pair.EstimatedValuesEnumerator;
    208 
    209         yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
    210       }
    211     }
    212 
    213195    private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) {
    214196      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
    215197              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
    216198    }
    217 
    218199    private bool RowIsTestForModel(int currentRow, IRegressionModel model) {
    219200      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     
    224205      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
    225206      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    226       var valuesEnumerator = (from xs in GetEstimatedValueVectors(ProblemData.Dataset, rowsToEvaluate)
    227                               select AggregateEstimatedValues(xs))
    228                              .GetEnumerator();
     207      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
    229208
    230209      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    235214    }
    236215
    237     public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
    238       if (!Model.Models.Any()) yield break;
    239       var estimatedValuesEnumerators = (from model in Model.Models
    240                                         select model.GetEstimatedValues(dataset, rows).GetEnumerator())
    241                                        .ToList();
    242 
    243       while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
    244         yield return from enumerator in estimatedValuesEnumerators
    245                      select enumerator.Current;
    246       }
    247     }
    248 
    249     private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) {
    250       return estimatedValues.DefaultIfEmpty(double.NaN).Average();
     216    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IEnumerable<int> rows) {
     217      return Model.GetEstimatedValueVectors(ProblemData.Dataset, rows);
    251218    }
    252219    #endregion
Note: See TracChangeset for help on using the changeset viewer.