#region License Information
/* HeuristicLab
* Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Optimization;
using HeuristicLab.Parameters;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using HeuristicLab.Problems.DataAnalysis;
using HeuristicLab.Random;
namespace HeuristicLab.Algorithms.DataAnalysis {
///
/// 1R classification algorithm.
///
[Item("OneR", "1R classification algorithm.")]
[Creatable("Data Analysis")]
[StorableClass]
public sealed class OneR : FixedDataAnalysisAlgorithm {
public IValueParameter MinBucketSizeParameter {
get { return (IValueParameter)Parameters["MinBucketSize"]; }
}
public IValueParameter RandomParameter {
get { return (IValueParameter)Parameters["Random"]; }
}
[StorableConstructor]
private OneR(bool deserializing) : base(deserializing) { }
private OneR(OneR original, Cloner cloner)
: base(original, cloner) {
}
public OneR()
: base() {
Parameters.Add(new ValueParameter("MinBucketSize", "Minimum size of a bucket for numerical values. (Except for the rightmost bucket)", new IntValue(6)));
Parameters.Add(new ValueParameter("Random", "Random number generator", new FastRandom()));
Problem = new ClassificationProblem();
}
public override IDeepCloneable Clone(Cloner cloner) {
return new OneR(this, cloner);
}
protected override void Run() {
var startTime = DateTime.Now;
var solution = CreateOneRSolution(Problem.ProblemData, MinBucketSizeParameter.Value.Value, RandomParameter.Value);
Results.Add(new Result("OneR solution", "The 1R classifier.", solution));
Results.Add(new Result("OneR Execution Time", "", new TimeSpanValue(DateTime.Now - startTime)));
startTime = DateTime.Now;
var solution3 = OneRTest.CreateOneRSolution(Problem.ProblemData, MinBucketSizeParameter.Value.Value);
Results.Add(new Result("OneR Test2 solution", "The 1R classifier.", solution3));
Results.Add(new Result("OneR Test2 Execution", "", new TimeSpanValue(DateTime.Now - startTime)));
}
public static IClassificationSolution CreateOneRSolution(IClassificationProblemData problemData, int minBucketSize, IRandom random) {
Dataset dataset = problemData.Dataset;
var trainingIndices = problemData.TrainingIndices;
int rowCount = trainingIndices.Count();
string target = problemData.TargetVariable;
var inputVariables = problemData.AllowedInputVariables.ToArray();
var classValues = problemData.ClassValues.ToArray();
double dominatingClass;
string bestVariable = null;
Dictionary bestSplits = null;
double missingValuesClass = double.NaN;
int correctClassified = 0;
for (int variable = 0; variable < inputVariables.Length; variable++) {
var inputVariableValues = dataset.GetDoubleValues(inputVariables[variable], trainingIndices).ToArray();
var classValuesInDataset = dataset.GetDoubleValues(target, trainingIndices).ToArray();
int curCorrectClassified = 0;
Dictionary classCount = PrepareClassCountDictionary(classValues);
Array.Sort(inputVariableValues, classValuesInDataset);
double curSplit = Double.NegativeInfinity;
Dictionary splits = new Dictionary();
bool newBucket = true;
bool done = false;
int curRow = 0;
if (curRow < inputVariableValues.Length && Double.IsNaN(inputVariableValues[curRow])) {
while (curRow < inputVariableValues.Length && Double.IsNaN(inputVariableValues[curRow])) {
classCount[classValuesInDataset[curRow]] += 1;
curRow++;
}
if (ExistsDominatingClass(classCount, out dominatingClass)) {
missingValuesClass = dominatingClass;
} else {
missingValuesClass = GetRandomMaxClass(classCount, random);
}
correctClassified += classCount[missingValuesClass];
classCount = PrepareClassCountDictionary(classValues);
}
while (curRow < inputVariableValues.Length) {
if (newBucket) {
for (int i = 0; i < minBucketSize && curRow + i < inputVariableValues.Length; i++) {
classCount[classValuesInDataset[curRow + i]] += 1;
}
curRow += minBucketSize;
if (curRow >= inputVariableValues.Length) {
break;
}
curSplit = inputVariableValues[curRow];
curRow = SetCurRowToEndOfSplit(curRow, inputVariableValues, classValuesInDataset, classCount, curSplit);
newBucket = false;
}
if (ExistsDominatingClass(classCount, out dominatingClass)) {
while (curRow + 1 < classValuesInDataset.Length &&
IsNextSplitStillDominatingClass(curRow, inputVariableValues, classValuesInDataset, curSplit, dominatingClass)) {
curRow++;
curSplit = inputVariableValues[curRow];
classCount[classValuesInDataset[curRow]] += 1;
curRow = SetCurRowToEndOfSplit(curRow, inputVariableValues, classValuesInDataset, classCount, curSplit);
}
curCorrectClassified += classCount[dominatingClass];
done = curRow >= inputVariableValues.Length - 1;
if (done) {
curSplit = Double.PositiveInfinity;
splits.Add(curSplit, dominatingClass);
break;
}
curRow++;
//intervals exclude end
curSplit = inputVariableValues[curRow];
splits.Add(curSplit, dominatingClass);
//intervals include start
curSplit = inputVariableValues[curRow];
classCount = PrepareClassCountDictionary(classValues);
newBucket = true;
} else {
curSplit = inputVariableValues[curRow];
classCount[classValuesInDataset[curRow]] += 1;
curRow = SetCurRowToEndOfSplit(curRow, inputVariableValues, classValuesInDataset, classCount, curSplit);
}
}
if (!done) {
curSplit = Double.PositiveInfinity;
double randomClass = GetRandomMaxClass(classCount, random);
splits.Add(curSplit, randomClass);
curCorrectClassified += classCount[randomClass];
}
if (curCorrectClassified > correctClassified) {
bestVariable = inputVariables[variable];
bestSplits = splits;
}
}
//merge intervals to simplify symbolic expression tree
Dictionary mergedSplits = MergeSplits(bestSplits);
var model = new OneRClassificationModel(bestVariable, mergedSplits.Keys.ToArray(), mergedSplits.Values.ToArray(), missingValuesClass);
var solution = new OneRClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
return solution;
}
private static double GetRandomMaxClass(Dictionary classCount, IRandom random) {
IList possibleClasses = new List();
int max = 0;
foreach (var item in classCount) {
if (max < item.Value) {
max = item.Value;
possibleClasses = new List();
possibleClasses.Add(item.Key);
} else if (max == item.Value) {
possibleClasses.Add(item.Key);
}
}
int classindex = random.Next(possibleClasses.Count);
return possibleClasses[classindex];
}
private static bool IsNextSplitStillDominatingClass(int curRow, double[] inputVariableValues, double[] classValuesInDataset, double curSplit, double dominatingClass) {
if (curRow >= classValuesInDataset.Length) {
return false;
}
double nextSplit = inputVariableValues[curRow + 1];
int i = 1;
while (curRow + i < classValuesInDataset.Length
&& inputVariableValues[curRow + i] == nextSplit
&& classValuesInDataset[curRow + i] == dominatingClass) {
i++;
}
if (curRow + i >= classValuesInDataset.Length) {
return true;
}
if (inputVariableValues[curRow + i] != nextSplit) {
return true;
}
// the next split would also contain values of a class which
// is not dominating (classValuesInDataset[curRow + i] != dominatingClass)
return false;
}
// needed if variable contains the same value several times
private static int SetCurRowToEndOfSplit(int curRow, double[] inputVariableValues, double[] classValuesInDataset, Dictionary classCount, double curSplit) {
while (curRow + 1 < inputVariableValues.Length && inputVariableValues[curRow + 1] == curSplit) {
curRow++;
classCount[classValuesInDataset[curRow]] += 1;
}
return curRow;
}
private static Dictionary MergeSplits(Dictionary bestSplits) {
Dictionary mergedSplits = new Dictionary();
double nextSplit, nextClass;
nextSplit = nextClass = double.NaN;
foreach (var item in bestSplits) {
if (Double.IsNaN(nextSplit)) {
nextSplit = item.Key;
nextClass = item.Value;
} else {
if (nextClass == item.Value) {
nextSplit = item.Key;
} else {
mergedSplits.Add(nextSplit, nextClass);
nextSplit = item.Key;
nextClass = item.Value;
}
}
}
mergedSplits.Add(nextSplit, nextClass);
return mergedSplits;
}
private static bool ExistsDominatingClass(Dictionary classCount, out double dominatingClass) {
bool dominating = false;
int count = 0;
dominatingClass = double.NaN;
foreach (var item in classCount) {
if (item.Value > count) {
dominatingClass = item.Key;
count = item.Value;
dominating = true;
} else if (item.Value == count) {
dominatingClass = double.NaN;
dominating = false;
}
}
return dominating;
}
private static Dictionary PrepareClassCountDictionary(double[] classValues) {
Dictionary classCount = new Dictionary();
for (int i = 0; i < classValues.Length; i++) {
classCount[classValues[i]] = 0;
}
return classCount;
}
}
}