Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs @ 13701

Last change on this file since 13701 was 13701, checked in by mkommend, 9 years ago

#2590: Fixed minor bug in RegressionEnsembleSolution when no models are present and updated view.

File size: 5.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28
29namespace HeuristicLab.Problems.DataAnalysis {
30  /// <summary>
31  /// Represents regression solutions that contain an ensemble of multiple regression models
32  /// </summary>
33  [StorableClass]
34  [Item("RegressionEnsembleModel", "A regression model that contains an ensemble of multiple regression models")]
35  public sealed class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
36
37    private List<IRegressionModel> models;
38    public IEnumerable<IRegressionModel> Models {
39      get { return new List<IRegressionModel>(models); }
40    }
41
42    [Storable(Name = "Models")]
43    private IEnumerable<IRegressionModel> StorableModels {
44      get { return models; }
45      set { models = value.ToList(); }
46    }
47
48    [Storable]
49    private bool averageModelEstimates = true;
50    public bool AverageModelEstimates {
51      get { return averageModelEstimates; }
52      set {
53        if (averageModelEstimates != value) {
54          averageModelEstimates = value;
55          OnAverageModelEstimatesChanged();
56        }
57      }
58    }
59
60    #region backwards compatiblity 3.3.5
61    [Storable(Name = "models", AllowOneWay = true)]
62    private List<IRegressionModel> OldStorableModels {
63      set { models = value; }
64    }
65    #endregion
66
67    [StorableConstructor]
68    private RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
69    private RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
70      : base(original, cloner) {
71      this.models = original.Models.Select(cloner.Clone).ToList();
72      this.averageModelEstimates = original.averageModelEstimates;
73    }
74    public override IDeepCloneable Clone(Cloner cloner) {
75      return new RegressionEnsembleModel(this, cloner);
76    }
77
78    public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { }
79    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models)
80      : base() {
81      this.name = ItemName;
82      this.description = ItemDescription;
83      this.models = new List<IRegressionModel>(models);
84    }
85
86    #region IRegressionEnsembleModel Members
87    public void Add(IRegressionModel model) {
88      models.Add(model);
89    }
90
91    public void Remove(IRegressionModel model) {
92      models.Remove(model);
93    }
94
95    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
96      var estimatedValuesEnumerators = (from model in models
97                                        select model.GetEstimatedValues(dataset, rows).GetEnumerator())
98                                       .ToList();
99
100      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
101        yield return from enumerator in estimatedValuesEnumerators
102                     select enumerator.Current;
103      }
104    }
105
106    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
107      var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator();
108      var rowsEnumerator = rows.GetEnumerator();
109
110      // aggregate to make sure that MoveNext is called for all enumerators
111      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) {
112        int currentRow = rowsEnumerator.Current;
113
114        var filteredEstimates = models.Zip(estimatedValuesEnumerators.Current, (m, e) => new { Model = m, EstimatedValue = e })
115                                      .Where(f => modelSelectionPredicate(currentRow, f.Model))
116                                      .Select(f => f.EstimatedValue).DefaultIfEmpty(double.NaN);
117
118        yield return AggregateEstimatedValues(filteredEstimates);
119      }
120    }
121
122    private double AggregateEstimatedValues(IEnumerable<double> estimatedValuesVector) {
123      if (AverageModelEstimates)
124        return estimatedValuesVector.Average();
125      else
126        return estimatedValuesVector.Sum();
127    }
128
129    public event EventHandler AverageModelEstimatesChanged;
130    private void OnAverageModelEstimatesChanged() {
131      var handler = AverageModelEstimatesChanged;
132      if (handler != null)
133        handler(this, EventArgs.Empty);
134    }
135    #endregion
136
137    #region IRegressionModel Members
138    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
139      foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) {
140        yield return AggregateEstimatedValues(estimatedValuesVector.DefaultIfEmpty(double.NaN));
141      }
142    }
143
144    public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) {
145      return new RegressionEnsembleSolution(this, new RegressionEnsembleProblemData(problemData));
146    }
147    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
148      return CreateRegressionSolution(problemData);
149    }
150    #endregion
151  }
152}
Note: See TracBrowser for help on using the repository browser.