#region License Information
/* HeuristicLab
* Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using System;
using HeuristicLab.Data;
namespace HeuristicLab.Problems.DataAnalysis {
///
/// Represents regression solutions that contain an ensemble of multiple regression models
///
[StorableClass]
[Item("Regression Ensemble Solution", "A regression solution that contains an ensemble of multiple regression models")]
// [Creatable("Data Analysis")]
public class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {
public new IRegressionEnsembleModel Model {
get { return (IRegressionEnsembleModel)base.Model; }
}
[Storable]
private Dictionary trainingPartitions;
[Storable]
private Dictionary testPartitions;
[StorableConstructor]
protected RegressionEnsembleSolution(bool deserializing) : base(deserializing) { }
protected RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner)
: base(original, cloner) {
}
public RegressionEnsembleSolution(IEnumerable models, IRegressionProblemData problemData)
: base(new RegressionEnsembleModel(models), problemData) {
trainingPartitions = new Dictionary();
testPartitions = new Dictionary();
foreach (var model in models) {
trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone();
testPartitions[model] = (IntRange)problemData.TestPartition.Clone();
}
RecalculateResults();
}
public RegressionEnsembleSolution(IEnumerable models, IRegressionProblemData problemData, IEnumerable trainingPartitions, IEnumerable testPartitions)
: base(new RegressionEnsembleModel(models), problemData) {
this.trainingPartitions = new Dictionary();
this.testPartitions = new Dictionary();
var modelEnumerator = models.GetEnumerator();
var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
var testPartitionEnumerator = testPartitions.GetEnumerator();
while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone();
this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone();
}
if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
throw new ArgumentException();
}
}
public override IDeepCloneable Clone(Cloner cloner) {
return new RegressionEnsembleSolution(this, cloner);
}
public override IEnumerable EstimatedTrainingValues {
get {
var estimatedValuesEnumerators = (from model in Model.Models
select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, ProblemData.TestIndizes).GetEnumerator() })
.ToList();
var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.All(en => en.EstimatedValuesEnumerator.MoveNext())) {
int currentRow = rowsEnumerator.Current;
var selectedEnumerators = from pair in estimatedValuesEnumerators
where trainingPartitions == null || !trainingPartitions.ContainsKey(pair.Model) ||
(trainingPartitions[pair.Model].Start >= currentRow && trainingPartitions[pair.Model].End < currentRow)
select pair.EstimatedValuesEnumerator;
yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
}
}
}
public override IEnumerable EstimatedTestValues {
get {
var estimatedValuesEnumerators = (from model in Model.Models
select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, ProblemData.TestIndizes).GetEnumerator() })
.ToList();
var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.All(en => en.EstimatedValuesEnumerator.MoveNext())) {
int currentRow = rowsEnumerator.Current;
var selectedEnumerators = from pair in estimatedValuesEnumerators
where testPartitions == null || !testPartitions.ContainsKey(pair.Model) ||
(testPartitions[pair.Model].Start >= currentRow && testPartitions[pair.Model].End < currentRow)
select pair.EstimatedValuesEnumerator;
yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
}
}
}
public override IEnumerable GetEstimatedValues(IEnumerable rows) {
return from xs in GetEstimatedValueVectors(ProblemData.Dataset, rows)
select AggregateEstimatedValues(xs);
}
public IEnumerable> GetEstimatedValueVectors(Dataset dataset, IEnumerable rows) {
var estimatedValuesEnumerators = (from model in Model.Models
select model.GetEstimatedValues(dataset, rows).GetEnumerator())
.ToList();
while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
yield return from enumerator in estimatedValuesEnumerators
select enumerator.Current;
}
}
private double AggregateEstimatedValues(IEnumerable estimatedValues) {
return estimatedValues.Average();
}
//[Storable]
//private string name;
//public string Name {
// get {
// return name;
// }
// set {
// if (value != null && value != name) {
// var cancelEventArgs = new CancelEventArgs(value);
// OnNameChanging(cancelEventArgs);
// if (cancelEventArgs.Cancel == false) {
// name = value;
// OnNamedChanged(EventArgs.Empty);
// }
// }
// }
//}
//public bool CanChangeName {
// get { return true; }
//}
//[Storable]
//private string description;
//public string Description {
// get {
// return description;
// }
// set {
// if (value != null && value != description) {
// description = value;
// OnDescriptionChanged(EventArgs.Empty);
// }
// }
//}
//public bool CanChangeDescription {
// get { return true; }
//}
//#region events
//public event EventHandler> NameChanging;
//private void OnNameChanging(CancelEventArgs cancelEventArgs) {
// var listener = NameChanging;
// if (listener != null) listener(this, cancelEventArgs);
//}
//public event EventHandler NameChanged;
//private void OnNamedChanged(EventArgs e) {
// var listener = NameChanged;
// if (listener != null) listener(this, e);
//}
//public event EventHandler DescriptionChanged;
//private void OnDescriptionChanged(EventArgs e) {
// var listener = DescriptionChanged;
// if (listener != null) listener(this, e);
//}
// #endregion
}
}