#region License Information
/* HeuristicLab
* Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Drawing;
using System.Linq;
using HeuristicLab.Collections;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
namespace HeuristicLab.Optimization.Operators.LCS {
[StorableClass]
[Item("GAssistEnsembleSolution", "Represents a GAssist ensemble.")]
[Creatable("Data Analysis - Ensembles")]
public class GAssistEnsembleSolution : ResultCollection, IGAssistEnsembleSolution {
private readonly Dictionary trainingEvaluationCache = new Dictionary();
private readonly Dictionary testEvaluationCache = new Dictionary();
private readonly Dictionary evaluationCache = new Dictionary();
private const string ModelResultName = "Model";
private const string ProblemDataResultName = "ProblemData";
private const string TrainingAccuracyResultName = "Accuracy (training)";
private const string TestAccuracyResultName = "Accuracy (test)";
public string Filename { get; set; }
public static new Image StaticItemImage {
get { return HeuristicLab.Common.Resources.VSImageLibrary.Function; }
}
public double TrainingAccuracy {
get { return ((PercentValue)this[TrainingAccuracyResultName].Value).Value; }
private set { ((PercentValue)this[TrainingAccuracyResultName].Value).Value = value; }
}
public double TestAccuracy {
get { return ((PercentValue)this[TestAccuracyResultName].Value).Value; }
private set { ((PercentValue)this[TestAccuracyResultName].Value).Value = value; }
}
#region properties
public IGAssistEnsembleModel Model {
get { return (IGAssistEnsembleModel)this[ModelResultName].Value; }
protected set {
if (this[ModelResultName].Value != value) {
if (value != null) {
this[ModelResultName].Value = value;
OnModelChanged();
}
}
}
}
public IGAssistEnsembleProblemData ProblemData {
get { return (IGAssistEnsembleProblemData)this[ProblemDataResultName].Value; }
set {
if (this[ProblemDataResultName].Value != value) {
if (value != null) {
ProblemData.Changed -= new EventHandler(ProblemData_Changed);
this[ProblemDataResultName].Value = value;
ProblemData.Changed += new EventHandler(ProblemData_Changed);
OnProblemDataChanged();
}
}
}
}
private void ProblemData_Changed(object sender, EventArgs e) {
OnProblemDataChanged();
}
#endregion
private readonly ItemCollection gassistSolutions;
public IItemCollection GAssistSolutions {
get { return gassistSolutions; }
}
[Storable]
private Dictionary trainingPartitions;
[Storable]
private Dictionary testPartitions;
[StorableHook(HookType.AfterDeserialization)]
private void AfterDeserialization() {
foreach (var model in Model.Models) {
IGAssistProblemData problemData = ProblemData.GetGAssistProblemData();
problemData.TrainingPartition.Start = trainingPartitions[model].Start;
problemData.TrainingPartition.End = trainingPartitions[model].End;
problemData.TestPartition.Start = testPartitions[model].Start;
problemData.TestPartition.End = testPartitions[model].End;
gassistSolutions.Add(model.CreateGAssistSolution(problemData));
}
RegisterGAssistSolutionsEventHandler();
}
[StorableConstructor]
protected GAssistEnsembleSolution(bool deserializing)
: base(deserializing) {
gassistSolutions = new ItemCollection();
}
protected GAssistEnsembleSolution(GAssistEnsembleSolution original, Cloner cloner)
: base(original, cloner) {
trainingPartitions = new Dictionary();
testPartitions = new Dictionary();
foreach (var pair in original.trainingPartitions) {
trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
}
foreach (var pair in original.testPartitions) {
testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
}
trainingEvaluationCache = new Dictionary(original.ProblemData.TrainingIndices.Count());
testEvaluationCache = new Dictionary(original.ProblemData.TestIndices.Count());
gassistSolutions = cloner.Clone(original.gassistSolutions);
RegisterGAssistSolutionsEventHandler();
}
public GAssistEnsembleSolution(IEnumerable models, IGAssistProblemData problemData)
: this(models, problemData,
models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
models.Select(m => (IntRange)problemData.TestPartition.Clone())
) { }
public GAssistEnsembleSolution()
: base() {
trainingPartitions = new Dictionary();
testPartitions = new Dictionary();
gassistSolutions = new ItemCollection();
RegisterGAssistSolutionsEventHandler();
}
public GAssistEnsembleSolution(IGAssistProblemData problemData)
: this(Enumerable.Empty(), problemData) {
}
public GAssistEnsembleSolution(IEnumerable models, IGAssistProblemData problemData, IEnumerable trainingPartitions, IEnumerable testPartitions)
: base() {
Add(new Result(ModelResultName, "The data analysis model.", new GAssistEnsembleModel(Enumerable.Empty())));
Add(new Result(ProblemDataResultName, "The data analysis problem data.", new GAssistEnsembleProblemData((IGAssistProblemData)problemData.Clone())));
Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));
Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));
this.trainingPartitions = new Dictionary();
this.testPartitions = new Dictionary();
this.gassistSolutions = new ItemCollection();
List solutions = new List();
var modelEnumerator = models.GetEnumerator();
var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
var testPartitionEnumerator = testPartitions.GetEnumerator();
while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
var p = (IGAssistProblemData)problemData.Clone();
p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
p.TestPartition.Start = testPartitionEnumerator.Current.Start;
p.TestPartition.End = testPartitionEnumerator.Current.End;
solutions.Add(modelEnumerator.Current.CreateGAssistSolution(p));
}
if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
throw new ArgumentException();
}
trainingEvaluationCache = new Dictionary(problemData.TrainingIndices.Count());
testEvaluationCache = new Dictionary(problemData.TestIndices.Count());
RegisterGAssistSolutionsEventHandler();
gassistSolutions.AddRange(solutions);
}
public override IDeepCloneable Clone(Cloner cloner) {
return new GAssistEnsembleSolution(this, cloner);
}
private void RegisterGAssistSolutionsEventHandler() {
gassistSolutions.ItemsAdded += new CollectionItemsChangedEventHandler(gassistSolutions_ItemsAdded);
gassistSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler(gassistSolutions_ItemsRemoved);
gassistSolutions.CollectionReset += new CollectionItemsChangedEventHandler(gassistSolutions_CollectionReset);
}
#region Evaluation
public IEnumerable EstimatedNiches {
get { return GetEstimatedNiches(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
}
public IEnumerable EstimatedTrainingNiche {
get {
var rows = ProblemData.TrainingIndices;
var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys);
var rowsEnumerator = rowsToEvaluate.GetEnumerator();
var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
trainingEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
}
return rows.Select(row => trainingEvaluationCache[row]);
}
}
public IEnumerable EstimatedTestNiche {
get {
var rows = ProblemData.TestIndices;
var rowsToEvaluate = rows.Except(testEvaluationCache.Keys);
var rowsEnumerator = rowsToEvaluate.GetEnumerator();
var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();
while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
testEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
}
return rows.Select(row => testEvaluationCache[row]);
}
}
public IEnumerable GetEstimatedNiches(IEnumerable rows) {
var rowsToEvaluate = rows.Except(evaluationCache.Keys);
var rowsEnumerator = rowsToEvaluate.GetEnumerator();
var valuesEnumerator = (from xs in GetEstimatedNicheVectors(ProblemData.FetchInput(rows))
select AggregateEstimatedClassValues(xs))
.GetEnumerator();
while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
}
return rows.Select(row => evaluationCache[row]);
}
public IEnumerable> GetEstimatedNicheVectors(IEnumerable input) {
if (!Model.Models.Any()) yield break;
var estimatedValuesEnumerators = (from model in Model.Models
select model.Evaluate(input).GetEnumerator())
.ToList();
while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
yield return from enumerator in estimatedValuesEnumerators
select enumerator.Current;
}
}
private IEnumerable GetEstimatedValues(IEnumerable rows, Func modelSelectionPredicate) {
var input = ProblemData.FetchInput(rows);
var estimatedValuesEnumerators = (from model in Model.Models
select new { Model = model, EstimatedValuesEnumerator = model.Evaluate(input).GetEnumerator() })
.ToList();
var rowsEnumerator = rows.GetEnumerator();
// aggregate to make sure that MoveNext is called for all enumerators
while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
int currentRow = rowsEnumerator.Current;
var selectedEnumerators = from pair in estimatedValuesEnumerators
where modelSelectionPredicate(currentRow, pair.Model)
select pair.EstimatedValuesEnumerator;
yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
}
}
private IGAssistNiche AggregateEstimatedClassValues(IEnumerable estimatedNiches) {
return estimatedNiches
.GroupBy(x => x, new GAssistNicheComparer())
.OrderByDescending(g => g.Count())
.Select(g => g.Key)
.FirstOrDefault();
}
private void RecalculateResults() {
var originalTrainingCondition = ProblemData.FetchInput(ProblemData.TrainingIndices);
var originalTestCondition = ProblemData.FetchInput(ProblemData.TestIndices);
var estimatedTraining = EstimatedTrainingNiche;
var estimatedTest = EstimatedTestNiche;
var originalTrainingAction = ProblemData.FetchAction(ProblemData.TrainingIndices);
var originalTestAction = ProblemData.FetchAction(ProblemData.TestIndices);
TrainingAccuracy = CalculateAccuracy(originalTrainingAction, estimatedTraining);
TestAccuracy = CalculateAccuracy(originalTestAction, estimatedTest);
}
public static double CalculateAccuracy(IEnumerable original, IEnumerable estimated) {
double correctClassified = 0;
double rows = original.Count();
var originalEnumerator = original.GetEnumerator();
var estimatedActionEnumerator = estimated.GetEnumerator();
while (originalEnumerator.MoveNext() && estimatedActionEnumerator.MoveNext()) {
if (originalEnumerator.Current != null && estimatedActionEnumerator.Current != null
&& originalEnumerator.Current.SameNiche(estimatedActionEnumerator.Current)) {
correctClassified++;
}
}
return correctClassified / rows;
}
private bool RowIsTrainingForModel(int currentRow, IGAssistModel model) {
return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
(trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
}
private bool RowIsTestForModel(int currentRow, IGAssistModel model) {
return testPartitions == null || !testPartitions.ContainsKey(model) ||
(testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
}
#endregion
public event EventHandler ProblemDataChanged;
protected void OnProblemDataChanged() {
trainingEvaluationCache.Clear();
testEvaluationCache.Clear();
evaluationCache.Clear();
IGAssistProblemData problemData = ProblemData.GetGAssistProblemData();
problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
problemData.TestPartition.Start = ProblemData.TestPartition.Start;
problemData.TestPartition.End = ProblemData.TestPartition.End;
foreach (var solution in GAssistSolutions) {
if (solution is GAssistEnsembleSolution)
solution.ProblemData = ProblemData;
else
solution.ProblemData = problemData;
}
foreach (var trainingPartition in trainingPartitions.Values) {
trainingPartition.Start = ProblemData.TrainingPartition.Start;
trainingPartition.End = ProblemData.TrainingPartition.End;
}
foreach (var testPartition in testPartitions.Values) {
testPartition.Start = ProblemData.TestPartition.Start;
testPartition.End = ProblemData.TestPartition.End;
}
RecalculateResults();
var listeners = ProblemDataChanged;
if (listeners != null) listeners(this, EventArgs.Empty);
}
public event EventHandler ModelChanged;
protected virtual void OnModelChanged() {
RecalculateResults();
var listeners = ModelChanged;
if (listeners != null) listeners(this, EventArgs.Empty);
}
public void AddGAssistSolutions(IEnumerable solutions) {
gassistSolutions.AddRange(solutions);
trainingEvaluationCache.Clear();
testEvaluationCache.Clear();
evaluationCache.Clear();
}
public void RemoveGAssistSolutions(IEnumerable solutions) {
gassistSolutions.RemoveRange(solutions);
trainingEvaluationCache.Clear();
testEvaluationCache.Clear();
evaluationCache.Clear();
}
private void gassistSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs e) {
foreach (var solution in e.Items) AddGAssistSolution(solution);
RecalculateResults();
}
private void gassistSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs e) {
foreach (var solution in e.Items) RemoveGAssistSolution(solution);
RecalculateResults();
}
private void gassistSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs e) {
foreach (var solution in e.OldItems) RemoveGAssistSolution(solution);
foreach (var solution in e.Items) AddGAssistSolution(solution);
RecalculateResults();
}
private void AddGAssistSolution(IGAssistSolution solution) {
if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
Model.Add(solution.Model);
trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
testPartitions[solution.Model] = solution.ProblemData.TestPartition;
trainingEvaluationCache.Clear();
testEvaluationCache.Clear();
evaluationCache.Clear();
}
private void RemoveGAssistSolution(IGAssistSolution solution) {
if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
Model.Remove(solution.Model);
trainingPartitions.Remove(solution.Model);
testPartitions.Remove(solution.Model);
trainingEvaluationCache.Clear();
testEvaluationCache.Clear();
evaluationCache.Clear();
}
#region IGAssistSolution Members
IGAssistModel IGAssistSolution.Model {
get { return Model; }
}
IGAssistProblemData IGAssistSolution.ProblemData {
get { return ProblemData; }
set { ProblemData = new GAssistEnsembleProblemData(value); }
}
public int TrainingNumberOfAliveRules {
get { return gassistSolutions.Sum(x => x.TrainingNumberOfAliveRules); }
}
public double TrainingTheoryLength {
get { return gassistSolutions.Sum(x => x.TrainingTheoryLength); }
}
public double TrainingExceptionsLength {
get { return 105.0 - TrainingAccuracy * 100.0; }
}
public int Classes {
get { return ProblemData.Classes; }
}
#endregion
}
}