#region License Information
/* HeuristicLab
* Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Collections;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
namespace HeuristicLab.Problems.DataAnalysis {
///
/// Represents classification solutions that contain an ensemble of multiple classification models
///
[StorableClass]
[Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")]
[Creatable("Data Analysis - Ensembles")]
public sealed class ClassificationEnsembleSolution : ClassificationSolutionBase, IClassificationEnsembleSolution {
private readonly Dictionary trainingEvaluationCache = new Dictionary();
private readonly Dictionary testEvaluationCache = new Dictionary();
private readonly Dictionary evaluationCache = new Dictionary();
public new IClassificationEnsembleModel Model {
get { return (IClassificationEnsembleModel)base.Model; }
}
public new ClassificationEnsembleProblemData ProblemData {
get { return (ClassificationEnsembleProblemData)base.ProblemData; }
set { base.ProblemData = value; }
}
private readonly CheckedItemCollection classificationSolutions;
public ICheckedItemCollection ClassificationSolutions {
get { return classificationSolutions; }
}
private IClassificationEnsembleSolutionWeightCalculator weightCalculator;
public IClassificationEnsembleSolutionWeightCalculator WeightCalculator {
set {
if (value != null) {
weightCalculator = value;
if (!ProblemData.IsEmpty) {
RecalculateResults();
}
}
}
get { return weightCalculator; }
}
[Storable]
private Dictionary trainingPartitions;
[Storable]
private Dictionary testPartitions;
[StorableConstructor]
private ClassificationEnsembleSolution(bool deserializing)
: base(deserializing) {
classificationSolutions = new CheckedItemCollection();
}
[StorableHook(HookType.AfterDeserialization)]
private void AfterDeserialization() {
foreach (var model in Model.Models) {
IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone();
problemData.TrainingPartition.Start = trainingPartitions[model].Start;
problemData.TrainingPartition.End = trainingPartitions[model].End;
problemData.TestPartition.Start = testPartitions[model].Start;
problemData.TestPartition.End = testPartitions[model].End;
classificationSolutions.Add(model.CreateClassificationSolution(problemData));
}
RegisterClassificationSolutionsEventHandler();
}
private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
: base(original, cloner) {
trainingPartitions = new Dictionary();
testPartitions = new Dictionary();
foreach (var pair in original.trainingPartitions) {
trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
}
foreach (var pair in original.testPartitions) {
testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
}
weightCalculator = cloner.Clone(original.weightCalculator);
classificationSolutions = cloner.Clone(original.classificationSolutions);
RegisterClassificationSolutionsEventHandler();
}
public ClassificationEnsembleSolution()
: base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) {
trainingPartitions = new Dictionary();
testPartitions = new Dictionary();
classificationSolutions = new CheckedItemCollection();
weightCalculator = new MajorityVoteWeightCalculator();
RegisterClassificationSolutionsEventHandler();
}
public ClassificationEnsembleSolution(IClassificationProblemData problemData) :
this(Enumerable.Empty(), problemData) { }
public ClassificationEnsembleSolution(IEnumerable models, IClassificationProblemData problemData)
: this(models, problemData,
models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
models.Select(m => (IntRange)problemData.TestPartition.Clone())
) { }
public ClassificationEnsembleSolution(IEnumerable models, IClassificationProblemData problemData, IEnumerable trainingPartitions, IEnumerable testPartitions)
: base(new ClassificationEnsembleModel(Enumerable.Empty()), new ClassificationEnsembleProblemData(problemData)) {
this.trainingPartitions = new Dictionary();
this.testPartitions = new Dictionary();
this.classificationSolutions = new CheckedItemCollection();
List solutions = new List();
var modelEnumerator = models.GetEnumerator();
var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
var testPartitionEnumerator = testPartitions.GetEnumerator();
while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
IClassificationProblemData p = (IClassificationProblemData)ProblemData.Clone();
p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
p.TestPartition.Start = testPartitionEnumerator.Current.Start;
p.TestPartition.End = testPartitionEnumerator.Current.End;
solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p));
}
if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
throw new ArgumentException();
}
RegisterClassificationSolutionsEventHandler();
weightCalculator = new MajorityVoteWeightCalculator();
classificationSolutions.AddRange(solutions);
}
public override IDeepCloneable Clone(Cloner cloner) {
return new ClassificationEnsembleSolution(this, cloner);
}
private void RegisterClassificationSolutionsEventHandler() {
classificationSolutions.ItemsAdded += new CollectionItemsChangedEventHandler(classificationSolutions_ItemsAdded);
classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler(classificationSolutions_ItemsRemoved);
classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler(classificationSolutions_CollectionReset);
classificationSolutions.CheckedItemsChanged += new CollectionItemsChangedEventHandler(classificationSolutions_CheckedItemsChanged);
}
#region Evaluation
public override IEnumerable EstimatedClassValues {
get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
}
public override IEnumerable EstimatedTrainingClassValues {
get {
return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems,
ProblemData.Dataset,
ProblemData.TrainingIndices,
weightCalculator.GetTrainingClassDelegate());
}
}
public override IEnumerable EstimatedTestClassValues {
get {
return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems,
ProblemData.Dataset,
ProblemData.TestIndices,
weightCalculator.GetTestClassDelegate());
}
}
public override IEnumerable GetEstimatedClassValues(IEnumerable rows) {
return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems,
ProblemData.Dataset,
rows,
weightCalculator.GetAllClassDelegate());
}
public IEnumerable> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable rows) {
IEnumerable models = classificationSolutions.CheckedItems.Select(sol => sol.Model);
if (!models.Any()) yield break;
var estimatedValuesEnumerators = (from model in models
select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
.ToList();
while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
yield return from enumerator in estimatedValuesEnumerators
select enumerator.Current;
}
}
#endregion
protected override void OnProblemDataChanged() {
IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset,
ProblemData.AllowedInputVariables,
ProblemData.TargetVariable);
problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
problemData.TestPartition.Start = ProblemData.TestPartition.Start;
problemData.TestPartition.End = ProblemData.TestPartition.End;
foreach (var solution in ClassificationSolutions) {
if (solution is ClassificationEnsembleSolution)
solution.ProblemData = ProblemData;
else
solution.ProblemData = problemData;
}
foreach (var trainingPartition in trainingPartitions.Values) {
trainingPartition.Start = ProblemData.TrainingPartition.Start;
trainingPartition.End = ProblemData.TrainingPartition.End;
}
foreach (var testPartition in testPartitions.Values) {
testPartition.Start = ProblemData.TestPartition.Start;
testPartition.End = ProblemData.TestPartition.End;
}
base.OnProblemDataChanged();
}
public void AddClassificationSolutions(IEnumerable solutions) {
classificationSolutions.AddRange(solutions);
}
public void RemoveClassificationSolutions(IEnumerable solutions) {
classificationSolutions.RemoveRange(solutions);
}
private void classificationSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs e) {
foreach (var solution in e.Items) AddClassificationSolution(solution);
RecalculateResults();
}
private void classificationSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs e) {
foreach (var solution in e.Items) RemoveClassificationSolution(solution);
RecalculateResults();
}
private void classificationSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs e) {
foreach (var solution in e.OldItems) RemoveClassificationSolution(solution);
foreach (var solution in e.Items) AddClassificationSolution(solution);
RecalculateResults();
}
private void classificationSolutions_CheckedItemsChanged(object sender, CollectionItemsChangedEventArgs e) {
RecalculateResults();
}
protected override void RecalculateResults() {
weightCalculator.CalculateNormalizedWeights(classificationSolutions.CheckedItems);
base.RecalculateResults();
}
private void AddClassificationSolution(IClassificationSolution solution) {
if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
Model.Add(solution.Model);
trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
testPartitions[solution.Model] = solution.ProblemData.TestPartition;
}
private void RemoveClassificationSolution(IClassificationSolution solution) {
if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
Model.Remove(solution.Model);
trainingPartitions.Remove(solution.Model);
testPartitions.Remove(solution.Model);
}
}
}