Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.DatastreamAnalysis/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs @ 14491

Last change on this file since 14491 was 14491, checked in by jzenisek, 7 years ago

#2719 adapted RegressionEnsembleModel in order that it is creatable from now on

File size: 8.6 KB
RevLine 
[5662]1#region License Information
2/* HeuristicLab
[14185]3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5662]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[13697]22using System;
[5662]23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28
29namespace HeuristicLab.Problems.DataAnalysis {
30  /// <summary>
31  /// Represents regression solutions that contain an ensemble of multiple regression models
32  /// </summary>
33  [StorableClass]
[14491]34  [Item("Regression Ensemble Model", "A regression model that contains an ensemble of multiple regression models")]
35  [Creatable(CreatableAttribute.Categories.DataAnalysisEnsembles, Priority = 100)]
[13941]36  public sealed class RegressionEnsembleModel : RegressionModel, IRegressionEnsembleModel {
37    public override IEnumerable<string> VariablesUsedForPrediction {
[13921]38      get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); }
39    }
[5662]40
41    private List<IRegressionModel> models;
42    public IEnumerable<IRegressionModel> Models {
43      get { return new List<IRegressionModel>(models); }
44    }
[6603]45
46    [Storable(Name = "Models")]
47    private IEnumerable<IRegressionModel> StorableModels {
48      get { return models; }
49      set { models = value.ToList(); }
50    }
51
[13704]52    private List<double> modelWeights;
53    public IEnumerable<double> ModelWeights {
54      get { return modelWeights; }
55    }
56
57    [Storable(Name = "ModelWeights")]
58    private IEnumerable<double> StorableModelWeights {
59      get { return modelWeights; }
60      set { modelWeights = value.ToList(); }
61    }
62
[13700]63    [Storable]
64    private bool averageModelEstimates = true;
65    public bool AverageModelEstimates {
66      get { return averageModelEstimates; }
67      set {
68        if (averageModelEstimates != value) {
69          averageModelEstimates = value;
[13704]70          OnChanged();
[13700]71        }
72      }
73    }
74
[6603]75    #region backwards compatiblity 3.3.5
76    [Storable(Name = "models", AllowOneWay = true)]
77    private List<IRegressionModel> OldStorableModels {
78      set { models = value; }
79    }
80    #endregion
81
[13704]82    [StorableHook(HookType.AfterDeserialization)]
83    private void AfterDeserialization() {
84      // BackwardsCompatibility 3.3.14
85      #region Backwards compatible code, remove with 3.4
86      if (modelWeights == null || !modelWeights.Any())
87        modelWeights = new List<double>(models.Select(m => 1.0));
88      #endregion
89    }
90
[5662]91    [StorableConstructor]
[13700]92    private RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
93    private RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
[5662]94      : base(original, cloner) {
[13700]95      this.models = original.Models.Select(cloner.Clone).ToList();
[13704]96      this.modelWeights = new List<double>(original.ModelWeights);
[13700]97      this.averageModelEstimates = original.averageModelEstimates;
[5662]98    }
[13700]99    public override IDeepCloneable Clone(Cloner cloner) {
100      return new RegressionEnsembleModel(this, cloner);
101    }
[6666]102
103    public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { }
[13704]104    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) : this(models, models.Select(m => 1.0)) { }
105    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models, IEnumerable<double> modelWeights)
[13941]106      : base(string.Empty) {
[5662]107      this.name = ItemName;
108      this.description = ItemDescription;
[13704]109
[5662]110      this.models = new List<IRegressionModel>(models);
[13704]111      this.modelWeights = new List<double>(modelWeights);
[13941]112
113      if (this.models.Any()) this.TargetVariable = this.models.First().TargetVariable;
[5662]114    }
115
[6520]116    public void Add(IRegressionModel model) {
[13941]117      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
[13704]118      Add(model, 1.0);
119    }
120    public void Add(IRegressionModel model, double weight) {
[13941]121      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
122
[6520]123      models.Add(model);
[13704]124      modelWeights.Add(weight);
125      OnChanged();
[6520]126    }
[13700]127
[13704]128    public void AddRange(IEnumerable<IRegressionModel> models) {
129      AddRange(models, models.Select(m => 1.0));
130    }
131    public void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) {
[13941]132      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = models.First().TargetVariable;
133
[13704]134      this.models.AddRange(models);
135      modelWeights.AddRange(weights);
136      OnChanged();
137    }
138
[6612]139    public void Remove(IRegressionModel model) {
[13704]140      var index = models.IndexOf(model);
141      models.RemoveAt(index);
142      modelWeights.RemoveAt(index);
[13941]143
144      if (!models.Any()) TargetVariable = string.Empty;
[13704]145      OnChanged();
[6612]146    }
[13704]147    public void RemoveRange(IEnumerable<IRegressionModel> models) {
148      foreach (var model in models) {
149        var index = this.models.IndexOf(model);
150        this.models.RemoveAt(index);
151        modelWeights.RemoveAt(index);
152      }
[13941]153
154      if (!models.Any()) TargetVariable = string.Empty;
[13704]155      OnChanged();
156    }
[6520]157
[13704]158    public double GetModelWeight(IRegressionModel model) {
159      var index = models.IndexOf(model);
160      return modelWeights[index];
161    }
162    public void SetModelWeight(IRegressionModel model, double weight) {
163      var index = models.IndexOf(model);
164      modelWeights[index] = weight;
165      OnChanged();
166    }
167
[13715]168    #region evaluation
[12509]169    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
[5662]170      var estimatedValuesEnumerators = (from model in models
[13705]171                                        let weight = GetModelWeight(model)
172                                        select model.GetEstimatedValues(dataset, rows).Select(e => weight * e)
173                                        .GetEnumerator()).ToList();
[5662]174
175      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
176        yield return from enumerator in estimatedValuesEnumerators
177                     select enumerator.Current;
178      }
179    }
180
[13941]181    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
[13715]182      double weightsSum = modelWeights.Sum();
183      var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)
184                            select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum();
185
186      if (AverageModelEstimates)
187        return summedEstimates.Select(v => v / weightsSum);
188      else
189        return summedEstimates;
190
191    }
192
[13697]193    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
194      var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator();
195      var rowsEnumerator = rows.GetEnumerator();
196
197      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) {
[13715]198        var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator();
[13697]199        int currentRow = rowsEnumerator.Current;
[13715]200        double weightsSum = 0.0;
201        double filteredEstimatesSum = 0.0;
[13697]202
[13715]203        for (int m = 0; m < models.Count; m++) {
204          estimatedValueEnumerator.MoveNext();
205          var model = models[m];
206          if (!modelSelectionPredicate(currentRow, model)) continue;
[13697]207
[13715]208          filteredEstimatesSum += estimatedValueEnumerator.Current;
209          weightsSum += modelWeights[m];
210        }
211
212        if (AverageModelEstimates)
213          yield return filteredEstimatesSum / weightsSum;
214        else
215          yield return filteredEstimatesSum;
[13697]216      }
217    }
[13700]218
[13715]219    #endregion
[13700]220
[13704]221    public event EventHandler Changed;
222    private void OnChanged() {
223      var handler = Changed;
[13700]224      if (handler != null)
225        handler(this, EventArgs.Empty);
226    }
[5662]227
228
[13941]229    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
[13698]230      return new RegressionEnsembleSolution(this, new RegressionEnsembleProblemData(problemData));
[6603]231    }
[5662]232  }
233}
Note: See TracBrowser for help on using the repository browser.