Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs @ 13834

Last change on this file since 13834 was 13715, checked in by mkommend, 9 years ago

#2590: Implemented an weighted average to ease weights manipulation in regression ensembles and corrected locking in the the weights view.

File size: 8.0 KB
RevLine 
[5662]1#region License Information
2/* HeuristicLab
[12012]3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5662]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[13697]22using System;
[5662]23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28
29namespace HeuristicLab.Problems.DataAnalysis {
30  /// <summary>
31  /// Represents regression solutions that contain an ensemble of multiple regression models
32  /// </summary>
33  [StorableClass]
34  [Item("RegressionEnsembleModel", "A regression model that contains an ensemble of multiple regression models")]
[13700]35  public sealed class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
[5662]36
37    private List<IRegressionModel> models;
38    public IEnumerable<IRegressionModel> Models {
39      get { return new List<IRegressionModel>(models); }
40    }
[6603]41
42    [Storable(Name = "Models")]
43    private IEnumerable<IRegressionModel> StorableModels {
44      get { return models; }
45      set { models = value.ToList(); }
46    }
47
[13704]48    private List<double> modelWeights;
49    public IEnumerable<double> ModelWeights {
50      get { return modelWeights; }
51    }
52
53    [Storable(Name = "ModelWeights")]
54    private IEnumerable<double> StorableModelWeights {
55      get { return modelWeights; }
56      set { modelWeights = value.ToList(); }
57    }
58
[13700]59    [Storable]
60    private bool averageModelEstimates = true;
61    public bool AverageModelEstimates {
62      get { return averageModelEstimates; }
63      set {
64        if (averageModelEstimates != value) {
65          averageModelEstimates = value;
[13704]66          OnChanged();
[13700]67        }
68      }
69    }
70
[6603]71    #region backwards compatiblity 3.3.5
72    [Storable(Name = "models", AllowOneWay = true)]
73    private List<IRegressionModel> OldStorableModels {
74      set { models = value; }
75    }
76    #endregion
77
[13704]78    [StorableHook(HookType.AfterDeserialization)]
79    private void AfterDeserialization() {
80      // BackwardsCompatibility 3.3.14
81      #region Backwards compatible code, remove with 3.4
82      if (modelWeights == null || !modelWeights.Any())
83        modelWeights = new List<double>(models.Select(m => 1.0));
84      #endregion
85    }
86
[5662]87    [StorableConstructor]
[13700]88    private RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
89    private RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
[5662]90      : base(original, cloner) {
[13700]91      this.models = original.Models.Select(cloner.Clone).ToList();
[13704]92      this.modelWeights = new List<double>(original.ModelWeights);
[13700]93      this.averageModelEstimates = original.averageModelEstimates;
[5662]94    }
[13700]95    public override IDeepCloneable Clone(Cloner cloner) {
96      return new RegressionEnsembleModel(this, cloner);
97    }
[6666]98
99    public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { }
[13704]100    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) : this(models, models.Select(m => 1.0)) { }
101    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models, IEnumerable<double> modelWeights)
[5662]102      : base() {
103      this.name = ItemName;
104      this.description = ItemDescription;
[13704]105
106
[5662]107      this.models = new List<IRegressionModel>(models);
[13704]108      this.modelWeights = new List<double>(modelWeights);
[5662]109    }
110
[6520]111    public void Add(IRegressionModel model) {
[13704]112      Add(model, 1.0);
113    }
114    public void Add(IRegressionModel model, double weight) {
[6520]115      models.Add(model);
[13704]116      modelWeights.Add(weight);
117      OnChanged();
[6520]118    }
[13700]119
[13704]120    public void AddRange(IEnumerable<IRegressionModel> models) {
121      AddRange(models, models.Select(m => 1.0));
122    }
123    public void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) {
124      this.models.AddRange(models);
125      modelWeights.AddRange(weights);
126      OnChanged();
127    }
128
[6612]129    public void Remove(IRegressionModel model) {
[13704]130      var index = models.IndexOf(model);
131      models.RemoveAt(index);
132      modelWeights.RemoveAt(index);
133      OnChanged();
[6612]134    }
[13704]135    public void RemoveRange(IEnumerable<IRegressionModel> models) {
136      foreach (var model in models) {
137        var index = this.models.IndexOf(model);
138        this.models.RemoveAt(index);
139        modelWeights.RemoveAt(index);
140      }
141      OnChanged();
142    }
[6520]143
[13704]144    public double GetModelWeight(IRegressionModel model) {
145      var index = models.IndexOf(model);
146      return modelWeights[index];
147    }
148    public void SetModelWeight(IRegressionModel model, double weight) {
149      var index = models.IndexOf(model);
150      modelWeights[index] = weight;
151      OnChanged();
152    }
153
[13715]154    #region evaluation
[12509]155    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
[5662]156      var estimatedValuesEnumerators = (from model in models
[13705]157                                        let weight = GetModelWeight(model)
158                                        select model.GetEstimatedValues(dataset, rows).Select(e => weight * e)
159                                        .GetEnumerator()).ToList();
[5662]160
161      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
162        yield return from enumerator in estimatedValuesEnumerators
163                     select enumerator.Current;
164      }
165    }
166
[13715]167    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
168      double weightsSum = modelWeights.Sum();
169      var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)
170                            select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum();
171
172      if (AverageModelEstimates)
173        return summedEstimates.Select(v => v / weightsSum);
174      else
175        return summedEstimates;
176
177    }
178
[13697]179    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
180      var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator();
181      var rowsEnumerator = rows.GetEnumerator();
182
183      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) {
[13715]184        var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator();
[13697]185        int currentRow = rowsEnumerator.Current;
[13715]186        double weightsSum = 0.0;
187        double filteredEstimatesSum = 0.0;
[13697]188
[13715]189        for (int m = 0; m < models.Count; m++) {
190          estimatedValueEnumerator.MoveNext();
191          var model = models[m];
192          if (!modelSelectionPredicate(currentRow, model)) continue;
[13697]193
[13715]194          filteredEstimatesSum += estimatedValueEnumerator.Current;
195          weightsSum += modelWeights[m];
196        }
197
198        if (AverageModelEstimates)
199          yield return filteredEstimatesSum / weightsSum;
200        else
201          yield return filteredEstimatesSum;
[13697]202      }
203    }
[13700]204
[13715]205    #endregion
[13700]206
[13704]207    public event EventHandler Changed;
208    private void OnChanged() {
209      var handler = Changed;
[13700]210      if (handler != null)
211        handler(this, EventArgs.Empty);
212    }
[5662]213
214
[6603]215    public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) {
[13698]216      return new RegressionEnsembleSolution(this, new RegressionEnsembleProblemData(problemData));
[6603]217    }
218    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
219      return CreateRegressionSolution(problemData);
220    }
[5662]221  }
222}
Note: See TracBrowser for help on using the repository browser.