#region License Information /* HeuristicLab * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Random; namespace HeuristicLab.Algorithms.DataAnalysis { /// /// 1R classification algorithm. /// [Item("OneR", "1R classification algorithm.")] [Creatable("Data Analysis")] [StorableClass] public sealed class OneR : FixedDataAnalysisAlgorithm { public IValueParameter MinBucketSizeParameter { get { return (IValueParameter)Parameters["MinBucketSize"]; } } public IValueParameter RandomParameter { get { return (IValueParameter)Parameters["Random"]; } } [StorableConstructor] private OneR(bool deserializing) : base(deserializing) { } private OneR(OneR original, Cloner cloner) : base(original, cloner) { } public OneR() : base() { Parameters.Add(new ValueParameter("MinBucketSize", "Minimum size of a bucket for numerical values. (Except for the rightmost bucket)", new IntValue(6))); Parameters.Add(new ValueParameter("Random", "Random number generator", new FastRandom())); Problem = new ClassificationProblem(); } public override IDeepCloneable Clone(Cloner cloner) { return new OneR(this, cloner); } protected override void Run() { var startTime = DateTime.Now; var solution = CreateOneRSolution(Problem.ProblemData, MinBucketSizeParameter.Value.Value, RandomParameter.Value); Results.Add(new Result("OneR solution", "The 1R classifier.", solution)); Results.Add(new Result("OneR Execution Time", "", new TimeSpanValue(DateTime.Now - startTime))); startTime = DateTime.Now; var solution3 = OneRTest.CreateOneRSolution(Problem.ProblemData, MinBucketSizeParameter.Value.Value); Results.Add(new Result("OneR Test2 solution", "The 1R classifier.", solution3)); Results.Add(new Result("OneR Test2 Execution", "", new TimeSpanValue(DateTime.Now - startTime))); } public static IClassificationSolution CreateOneRSolution(IClassificationProblemData problemData, int minBucketSize, IRandom random) { Dataset dataset = problemData.Dataset; var trainingIndices = problemData.TrainingIndices; int rowCount = trainingIndices.Count(); string target = problemData.TargetVariable; var inputVariables = problemData.AllowedInputVariables.ToArray(); var classValues = problemData.ClassValues.ToArray(); double dominatingClass; string bestVariable = null; Dictionary bestSplits = null; double missingValuesClass = double.NaN; int correctClassified = 0; for (int variable = 0; variable < inputVariables.Length; variable++) { var inputVariableValues = dataset.GetDoubleValues(inputVariables[variable], trainingIndices).ToArray(); var classValuesInDataset = dataset.GetDoubleValues(target, trainingIndices).ToArray(); int curCorrectClassified = 0; Dictionary classCount = PrepareClassCountDictionary(classValues); Array.Sort(inputVariableValues, classValuesInDataset); double curSplit = Double.NegativeInfinity; Dictionary splits = new Dictionary(); bool newBucket = true; bool done = false; int curRow = 0; if (curRow < inputVariableValues.Length && Double.IsNaN(inputVariableValues[curRow])) { while (curRow < inputVariableValues.Length && Double.IsNaN(inputVariableValues[curRow])) { classCount[classValuesInDataset[curRow]] += 1; curRow++; } if (ExistsDominatingClass(classCount, out dominatingClass)) { missingValuesClass = dominatingClass; } else { missingValuesClass = GetRandomMaxClass(classCount, random); } correctClassified += classCount[missingValuesClass]; classCount = PrepareClassCountDictionary(classValues); } while (curRow < inputVariableValues.Length) { if (newBucket) { for (int i = 0; i < minBucketSize && curRow + i < inputVariableValues.Length; i++) { classCount[classValuesInDataset[curRow + i]] += 1; } curRow += minBucketSize; if (curRow >= inputVariableValues.Length) { break; } curSplit = inputVariableValues[curRow]; curRow = SetCurRowToEndOfSplit(curRow, inputVariableValues, classValuesInDataset, classCount, curSplit); newBucket = false; } if (ExistsDominatingClass(classCount, out dominatingClass)) { while (curRow + 1 < classValuesInDataset.Length && IsNextSplitStillDominatingClass(curRow, inputVariableValues, classValuesInDataset, curSplit, dominatingClass)) { curRow++; curSplit = inputVariableValues[curRow]; classCount[classValuesInDataset[curRow]] += 1; curRow = SetCurRowToEndOfSplit(curRow, inputVariableValues, classValuesInDataset, classCount, curSplit); } curCorrectClassified += classCount[dominatingClass]; done = curRow >= inputVariableValues.Length - 1; if (done) { curSplit = Double.PositiveInfinity; splits.Add(curSplit, dominatingClass); break; } curRow++; //intervals exclude end curSplit = inputVariableValues[curRow]; splits.Add(curSplit, dominatingClass); //intervals include start curSplit = inputVariableValues[curRow]; classCount = PrepareClassCountDictionary(classValues); newBucket = true; } else { curSplit = inputVariableValues[curRow]; classCount[classValuesInDataset[curRow]] += 1; curRow = SetCurRowToEndOfSplit(curRow, inputVariableValues, classValuesInDataset, classCount, curSplit); } } if (!done) { curSplit = Double.PositiveInfinity; double randomClass = GetRandomMaxClass(classCount, random); splits.Add(curSplit, randomClass); curCorrectClassified += classCount[randomClass]; } if (curCorrectClassified > correctClassified) { bestVariable = inputVariables[variable]; bestSplits = splits; } } //merge intervals to simplify symbolic expression tree Dictionary mergedSplits = MergeSplits(bestSplits); var model = new OneRClassificationModel(bestVariable, mergedSplits.Keys.ToArray(), mergedSplits.Values.ToArray(), missingValuesClass); var solution = new OneRClassificationSolution(model, (IClassificationProblemData)problemData.Clone()); return solution; } private static double GetRandomMaxClass(Dictionary classCount, IRandom random) { IList possibleClasses = new List(); int max = 0; foreach (var item in classCount) { if (max < item.Value) { max = item.Value; possibleClasses = new List(); possibleClasses.Add(item.Key); } else if (max == item.Value) { possibleClasses.Add(item.Key); } } int classindex = random.Next(possibleClasses.Count); return possibleClasses[classindex]; } private static bool IsNextSplitStillDominatingClass(int curRow, double[] inputVariableValues, double[] classValuesInDataset, double curSplit, double dominatingClass) { if (curRow >= classValuesInDataset.Length) { return false; } double nextSplit = inputVariableValues[curRow + 1]; int i = 1; while (curRow + i < classValuesInDataset.Length && inputVariableValues[curRow + i] == nextSplit && classValuesInDataset[curRow + i] == dominatingClass) { i++; } if (curRow + i >= classValuesInDataset.Length) { return true; } if (inputVariableValues[curRow + i] != nextSplit) { return true; } // the next split would also contain values of a class which // is not dominating (classValuesInDataset[curRow + i] != dominatingClass) return false; } // needed if variable contains the same value several times private static int SetCurRowToEndOfSplit(int curRow, double[] inputVariableValues, double[] classValuesInDataset, Dictionary classCount, double curSplit) { while (curRow + 1 < inputVariableValues.Length && inputVariableValues[curRow + 1] == curSplit) { curRow++; classCount[classValuesInDataset[curRow]] += 1; } return curRow; } private static Dictionary MergeSplits(Dictionary bestSplits) { Dictionary mergedSplits = new Dictionary(); double nextSplit, nextClass; nextSplit = nextClass = double.NaN; foreach (var item in bestSplits) { if (Double.IsNaN(nextSplit)) { nextSplit = item.Key; nextClass = item.Value; } else { if (nextClass == item.Value) { nextSplit = item.Key; } else { mergedSplits.Add(nextSplit, nextClass); nextSplit = item.Key; nextClass = item.Value; } } } mergedSplits.Add(nextSplit, nextClass); return mergedSplits; } private static bool ExistsDominatingClass(Dictionary classCount, out double dominatingClass) { bool dominating = false; int count = 0; dominatingClass = double.NaN; foreach (var item in classCount) { if (item.Value > count) { dominatingClass = item.Key; count = item.Value; dominating = true; } else if (item.Value == count) { dominatingClass = double.NaN; dominating = false; } } return dominating; } private static Dictionary PrepareClassCountDictionary(double[] classValues) { Dictionary classCount = new Dictionary(); for (int i = 0; i < classValues.Length; i++) { classCount[classValues[i]] = 0; } return classCount; } } }