#region License Information
/* HeuristicLab
* Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
namespace HeuristicLab.Problems.DataAnalysis {
///
/// Represents regression solutions that contain an ensemble of multiple regression models
///
[StorableClass]
[Item("RegressionEnsembleModel", "A regression model that contains an ensemble of multiple regression models")]
public sealed class RegressionEnsembleModel : RegressionModel, IRegressionEnsembleModel {
public override IEnumerable VariablesUsedForPrediction {
get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); }
}
private List models;
public IEnumerable Models {
get { return new List(models); }
}
[Storable(Name = "Models")]
private IEnumerable StorableModels {
get { return models; }
set { models = value.ToList(); }
}
private List modelWeights;
public IEnumerable ModelWeights {
get { return modelWeights; }
}
[Storable(Name = "ModelWeights")]
private IEnumerable StorableModelWeights {
get { return modelWeights; }
set { modelWeights = value.ToList(); }
}
[Storable]
private bool averageModelEstimates = true;
public bool AverageModelEstimates {
get { return averageModelEstimates; }
set {
if (averageModelEstimates != value) {
averageModelEstimates = value;
OnChanged();
}
}
}
#region backwards compatiblity 3.3.5
[Storable(Name = "models", AllowOneWay = true)]
private List OldStorableModels {
set { models = value; }
}
#endregion
[StorableHook(HookType.AfterDeserialization)]
private void AfterDeserialization() {
// BackwardsCompatibility 3.3.14
#region Backwards compatible code, remove with 3.4
if (modelWeights == null || !modelWeights.Any())
modelWeights = new List(models.Select(m => 1.0));
#endregion
}
[StorableConstructor]
private RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
private RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
: base(original, cloner) {
this.models = original.Models.Select(cloner.Clone).ToList();
this.modelWeights = new List(original.ModelWeights);
this.averageModelEstimates = original.averageModelEstimates;
}
public override IDeepCloneable Clone(Cloner cloner) {
return new RegressionEnsembleModel(this, cloner);
}
public RegressionEnsembleModel() : this(Enumerable.Empty()) { }
public RegressionEnsembleModel(IEnumerable models) : this(models, models.Select(m => 1.0)) { }
public RegressionEnsembleModel(IEnumerable models, IEnumerable modelWeights)
: base(string.Empty) {
this.name = ItemName;
this.description = ItemDescription;
this.models = new List(models);
this.modelWeights = new List(modelWeights);
if (this.models.Any()) this.TargetVariable = this.models.First().TargetVariable;
}
public void Add(IRegressionModel model) {
if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
Add(model, 1.0);
}
public void Add(IRegressionModel model, double weight) {
if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
models.Add(model);
modelWeights.Add(weight);
OnChanged();
}
public void AddRange(IEnumerable models) {
AddRange(models, models.Select(m => 1.0));
}
public void AddRange(IEnumerable models, IEnumerable weights) {
if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = models.First().TargetVariable;
this.models.AddRange(models);
modelWeights.AddRange(weights);
OnChanged();
}
public void Remove(IRegressionModel model) {
var index = models.IndexOf(model);
models.RemoveAt(index);
modelWeights.RemoveAt(index);
if (!models.Any()) TargetVariable = string.Empty;
OnChanged();
}
public void RemoveRange(IEnumerable models) {
foreach (var model in models) {
var index = this.models.IndexOf(model);
this.models.RemoveAt(index);
modelWeights.RemoveAt(index);
}
if (!models.Any()) TargetVariable = string.Empty;
OnChanged();
}
public double GetModelWeight(IRegressionModel model) {
var index = models.IndexOf(model);
return modelWeights[index];
}
public void SetModelWeight(IRegressionModel model, double weight) {
var index = models.IndexOf(model);
modelWeights[index] = weight;
OnChanged();
}
#region evaluation
public IEnumerable> GetEstimatedValueVectors(IDataset dataset, IEnumerable rows) {
var estimatedValuesEnumerators = (from model in models
let weight = GetModelWeight(model)
select model.GetEstimatedValues(dataset, rows).Select(e => weight * e)
.GetEnumerator()).ToList();
while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
yield return from enumerator in estimatedValuesEnumerators
select enumerator.Current;
}
}
public override IEnumerable GetEstimatedValues(IDataset dataset, IEnumerable rows) {
double weightsSum = modelWeights.Sum();
var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)
select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum();
if (AverageModelEstimates)
return summedEstimates.Select(v => v / weightsSum);
else
return summedEstimates;
}
public IEnumerable GetEstimatedValues(IDataset dataset, IEnumerable rows, Func modelSelectionPredicate) {
var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator();
var rowsEnumerator = rows.GetEnumerator();
while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) {
var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator();
int currentRow = rowsEnumerator.Current;
double weightsSum = 0.0;
double filteredEstimatesSum = 0.0;
for (int m = 0; m < models.Count; m++) {
estimatedValueEnumerator.MoveNext();
var model = models[m];
if (!modelSelectionPredicate(currentRow, model)) continue;
filteredEstimatesSum += estimatedValueEnumerator.Current;
weightsSum += modelWeights[m];
}
if (AverageModelEstimates)
yield return filteredEstimatesSum / weightsSum;
else
yield return filteredEstimatesSum;
}
}
#endregion
public event EventHandler Changed;
private void OnChanged() {
var handler = Changed;
if (handler != null)
handler(this, EventArgs.Empty);
}
public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
return new RegressionEnsembleSolution(this, new RegressionEnsembleProblemData(problemData));
}
}
}