#region License Information
/* HeuristicLab
* Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
using HEAL.Fossil;
namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
///
/// Represents a nearest neighbour model for regression and classification
///
[StorableType("B9F8A753-B102-4356-8821-76E31634A0C6")]
[Item("SymbolicNearestNeighbourClassificationModel", "Represents a nearest neighbour model for symbolic classification.")]
public sealed class SymbolicNearestNeighbourClassificationModel : SymbolicClassificationModel {
[Storable]
private int k;
[Storable]
private List trainedClasses;
[Storable]
private List trainedEstimatedValues;
[Storable]
private ClassFrequencyComparer frequencyComparer;
[StorableConstructor]
private SymbolicNearestNeighbourClassificationModel(StorableConstructorFlag _) : base(_) { }
private SymbolicNearestNeighbourClassificationModel(SymbolicNearestNeighbourClassificationModel original, Cloner cloner)
: base(original, cloner) {
k = original.k;
frequencyComparer = new ClassFrequencyComparer(original.frequencyComparer);
trainedEstimatedValues = new List(original.trainedEstimatedValues);
trainedClasses = new List(original.trainedClasses);
}
public SymbolicNearestNeighbourClassificationModel(string targetVariable, int k, ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue)
: base(targetVariable, tree, interpreter, lowerEstimationLimit, upperEstimationLimit) {
this.k = k;
frequencyComparer = new ClassFrequencyComparer();
}
public override IDeepCloneable Clone(Cloner cloner) {
return new SymbolicNearestNeighbourClassificationModel(this, cloner);
}
public override IEnumerable GetEstimatedClassValues(IDataset dataset, IEnumerable rows) {
var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows)
.LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
foreach (var ev in estimatedValues) {
// find the range [lower, upper[ of trainedTargetValues that contains the k closest neighbours
// the range can span more than k elements when there are equal estimated values
// find the index of the training-point to which distance is shortest
int lower = trainedEstimatedValues.BinarySearch(ev);
int upper;
// if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item
if (lower < 0) {
lower = ~lower;
// lower is not necessarily the closer one
// determine which element is closer to ev (lower - 1) or (lower)
if (lower == trainedEstimatedValues.Count ||
(lower > 0 && Math.Abs(ev - trainedEstimatedValues[lower - 1]) < Math.Abs(ev - trainedEstimatedValues[lower]))) {
lower = lower - 1;
}
}
upper = lower + 1;
// at this point we have a range [lower, upper[ that includes only the closest element to ev
// expand the range to left or right looking for the nearest neighbors
while (upper - lower < Math.Min(k, trainedEstimatedValues.Count)) {
bool lowerIsCloser = upper >= trainedEstimatedValues.Count ||
(lower > 0 && ev - trainedEstimatedValues[lower] <= trainedEstimatedValues[upper] - ev);
bool upperIsCloser = lower <= 0 ||
(upper < trainedEstimatedValues.Count &&
ev - trainedEstimatedValues[lower] >= trainedEstimatedValues[upper] - ev);
if (!lowerIsCloser && !upperIsCloser) break;
if (lowerIsCloser) {
lower--;
// eat up all equal values
while (lower > 0 && trainedEstimatedValues[lower - 1].IsAlmost(trainedEstimatedValues[lower]))
lower--;
}
if (upperIsCloser) {
upper++;
while (upper < trainedEstimatedValues.Count &&
trainedEstimatedValues[upper - 1].IsAlmost(trainedEstimatedValues[upper]))
upper++;
}
}
// majority voting with preference for bigger class in case of tie
yield return Enumerable.Range(lower, upper - lower)
.Select(i => trainedClasses[i])
.GroupBy(c => c)
.Select(g => new { Class = g.Key, Votes = g.Count() })
.MaxItems(p => p.Votes)
.OrderByDescending(m => m.Class, frequencyComparer)
.First().Class;
}
}
public override void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable rows) {
var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, problemData.Dataset, rows)
.LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
var trainedClasses = targetValues.ToArray();
var trainedEstimatedValues = estimatedValues.ToArray();
Array.Sort(trainedEstimatedValues, trainedClasses);
this.trainedClasses = new List(trainedClasses);
this.trainedEstimatedValues = new List(trainedEstimatedValues);
var freq = trainedClasses
.GroupBy(c => c)
.ToDictionary(g => g.Key, g => g.Count());
this.frequencyComparer = new ClassFrequencyComparer(freq);
}
public override ISymbolicClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
return new SymbolicClassificationSolution((ISymbolicClassificationModel)Clone(), problemData);
}
}
[StorableType("523AFB5D-3758-4547-BD6E-1181A01A02B4")]
internal sealed class ClassFrequencyComparer : IComparer {
[Storable]
private readonly Dictionary classFrequencies;
[StorableConstructor]
private ClassFrequencyComparer(StorableConstructorFlag _) { }
public ClassFrequencyComparer() {
classFrequencies = new Dictionary();
}
public ClassFrequencyComparer(Dictionary frequencies) {
classFrequencies = frequencies;
}
public ClassFrequencyComparer(ClassFrequencyComparer original) {
classFrequencies = new Dictionary(original.classFrequencies);
}
public int Compare(double x, double y) {
bool cx = classFrequencies.ContainsKey(x), cy = classFrequencies.ContainsKey(y);
if (cx && cy)
return classFrequencies[x].CompareTo(classFrequencies[y]);
if (cx) return 1;
return -1;
}
}
}