Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/23/17 00:52:14 (7 years ago)
Author:
abeham
Message:

#2258: merged r13329:14000 from trunk into branch

Location:
branches/Async
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/Async

  • branches/Async/HeuristicLab.Problems.DataAnalysis

  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r12509 r15280  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     
    3233  [StorableClass]
    3334  [Item("RegressionEnsembleModel", "A regression model that contains an ensemble of multiple regression models")]
    34   public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
     35  public sealed class RegressionEnsembleModel : RegressionModel, IRegressionEnsembleModel {
     36    public override IEnumerable<string> VariablesUsedForPrediction {
     37      get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); }
     38    }
    3539
    3640    private List<IRegressionModel> models;
     
    4549    }
    4650
     51    private List<double> modelWeights;
     52    public IEnumerable<double> ModelWeights {
     53      get { return modelWeights; }
     54    }
     55
     56    [Storable(Name = "ModelWeights")]
     57    private IEnumerable<double> StorableModelWeights {
     58      get { return modelWeights; }
     59      set { modelWeights = value.ToList(); }
     60    }
     61
     62    [Storable]
     63    private bool averageModelEstimates = true;
     64    public bool AverageModelEstimates {
     65      get { return averageModelEstimates; }
     66      set {
     67        if (averageModelEstimates != value) {
     68          averageModelEstimates = value;
     69          OnChanged();
     70        }
     71      }
     72    }
     73
    4774    #region backwards compatiblity 3.3.5
    4875    [Storable(Name = "models", AllowOneWay = true)]
     
    5279    #endregion
    5380
     81    [StorableHook(HookType.AfterDeserialization)]
     82    private void AfterDeserialization() {
     83      // BackwardsCompatibility 3.3.14
     84      #region Backwards compatible code, remove with 3.4
     85      if (modelWeights == null || !modelWeights.Any())
     86        modelWeights = new List<double>(models.Select(m => 1.0));
     87      #endregion
     88    }
     89
    5490    [StorableConstructor]
    55     protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
    56     protected RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
     91    private RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
     92    private RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
    5793      : base(original, cloner) {
    58       this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
     94      this.models = original.Models.Select(cloner.Clone).ToList();
     95      this.modelWeights = new List<double>(original.ModelWeights);
     96      this.averageModelEstimates = original.averageModelEstimates;
     97    }
     98    public override IDeepCloneable Clone(Cloner cloner) {
     99      return new RegressionEnsembleModel(this, cloner);
    59100    }
    60101
    61102    public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { }
    62     public RegressionEnsembleModel(IEnumerable<IRegressionModel> models)
    63       : base() {
     103    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) : this(models, models.Select(m => 1.0)) { }
     104    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models, IEnumerable<double> modelWeights)
     105      : base(string.Empty) {
    64106      this.name = ItemName;
    65107      this.description = ItemDescription;
     108
    66109      this.models = new List<IRegressionModel>(models);
    67     }
    68 
    69     public override IDeepCloneable Clone(Cloner cloner) {
    70       return new RegressionEnsembleModel(this, cloner);
    71     }
    72 
    73     #region IRegressionEnsembleModel Members
     110      this.modelWeights = new List<double>(modelWeights);
     111
     112      if (this.models.Any()) this.TargetVariable = this.models.First().TargetVariable;
     113    }
    74114
    75115    public void Add(IRegressionModel model) {
     116      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
     117      Add(model, 1.0);
     118    }
     119    public void Add(IRegressionModel model, double weight) {
     120      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
     121
    76122      models.Add(model);
    77     }
     123      modelWeights.Add(weight);
     124      OnChanged();
     125    }
     126
     127    public void AddRange(IEnumerable<IRegressionModel> models) {
     128      AddRange(models, models.Select(m => 1.0));
     129    }
     130    public void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) {
     131      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = models.First().TargetVariable;
     132
     133      this.models.AddRange(models);
     134      modelWeights.AddRange(weights);
     135      OnChanged();
     136    }
     137
    78138    public void Remove(IRegressionModel model) {
    79       models.Remove(model);
    80     }
    81 
     139      var index = models.IndexOf(model);
     140      models.RemoveAt(index);
     141      modelWeights.RemoveAt(index);
     142
     143      if (!models.Any()) TargetVariable = string.Empty;
     144      OnChanged();
     145    }
     146    public void RemoveRange(IEnumerable<IRegressionModel> models) {
     147      foreach (var model in models) {
     148        var index = this.models.IndexOf(model);
     149        this.models.RemoveAt(index);
     150        modelWeights.RemoveAt(index);
     151      }
     152
     153      if (!models.Any()) TargetVariable = string.Empty;
     154      OnChanged();
     155    }
     156
     157    public double GetModelWeight(IRegressionModel model) {
     158      var index = models.IndexOf(model);
     159      return modelWeights[index];
     160    }
     161    public void SetModelWeight(IRegressionModel model, double weight) {
     162      var index = models.IndexOf(model);
     163      modelWeights[index] = weight;
     164      OnChanged();
     165    }
     166
     167    #region evaluation
    82168    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
    83169      var estimatedValuesEnumerators = (from model in models
    84                                         select model.GetEstimatedValues(dataset, rows).GetEnumerator())
    85                                        .ToList();
     170                                        let weight = GetModelWeight(model)
     171                                        select model.GetEstimatedValues(dataset, rows).Select(e => weight * e)
     172                                        .GetEnumerator()).ToList();
    86173
    87174      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
     
    91178    }
    92179
     180    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
     181      double weightsSum = modelWeights.Sum();
     182      var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)
     183                            select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum();
     184
     185      if (AverageModelEstimates)
     186        return summedEstimates.Select(v => v / weightsSum);
     187      else
     188        return summedEstimates;
     189
     190    }
     191
     192    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
     193      var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator();
     194      var rowsEnumerator = rows.GetEnumerator();
     195
     196      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) {
     197        var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator();
     198        int currentRow = rowsEnumerator.Current;
     199        double weightsSum = 0.0;
     200        double filteredEstimatesSum = 0.0;
     201
     202        for (int m = 0; m < models.Count; m++) {
     203          estimatedValueEnumerator.MoveNext();
     204          var model = models[m];
     205          if (!modelSelectionPredicate(currentRow, model)) continue;
     206
     207          filteredEstimatesSum += estimatedValueEnumerator.Current;
     208          weightsSum += modelWeights[m];
     209        }
     210
     211        if (AverageModelEstimates)
     212          yield return filteredEstimatesSum / weightsSum;
     213        else
     214          yield return filteredEstimatesSum;
     215      }
     216    }
     217
    93218    #endregion
    94219
    95     #region IRegressionModel Members
    96 
    97     public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    98       foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) {
    99         yield return estimatedValuesVector.Average();
    100       }
    101     }
    102 
    103     public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    104       return new RegressionEnsembleSolution(this.Models, new RegressionEnsembleProblemData(problemData));
    105     }
    106     IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
    107       return CreateRegressionSolution(problemData);
    108     }
    109 
    110     #endregion
     220    public event EventHandler Changed;
     221    private void OnChanged() {
     222      var handler = Changed;
     223      if (handler != null)
     224        handler(this, EventArgs.Empty);
     225    }
     226
     227
     228    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     229      return new RegressionEnsembleSolution(this, new RegressionEnsembleProblemData(problemData));
     230    }
    111231  }
    112232}
Note: See TracChangeset for help on using the changeset viewer.