Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2719_HeuristicLab.DatastreamAnalysis/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RatedRegressionEnsembleModel.cs @ 17399

Last change on this file since 17399 was 14710, checked in by jzenisek, 8 years ago

#2719 implemented ensemble model rating by introducing the new type RatedEnsembleModel; introduced performance indicator calculation in results;

File size: 9.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29
30namespace HeuristicLab.Problems.DataAnalysis {
31  /// <summary>
32  /// Represents regression solutions that contain an ensemble of multiple regression models
33  /// </summary>
34  [StorableClass]
35  [Item("Rated Regression Ensemble Model", "A regression model that contains an ensemble of multiple regression models")]
36  [Creatable(CreatableAttribute.Categories.DataAnalysisEnsembles, Priority = 100)]
37  public sealed class RatedRegressionEnsembleModel : RegressionModel, IRegressionEnsembleModel {
38    public override IEnumerable<string> VariablesUsedForPrediction {
39      get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); }
40    }
41
42    private List<IRegressionModel> models;
43    public IEnumerable<IRegressionModel> Models {
44      get { return new List<IRegressionModel>(models); }
45    }
46
47    [Storable(Name = "Models")]
48    private IEnumerable<IRegressionModel> StorableModels {
49      get { return models; }
50      set { models = value.ToList(); }
51    }
52
53    private List<double> modelWeights;
54    public IEnumerable<double> ModelWeights {
55      get { return modelWeights; }
56    }
57
58    [Storable(Name = "ModelWeights")]
59    private IEnumerable<double> StorableModelWeights {
60      get { return modelWeights; }
61      set { modelWeights = value.ToList(); }
62    }
63
64    private DoubleRange qualityThreshold;
65    public DoubleRange QualityThreshold {
66      get { return qualityThreshold; }     
67      set { qualityThreshold = value; }
68    }
69    [Storable(Name = "QualityThreshold")]
70    private DoubleRange StorableQualityThreshold {
71      get { return qualityThreshold; }
72      set { qualityThreshold = value; }
73    }
74
75    private DoubleRange confidenceThreshold;
76    public DoubleRange ConfidenceThreshold
77    {
78      get { return confidenceThreshold; }
79      set { confidenceThreshold = value; }
80    }
81    [Storable(Name = "QualityThreshold")]
82    private DoubleRange StorableConfidenceThreshold
83    {
84      get { return confidenceThreshold; }
85      set { confidenceThreshold = value; }
86    }
87
88    [Storable]
89    private bool averageModelEstimates = true;
90    public bool AverageModelEstimates {
91      get { return averageModelEstimates; }
92      set {
93        if (averageModelEstimates != value) {
94          averageModelEstimates = value;
95          OnChanged();
96        }
97      }
98    }
99
100    #region backwards compatiblity 3.3.5
101    [Storable(Name = "models", AllowOneWay = true)]
102    private List<IRegressionModel> OldStorableModels {
103      set { models = value; }
104    }
105    #endregion
106
107    [StorableHook(HookType.AfterDeserialization)]
108    private void AfterDeserialization() {
109      // BackwardsCompatibility 3.3.14
110      #region Backwards compatible code, remove with 3.4
111      if (modelWeights == null || !modelWeights.Any())
112        modelWeights = new List<double>(models.Select(m => 1.0));
113      #endregion
114    }
115
116    [StorableConstructor]
117    private RatedRegressionEnsembleModel(bool deserializing) : base(deserializing) { }
118    private RatedRegressionEnsembleModel(RatedRegressionEnsembleModel original, Cloner cloner)
119      : base(original, cloner) {
120      this.models = original.Models.Select(cloner.Clone).ToList();
121      this.modelWeights = new List<double>(original.ModelWeights);
122      this.qualityThreshold = cloner.Clone(original.qualityThreshold);
123      this.confidenceThreshold = cloner.Clone(original.confidenceThreshold);
124      this.averageModelEstimates = original.averageModelEstimates;
125    }
126    public override IDeepCloneable Clone(Cloner cloner) {
127      return new RatedRegressionEnsembleModel(this, cloner);
128    }
129
130    public RatedRegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { }
131    public RatedRegressionEnsembleModel(IEnumerable<IRegressionModel> models) : this(models, models.Select(m => 1.0)) { }
132    public RatedRegressionEnsembleModel(IEnumerable<IRegressionModel> models, IEnumerable<double> modelWeights)
133      : base(string.Empty) {
134      this.name = ItemName;
135      this.description = ItemDescription;
136
137      this.models = new List<IRegressionModel>(models);
138      this.modelWeights = new List<double>(modelWeights);
139
140      if (this.models.Any()) this.TargetVariable = this.models.First().TargetVariable;
141    }
142
143    public void Add(IRegressionModel model) {
144      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
145      Add(model, 1.0);
146    }
147    public void Add(IRegressionModel model, double weight) {
148      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
149
150      models.Add(model);
151      modelWeights.Add(weight);
152      OnChanged();
153    }
154
155    public void AddRange(IEnumerable<IRegressionModel> models) {
156      AddRange(models, models.Select(m => 1.0));
157    }
158    public void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) {
159      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = models.First().TargetVariable;
160
161      this.models.AddRange(models);
162      modelWeights.AddRange(weights);
163      OnChanged();
164    }
165
166    public void Remove(IRegressionModel model) {
167      var index = models.IndexOf(model);
168      models.RemoveAt(index);
169      modelWeights.RemoveAt(index);
170
171      if (!models.Any()) TargetVariable = string.Empty;
172      OnChanged();
173    }
174    public void RemoveRange(IEnumerable<IRegressionModel> models) {
175      foreach (var model in models) {
176        var index = this.models.IndexOf(model);
177        this.models.RemoveAt(index);
178        modelWeights.RemoveAt(index);
179      }
180
181      if (!models.Any()) TargetVariable = string.Empty;
182      OnChanged();
183    }
184
185    public double GetModelWeight(IRegressionModel model) {
186      var index = models.IndexOf(model);
187      return modelWeights[index];
188    }
189    public void SetModelWeight(IRegressionModel model, double weight) {
190      var index = models.IndexOf(model);
191      modelWeights[index] = weight;
192      OnChanged();
193    }
194
195    #region evaluation
196    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
197      var estimatedValuesEnumerators = (from model in models
198                                        let weight = GetModelWeight(model)
199                                        select model.GetEstimatedValues(dataset, rows).Select(e => weight * e)
200                                        .GetEnumerator()).ToList();
201
202      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
203        yield return from enumerator in estimatedValuesEnumerators
204                     select enumerator.Current;
205      }
206    }
207
208    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
209      double weightsSum = modelWeights.Sum();
210      var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)
211                            select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum();
212
213      if (AverageModelEstimates)
214        return summedEstimates.Select(v => v / weightsSum);
215      else
216        return summedEstimates;
217
218    }
219
220    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
221      var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator();
222      var rowsEnumerator = rows.GetEnumerator();
223
224      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) {
225        var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator();
226        int currentRow = rowsEnumerator.Current;
227        double weightsSum = 0.0;
228        double filteredEstimatesSum = 0.0;
229
230        for (int m = 0; m < models.Count; m++) {
231          estimatedValueEnumerator.MoveNext();
232          var model = models[m];
233          if (!modelSelectionPredicate(currentRow, model)) continue;
234
235          filteredEstimatesSum += estimatedValueEnumerator.Current;
236          weightsSum += modelWeights[m];
237        }
238
239        if (AverageModelEstimates)
240          yield return filteredEstimatesSum / weightsSum;
241        else
242          yield return filteredEstimatesSum;
243      }
244    }
245
246    #endregion
247
248    public event EventHandler Changed;
249    private void OnChanged() {
250      var handler = Changed;
251      if (handler != null)
252        handler(this, EventArgs.Empty);
253    }
254
255
256    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
257      return new RegressionEnsembleSolution(this, new RegressionEnsembleProblemData(problemData));
258    }
259  }
260}
Note: See TracBrowser for help on using the repository browser.