Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/14/16 17:16:12 (9 years ago)
Author:
mkommend
Message:

#2590: Extracted estimated values calculation from RegressionEnsembleSolution to the according model.

Location:
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r12509 r13697  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     
    9192    }
    9293
     94    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
     95      var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator();
     96      var rowsEnumerator = rows.GetEnumerator();
     97
     98      // aggregate to make sure that MoveNext is called for all enumerators
     99      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) {
     100        int currentRow = rowsEnumerator.Current;
     101
     102        var filteredEstimates = models.Zip(estimatedValuesEnumerators.Current,
     103          (m, e) => new { Model = m, EstimatedValue = e }).Where(f => modelSelectionPredicate(currentRow, f.Model));
     104
     105        yield return filteredEstimates.Select(f => f.EstimatedValue).DefaultIfEmpty(double.NaN).Average();
     106      }
     107    }
     108
    93109    #endregion
    94110
    95111    #region IRegressionModel Members
    96 
    97112    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    98113      foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) {
     
    107122      return CreateRegressionSolution(problemData);
    108123    }
    109 
    110124    #endregion
    111125  }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r12816 r13697  
    169169        var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys);
    170170        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    171         var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
     171        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
    172172
    173173        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    184184        var rowsToEvaluate = rows.Except(testEvaluationCache.Keys);
    185185        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    186         var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();
     186        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, RowIsTestForModel).GetEnumerator();
    187187
    188188        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    193193      }
    194194    }
    195 
    196     private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
    197       var estimatedValuesEnumerators = (from model in Model.Models
    198                                         select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    199                                        .ToList();
    200       var rowsEnumerator = rows.GetEnumerator();
    201       // aggregate to make sure that MoveNext is called for all enumerators
    202       while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    203         int currentRow = rowsEnumerator.Current;
    204 
    205         var selectedEnumerators = from pair in estimatedValuesEnumerators
    206                                   where modelSelectionPredicate(currentRow, pair.Model)
    207                                   select pair.EstimatedValuesEnumerator;
    208 
    209         yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
    210       }
    211     }
    212 
    213195    private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) {
    214196      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
    215197              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
    216198    }
    217 
    218199    private bool RowIsTestForModel(int currentRow, IRegressionModel model) {
    219200      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     
    224205      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
    225206      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    226       var valuesEnumerator = (from xs in GetEstimatedValueVectors(ProblemData.Dataset, rowsToEvaluate)
    227                               select AggregateEstimatedValues(xs))
    228                              .GetEnumerator();
     207      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
    229208
    230209      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    235214    }
    236215
    237     public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
    238       if (!Model.Models.Any()) yield break;
    239       var estimatedValuesEnumerators = (from model in Model.Models
    240                                         select model.GetEstimatedValues(dataset, rows).GetEnumerator())
    241                                        .ToList();
    242 
    243       while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
    244         yield return from enumerator in estimatedValuesEnumerators
    245                      select enumerator.Current;
    246       }
    247     }
    248 
    249     private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) {
    250       return estimatedValues.DefaultIfEmpty(double.NaN).Average();
     216    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IEnumerable<int> rows) {
     217      return Model.GetEstimatedValueVectors(ProblemData.Dataset, rows);
    251218    }
    252219    #endregion
Note: See TracChangeset for help on using the changeset viewer.