#region License Information /* HeuristicLab * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic; namespace HeuristicLab.Problems.DataAnalysis.Classification { /// /// Represents a solution for a symbolic classification problem which can be visualized in the GUI. /// [Item("SymbolicClassificationSolution", "Represents a solution for a symbolic classification problem which can be visualized in the GUI.")] [StorableClass] public class SymbolicClassificationSolution : SymbolicRegressionSolution, IClassificationSolution { public new ClassificationProblemData ProblemData { get { return (ClassificationProblemData)base.ProblemData; } set { base.ProblemData = value; } } #region properties private List optimalThresholds; private List actualThresholds; public IEnumerable Thresholds { get { if (actualThresholds == null) RecalculateEstimatedValues(); return actualThresholds; } set { if (actualThresholds != null && actualThresholds.SequenceEqual(value)) return; actualThresholds = new List(value); OnThresholdsChanged(); } } public IEnumerable EstimatedClassValues { get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } } public IEnumerable EstimatedTrainingClassValues { get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); } } public IEnumerable EstimatedTestClassValues { get { return GetEstimatedClassValues(ProblemData.TestIndizes); } } [StorableConstructor] protected SymbolicClassificationSolution(bool deserializing) : base(deserializing) { } protected SymbolicClassificationSolution(SymbolicClassificationSolution original, Cloner cloner) : base(original, cloner) { } public SymbolicClassificationSolution(ClassificationProblemData problemData, SymbolicRegressionModel model, double lowerEstimationLimit, double upperEstimationLimit) : base(problemData, model, lowerEstimationLimit, upperEstimationLimit) { } public override IDeepCloneable Clone(Cloner cloner) { return new SymbolicClassificationSolution(this, cloner); } protected override void RecalculateEstimatedValues() { estimatedValues = (from x in Model.GetEstimatedValues(ProblemData, 0, ProblemData.Dataset.Rows) let boundedX = Math.Min(UpperEstimationLimit, Math.Max(LowerEstimationLimit, x)) select double.IsNaN(boundedX) ? UpperEstimationLimit : boundedX).ToList(); RecalculateClassIntermediates(); OnEstimatedValuesChanged(); } private void RecalculateClassIntermediates() { int slices = 100; List classInstances = (from classValue in ProblemData.Dataset.GetVariableValues(ProblemData.TargetVariable.Value) group classValue by classValue into grouping select grouping.Count()).ToList(); List> estimatedTargetValues = (from row in ProblemData.TrainingIndizes select new KeyValuePair( estimatedValues[row], ProblemData.Dataset[ProblemData.TargetVariable.Value, row])).ToList(); List originalClasses = ProblemData.SortedClassValues.ToList(); double[] thresholds = new double[ProblemData.NumberOfClasses + 1]; thresholds[0] = double.NegativeInfinity; thresholds[thresholds.Length - 1] = double.PositiveInfinity; for (int i = 1; i < thresholds.Length - 1; i++) { double lowerThreshold = thresholds[i - 1]; double actualThreshold = originalClasses[i - 1]; double thresholdIncrement = (originalClasses[i] - originalClasses[i - 1]) / slices; double bestThreshold = double.NaN; double bestClassificationScore = double.PositiveInfinity; while (actualThreshold < originalClasses[i]) { double classificationScore = 0.0; foreach (KeyValuePair estimatedTarget in estimatedTargetValues) { //all positives if (estimatedTarget.Value.IsAlmost(originalClasses[i - 1])) { if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold) //true positive classificationScore += ProblemData.MisclassificationMatrix[i - 1, i - 1] / classInstances[i - 1]; else //false negative classificationScore += ProblemData.MisclassificationMatrix[i, i - 1] / classInstances[i - 1]; } //all negatives else { if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold) classificationScore += ProblemData.MisclassificationMatrix[i - 1, i] / classInstances[i]; else //true negative, consider only upper class classificationScore += ProblemData.MisclassificationMatrix[i, i] / classInstances[i]; } } if (classificationScore < bestClassificationScore) { bestClassificationScore = classificationScore; bestThreshold = actualThreshold; } actualThreshold += thresholdIncrement; } thresholds[i] = bestThreshold; } this.optimalThresholds = new List(thresholds); this.actualThresholds = optimalThresholds; } public IEnumerable GetEstimatedClassValues(IEnumerable rows) { double[] classValues = ProblemData.SortedClassValues.ToArray(); foreach (int row in rows) { double value = estimatedValues[row]; int classIndex = 0; while (value > actualThresholds[classIndex + 1]) classIndex++; yield return classValues[classIndex]; } } #endregion public event EventHandler ThresholdsChanged; private void OnThresholdsChanged() { var handler = ThresholdsChanged; if (handler != null) ThresholdsChanged(this, EventArgs.Empty); } } }