source: branches/ClassificationModelComparison/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/OneR.cs @ 10569

Last change on this file since 10569 was 10569, checked in by mkommend, 8 years ago

#1998: Reimplemented OneR classification algorithm.

File size: 11.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  /// <summary>
36  /// 1R classification algorithm.
37  /// </summary>
38  [Item("OneR", "1R classification algorithm.")]
39  [Creatable("Data Analysis")]
40  [StorableClass]
41  public sealed class OneR : FixedDataAnalysisAlgorithm<IClassificationProblem> {
42
43    public IValueParameter<IntValue> MinBucketSizeParameter {
44      get { return (IValueParameter<IntValue>)Parameters["MinBucketSize"]; }
45    }
46    public IValueParameter<IRandom> RandomParameter {
47      get { return (IValueParameter<IRandom>)Parameters["Random"]; }
48    }
49
50    [StorableConstructor]
51    private OneR(bool deserializing) : base(deserializing) { }
52    private OneR(OneR original, Cloner cloner)
53      : base(original, cloner) {
54    }
55    public OneR()
56      : base() {
57      Parameters.Add(new ValueParameter<IntValue>("MinBucketSize", "Minimum size of a bucket for numerical values. (Except for the rightmost bucket)", new IntValue(6)));
58      Parameters.Add(new ValueParameter<IRandom>("Random", "Random number generator", new FastRandom()));
59      Problem = new ClassificationProblem();
60    }
61
62    public override IDeepCloneable Clone(Cloner cloner) {
63      return new OneR(this, cloner);
64    }
65
66    protected override void Run() {
67      var startTime = DateTime.Now;
68      var solution = CreateOneRSolution(Problem.ProblemData, MinBucketSizeParameter.Value.Value, RandomParameter.Value);
69      Results.Add(new Result("OneR solution", "The 1R classifier.", solution));
70      Results.Add(new Result("OneR Execution Time", "", new TimeSpanValue(DateTime.Now - startTime)));
71
72      startTime = DateTime.Now;
73      var solution3 = OneRTest.CreateOneRSolution(Problem.ProblemData, MinBucketSizeParameter.Value.Value);
74      Results.Add(new Result("OneR Test2 solution", "The 1R classifier.", solution3));
75      Results.Add(new Result("OneR Test2 Execution", "", new TimeSpanValue(DateTime.Now - startTime)));
76    }
77
78    public static IClassificationSolution CreateOneRSolution(IClassificationProblemData problemData, int minBucketSize, IRandom random) {
79      Dataset dataset = problemData.Dataset;
80      var trainingIndices = problemData.TrainingIndices;
81      int rowCount = trainingIndices.Count();
82      string target = problemData.TargetVariable;
83      var inputVariables = problemData.AllowedInputVariables.ToArray();
84      var classValues = problemData.ClassValues.ToArray();
85      double dominatingClass;
86
87      string bestVariable = null;
88      Dictionary<double, double> bestSplits = null;
89      double missingValuesClass = double.NaN;
90      int correctClassified = 0;
91
92      for (int variable = 0; variable < inputVariables.Length; variable++) {
93        var inputVariableValues = dataset.GetDoubleValues(inputVariables[variable], trainingIndices).ToArray();
94        var classValuesInDataset = dataset.GetDoubleValues(target, trainingIndices).ToArray();
95
96        int curCorrectClassified = 0;
97        Dictionary<double, int> classCount = PrepareClassCountDictionary(classValues);
98        Array.Sort(inputVariableValues, classValuesInDataset);
99        double curSplit = Double.NegativeInfinity;
100        Dictionary<double, double> splits = new Dictionary<double, double>();
101        bool newBucket = true;
102        bool done = false;
103        int curRow = 0;
104
105        if (curRow < inputVariableValues.Length && Double.IsNaN(inputVariableValues[curRow])) {
106          while (curRow < inputVariableValues.Length && Double.IsNaN(inputVariableValues[curRow])) {
107            classCount[classValuesInDataset[curRow]] += 1;
108            curRow++;
109          }
110          if (ExistsDominatingClass(classCount, out dominatingClass)) {
111            missingValuesClass = dominatingClass;
112          } else {
113            missingValuesClass = GetRandomMaxClass(classCount, random);
114          }
115          correctClassified += classCount[missingValuesClass];
116          classCount = PrepareClassCountDictionary(classValues);
117        }
118        while (curRow < inputVariableValues.Length) {
119          if (newBucket) {
120            for (int i = 0; i < minBucketSize && curRow + i < inputVariableValues.Length; i++) {
121              classCount[classValuesInDataset[curRow + i]] += 1;
122            }
123            curRow += minBucketSize;
124            if (curRow >= inputVariableValues.Length) {
125              break;
126            }
127            curSplit = inputVariableValues[curRow];
128            curRow = SetCurRowToEndOfSplit(curRow, inputVariableValues, classValuesInDataset, classCount, curSplit);
129            newBucket = false;
130          }
131
132          if (ExistsDominatingClass(classCount, out dominatingClass)) {
133            while (curRow + 1 < classValuesInDataset.Length &&
134              IsNextSplitStillDominatingClass(curRow, inputVariableValues, classValuesInDataset, curSplit, dominatingClass)) {
135              curRow++;
136              curSplit = inputVariableValues[curRow];
137              classCount[classValuesInDataset[curRow]] += 1;
138              curRow = SetCurRowToEndOfSplit(curRow, inputVariableValues, classValuesInDataset, classCount, curSplit);
139            }
140
141            curCorrectClassified += classCount[dominatingClass];
142            done = curRow >= inputVariableValues.Length - 1;
143
144            if (done) {
145              curSplit = Double.PositiveInfinity;
146              splits.Add(curSplit, dominatingClass);
147              break;
148            }
149
150            curRow++;
151            //intervals exclude end
152            curSplit = inputVariableValues[curRow];
153            splits.Add(curSplit, dominatingClass);
154
155            //intervals include start
156            curSplit = inputVariableValues[curRow];
157            classCount = PrepareClassCountDictionary(classValues);
158            newBucket = true;
159          } else {
160            curSplit = inputVariableValues[curRow];
161            classCount[classValuesInDataset[curRow]] += 1;
162            curRow = SetCurRowToEndOfSplit(curRow, inputVariableValues, classValuesInDataset, classCount, curSplit);
163          }
164        }
165
166        if (!done) {
167          curSplit = Double.PositiveInfinity;
168          double randomClass = GetRandomMaxClass(classCount, random);
169          splits.Add(curSplit, randomClass);
170
171          curCorrectClassified += classCount[randomClass];
172        }
173
174        if (curCorrectClassified > correctClassified) {
175          bestVariable = inputVariables[variable];
176          bestSplits = splits;
177        }
178      }
179
180      //merge intervals to simplify symbolic expression tree
181      Dictionary<double, double> mergedSplits = MergeSplits(bestSplits);
182
183      var model = new OneRClassificationModel(bestVariable, mergedSplits.Keys.ToArray(), mergedSplits.Values.ToArray(), missingValuesClass);
184      var solution = new OneRClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
185
186      return solution;
187    }
188
189    private static double GetRandomMaxClass(Dictionary<double, int> classCount, IRandom random) {
190      IList<double> possibleClasses = new List<double>();
191      int max = 0;
192      foreach (var item in classCount) {
193        if (max < item.Value) {
194          max = item.Value;
195          possibleClasses = new List<double>();
196          possibleClasses.Add(item.Key);
197        } else if (max == item.Value) {
198          possibleClasses.Add(item.Key);
199        }
200      }
201      int classindex = random.Next(possibleClasses.Count);
202      return possibleClasses[classindex];
203    }
204
205    private static bool IsNextSplitStillDominatingClass(int curRow, double[] inputVariableValues, double[] classValuesInDataset, double curSplit, double dominatingClass) {
206      if (curRow >= classValuesInDataset.Length) {
207        return false;
208      }
209      double nextSplit = inputVariableValues[curRow + 1];
210      int i = 1;
211      while (curRow + i < classValuesInDataset.Length
212        && inputVariableValues[curRow + i] == nextSplit
213        && classValuesInDataset[curRow + i] == dominatingClass) {
214        i++;
215      }
216      if (curRow + i >= classValuesInDataset.Length) {
217        return true;
218      }
219      if (inputVariableValues[curRow + i] != nextSplit) {
220        return true;
221      }
222      // the next split would also contain values of a class which
223      // is not dominating (classValuesInDataset[curRow + i] != dominatingClass)
224      return false;
225    }
226
227    // needed if variable contains the same value several times
228    private static int SetCurRowToEndOfSplit(int curRow, double[] inputVariableValues, double[] classValuesInDataset, Dictionary<double, int> classCount, double curSplit) {
229      while (curRow + 1 < inputVariableValues.Length && inputVariableValues[curRow + 1] == curSplit) {
230        curRow++;
231        classCount[classValuesInDataset[curRow]] += 1;
232      }
233      return curRow;
234    }
235
236    private static Dictionary<double, double> MergeSplits(Dictionary<double, double> bestSplits) {
237      Dictionary<double, double> mergedSplits = new Dictionary<double, double>();
238      double nextSplit, nextClass;
239      nextSplit = nextClass = double.NaN;
240      foreach (var item in bestSplits) {
241        if (Double.IsNaN(nextSplit)) {
242          nextSplit = item.Key;
243          nextClass = item.Value;
244        } else {
245          if (nextClass == item.Value) {
246            nextSplit = item.Key;
247          } else {
248            mergedSplits.Add(nextSplit, nextClass);
249            nextSplit = item.Key;
250            nextClass = item.Value;
251          }
252        }
253      }
254      mergedSplits.Add(nextSplit, nextClass);
255      return mergedSplits;
256    }
257
258    private static bool ExistsDominatingClass(Dictionary<double, int> classCount, out double dominatingClass) {
259      bool dominating = false;
260      int count = 0;
261      dominatingClass = double.NaN;
262      foreach (var item in classCount) {
263        if (item.Value > count) {
264          dominatingClass = item.Key;
265          count = item.Value;
266          dominating = true;
267        } else if (item.Value == count) {
268          dominatingClass = double.NaN;
269          dominating = false;
270        }
271      }
272      return dominating;
273    }
274
275    private static Dictionary<double, int> PrepareClassCountDictionary(double[] classValues) {
276      Dictionary<double, int> classCount = new Dictionary<double, int>();
277      for (int i = 0; i < classValues.Length; i++) {
278        classCount[classValues[i]] = 0;
279      }
280      return classCount;
281    }
282  }
283}
Note: See TracBrowser for help on using the repository browser.