Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/14/16 18:20:44 (9 years ago)
Author:
mkommend
Message:

#2590: Added flag to decide whether estimated values are averaged or summed up when evaluating a RegressionEnsembleModel.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r13698 r13700  
    3333  [StorableClass]
    3434  [Item("RegressionEnsembleModel", "A regression model that contains an ensemble of multiple regression models")]
    35   public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
     35  public sealed class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
    3636
    3737    private List<IRegressionModel> models;
     
    4646    }
    4747
     48    [Storable]
     49    private bool averageModelEstimates = true;
     50    public bool AverageModelEstimates {
     51      get { return averageModelEstimates; }
     52      set {
     53        if (averageModelEstimates != value) {
     54          averageModelEstimates = value;
     55          OnAverageModelEstimatesChanged();
     56        }
     57      }
     58    }
     59
    4860    #region backwards compatiblity 3.3.5
    4961    [Storable(Name = "models", AllowOneWay = true)]
     
    5466
    5567    [StorableConstructor]
    56     protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
    57     protected RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
     68    private RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
     69    private RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
    5870      : base(original, cloner) {
    59       this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
     71      this.models = original.Models.Select(cloner.Clone).ToList();
     72      this.averageModelEstimates = original.averageModelEstimates;
     73    }
     74    public override IDeepCloneable Clone(Cloner cloner) {
     75      return new RegressionEnsembleModel(this, cloner);
    6076    }
    6177
     
    6884    }
    6985
    70     public override IDeepCloneable Clone(Cloner cloner) {
    71       return new RegressionEnsembleModel(this, cloner);
    72     }
    73 
    7486    #region IRegressionEnsembleModel Members
    7587    public void Add(IRegressionModel model) {
    7688      models.Add(model);
    7789    }
     90
    7891    public void Remove(IRegressionModel model) {
    7992      models.Remove(model);
     
    99112        int currentRow = rowsEnumerator.Current;
    100113
    101         var filteredEstimates = models.Zip(estimatedValuesEnumerators.Current,
    102           (m, e) => new { Model = m, EstimatedValue = e }).Where(f => modelSelectionPredicate(currentRow, f.Model));
     114        var filteredEstimates = models.Zip(estimatedValuesEnumerators.Current, (m, e) => new { Model = m, EstimatedValue = e })
     115                                      .Where(f => modelSelectionPredicate(currentRow, f.Model))
     116                                      .Select(f => f.EstimatedValue).DefaultIfEmpty(double.NaN);
    103117
    104         yield return filteredEstimates.Select(f => f.EstimatedValue).DefaultIfEmpty(double.NaN).Average();
     118        yield return AggregateEstimatedValues(filteredEstimates);
    105119      }
     120    }
     121
     122    private double AggregateEstimatedValues(IEnumerable<double> estimatedValuesVector) {
     123      if (AverageModelEstimates)
     124        return estimatedValuesVector.Average();
     125      else
     126        return estimatedValuesVector.Sum();
     127    }
     128
     129    public event EventHandler AverageModelEstimatesChanged;
     130    private void OnAverageModelEstimatesChanged() {
     131      var handler = AverageModelEstimatesChanged;
     132      if (handler != null)
     133        handler(this, EventArgs.Empty);
    106134    }
    107135    #endregion
     
    110138    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    111139      foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) {
    112         yield return estimatedValuesVector.Average();
     140        yield return AggregateEstimatedValues(estimatedValuesVector);
    113141      }
    114142    }
Note: See TracChangeset for help on using the changeset viewer.