Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/17/16 17:48:36 (9 years ago)
Author:
mkommend
Message:

#2590: Implemented an weighted average to ease weights manipulation in regression ensembles and corrected locking in the the weights view.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r13705 r13715  
    109109    }
    110110
    111     #region IRegressionEnsembleModel Members
    112111    public void Add(IRegressionModel model) {
    113112      Add(model, 1.0);
     
    153152    }
    154153
     154    #region evaluation
    155155    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
    156156      var estimatedValuesEnumerators = (from model in models
     
    165165    }
    166166
     167    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
     168      double weightsSum = modelWeights.Sum();
     169      var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)
     170                            select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum();
     171
     172      if (AverageModelEstimates)
     173        return summedEstimates.Select(v => v / weightsSum);
     174      else
     175        return summedEstimates;
     176
     177    }
     178
    167179    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
    168180      var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator();
    169181      var rowsEnumerator = rows.GetEnumerator();
    170182
    171       // aggregate to make sure that MoveNext is called for all enumerators
    172183      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) {
     184        var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator();
    173185        int currentRow = rowsEnumerator.Current;
    174 
    175         var filteredEstimates = models.Zip(estimatedValuesEnumerators.Current, (m, e) => new { Model = m, EstimatedValue = e })
    176                                       .Where(f => modelSelectionPredicate(currentRow, f.Model))
    177                                       .Select(f => f.EstimatedValue).DefaultIfEmpty(double.NaN);
    178 
    179         yield return AggregateEstimatedValues(filteredEstimates);
    180       }
    181     }
    182 
    183     private double AggregateEstimatedValues(IEnumerable<double> estimatedValuesVector) {
    184       if (AverageModelEstimates)
    185         return estimatedValuesVector.Average();
    186       else
    187         return estimatedValuesVector.Sum();
    188     }
     186        double weightsSum = 0.0;
     187        double filteredEstimatesSum = 0.0;
     188
     189        for (int m = 0; m < models.Count; m++) {
     190          estimatedValueEnumerator.MoveNext();
     191          var model = models[m];
     192          if (!modelSelectionPredicate(currentRow, model)) continue;
     193
     194          filteredEstimatesSum += estimatedValueEnumerator.Current;
     195          weightsSum += modelWeights[m];
     196        }
     197
     198        if (AverageModelEstimates)
     199          yield return filteredEstimatesSum / weightsSum;
     200        else
     201          yield return filteredEstimatesSum;
     202      }
     203    }
     204
     205    #endregion
    189206
    190207    public event EventHandler Changed;
     
    194211        handler(this, EventArgs.Empty);
    195212    }
    196     #endregion
    197 
    198     #region IRegressionModel Members
    199     public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    200       foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) {
    201         yield return AggregateEstimatedValues(estimatedValuesVector.DefaultIfEmpty(double.NaN));
    202       }
    203     }
     213
    204214
    205215    public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     
    209219      return CreateRegressionSolution(problemData);
    210220    }
    211     #endregion
    212221  }
    213222}
Note: See TracChangeset for help on using the changeset viewer.