#region License Information
/* HeuristicLab
* Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic;
namespace HeuristicLab.Problems.DataAnalysis.Classification {
///
/// Represents a solution for a symbolic classification problem which can be visualized in the GUI.
///
[Item("SymbolicClassificationSolution", "Represents a solution for a symbolic classification problem which can be visualized in the GUI.")]
[StorableClass]
public class SymbolicClassificationSolution : SymbolicRegressionSolution, IClassificationSolution {
public new ClassificationProblemData ProblemData {
get { return (ClassificationProblemData)base.ProblemData; }
set { base.ProblemData = value; }
}
#region properties
private List optimalThresholds;
private List actualThresholds;
public IEnumerable Thresholds {
get {
if (actualThresholds == null) RecalculateEstimatedValues();
return actualThresholds;
}
set {
if (actualThresholds != null && actualThresholds.SequenceEqual(value))
return;
actualThresholds = new List(value);
OnThresholdsChanged();
}
}
public IEnumerable EstimatedClassValues {
get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
}
public IEnumerable EstimatedTrainingClassValues {
get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
}
public IEnumerable EstimatedTestClassValues {
get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
}
[StorableConstructor]
protected SymbolicClassificationSolution(bool deserializing) : base(deserializing) { }
protected SymbolicClassificationSolution(SymbolicClassificationSolution original, Cloner cloner) : base(original, cloner) { }
public SymbolicClassificationSolution(ClassificationProblemData problemData, SymbolicRegressionModel model, double lowerEstimationLimit, double upperEstimationLimit)
: base(problemData, model, lowerEstimationLimit, upperEstimationLimit) {
}
public override IDeepCloneable Clone(Cloner cloner) {
return new SymbolicClassificationSolution(this, cloner);
}
protected override void RecalculateEstimatedValues() {
estimatedValues =
(from x in Model.GetEstimatedValues(ProblemData, 0, ProblemData.Dataset.Rows)
let boundedX = Math.Min(UpperEstimationLimit, Math.Max(LowerEstimationLimit, x))
select double.IsNaN(boundedX) ? UpperEstimationLimit : boundedX).ToList();
RecalculateClassIntermediates();
OnEstimatedValuesChanged();
}
private void RecalculateClassIntermediates() {
int slices = 100;
List classInstances = (from classValue in ProblemData.Dataset.GetVariableValues(ProblemData.TargetVariable.Value)
group classValue by classValue into grouping
select grouping.Count()).ToList();
List> estimatedTargetValues =
(from row in ProblemData.TrainingIndizes
select new KeyValuePair(
estimatedValues[row],
ProblemData.Dataset[ProblemData.TargetVariable.Value, row])).ToList();
List originalClasses = ProblemData.SortedClassValues.ToList();
double[] thresholds = new double[ProblemData.NumberOfClasses + 1];
thresholds[0] = double.NegativeInfinity;
thresholds[thresholds.Length - 1] = double.PositiveInfinity;
for (int i = 1; i < thresholds.Length - 1; i++) {
double lowerThreshold = thresholds[i - 1];
double actualThreshold = originalClasses[i - 1];
double thresholdIncrement = (originalClasses[i] - originalClasses[i - 1]) / slices;
double bestThreshold = double.NaN;
double bestClassificationScore = double.PositiveInfinity;
while (actualThreshold < originalClasses[i]) {
double classificationScore = 0.0;
foreach (KeyValuePair estimatedTarget in estimatedTargetValues) {
//all positives
if (estimatedTarget.Value.IsAlmost(originalClasses[i - 1])) {
if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
//true positive
classificationScore += ProblemData.MisclassificationMatrix[i - 1, i - 1] / classInstances[i - 1];
else
//false negative
classificationScore += ProblemData.MisclassificationMatrix[i, i - 1] / classInstances[i - 1];
}
//all negatives
else {
if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
classificationScore += ProblemData.MisclassificationMatrix[i - 1, i] / classInstances[i];
else
//true negative, consider only upper class
classificationScore += ProblemData.MisclassificationMatrix[i, i] / classInstances[i];
}
}
if (classificationScore < bestClassificationScore) {
bestClassificationScore = classificationScore;
bestThreshold = actualThreshold;
}
actualThreshold += thresholdIncrement;
}
thresholds[i] = bestThreshold;
}
this.optimalThresholds = new List(thresholds);
this.actualThresholds = optimalThresholds;
}
public IEnumerable GetEstimatedClassValues(IEnumerable rows) {
double[] classValues = ProblemData.SortedClassValues.ToArray();
foreach (int row in rows) {
double value = estimatedValues[row];
int classIndex = 0;
while (value > actualThresholds[classIndex + 1])
classIndex++;
yield return classValues[classIndex];
}
}
#endregion
public event EventHandler ThresholdsChanged;
private void OnThresholdsChanged() {
var handler = ThresholdsChanged;
if (handler != null)
ThresholdsChanged(this, EventArgs.Empty);
}
}
}