Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ClassificationModelComparison/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/OneR.cs @ 9674

Last change on this file since 9674 was 9135, checked in by sforsten, 12 years ago

#1998:

  • OneR handles missing values separately
  • adapted OneRClassificationModelView to show the class of missing values
  • with a double-click on the row header in ClassificationSolutionComparisonView the selected solution opens in a new view
  • put a try catch block around linear discriminant analysis solution (it is only shown, if it doesn't throw an exception)
File size: 11.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  /// <summary>
36  /// 1R classification algorithm.
37  /// </summary>
38  [Item("OneR", "1R classification algorithm.")]
39  [Creatable("Data Analysis")]
40  [StorableClass]
41  public sealed class OneR : FixedDataAnalysisAlgorithm<IClassificationProblem> {
42
43    public IValueParameter<IntValue> MinBucketSizeParameter {
44      get { return (IValueParameter<IntValue>)Parameters["MinBucketSize"]; }
45    }
46    public IValueParameter<IRandom> RandomParameter {
47      get { return (IValueParameter<IRandom>)Parameters["Random"]; }
48    }
49
50    [StorableConstructor]
51    private OneR(bool deserializing) : base(deserializing) { }
52    private OneR(OneR original, Cloner cloner)
53      : base(original, cloner) {
54    }
55    public OneR()
56      : base() {
57      Parameters.Add(new ValueParameter<IntValue>("MinBucketSize", "Minimum size of a bucket for numerical values. (Except for the rightmost bucket)", new IntValue(6)));
58      Parameters.Add(new ValueParameter<IRandom>("Random", "Random number generator", new FastRandom()));
59      Problem = new ClassificationProblem();
60    }
61
62    public override IDeepCloneable Clone(Cloner cloner) {
63      return new OneR(this, cloner);
64    }
65
66    protected override void Run() {
67      int blub = MinBucketSizeParameter.Value.Value;
68      Console.WriteLine(blub);
69      IRandom bla = RandomParameter.Value;
70      Console.WriteLine(bla);
71      var solution = CreateOneRSolution(Problem.ProblemData, MinBucketSizeParameter.Value.Value, RandomParameter.Value);
72      Results.Add(new Result("OneR solution", "The 1R classifier.", solution));
73    }
74
75    public static IClassificationSolution CreateOneRSolution(IClassificationProblemData problemData, int minBucketSize, IRandom random) {
76      Dataset dataset = problemData.Dataset;
77      var trainingIndices = problemData.TrainingIndices;
78      int rowCount = trainingIndices.Count();
79      string target = problemData.TargetVariable;
80      var inputVariables = problemData.AllowedInputVariables.ToArray();
81      var classValues = problemData.ClassValues.ToArray();
82      double dominatingClass;
83
84      string bestVariable = null;
85      Dictionary<double, double> bestSplits = null;
86      double missingValuesClass = double.NaN;
87      int correctClassified = 0;
88
89      for (int variable = 0; variable < inputVariables.Length; variable++) {
90        var inputVariableValues = dataset.GetDoubleValues(inputVariables[variable], trainingIndices).ToArray();
91        var classValuesInDataset = dataset.GetDoubleValues(target, trainingIndices).ToArray();
92
93        int curCorrectClassified = 0;
94        Dictionary<double, int> classCount = PrepareClassCountDictionary(classValues);
95        Array.Sort(inputVariableValues, classValuesInDataset);
96        double curSplit = Double.NegativeInfinity;
97        Dictionary<double, double> splits = new Dictionary<double, double>();
98        bool newBucket = true;
99        bool done = false;
100        int curRow = 0;
101
102        if (curRow < inputVariableValues.Length && Double.IsNaN(inputVariableValues[curRow])) {
103          while (curRow < inputVariableValues.Length && Double.IsNaN(inputVariableValues[curRow])) {
104            classCount[classValuesInDataset[curRow]] += 1;
105            curRow++;
106          }
107          if (ExistsDominatingClass(classCount, out dominatingClass)) {
108            missingValuesClass = dominatingClass;
109          } else {
110            missingValuesClass = GetRandomMaxClass(classCount, random);
111          }
112          correctClassified += classCount[missingValuesClass];
113          classCount = PrepareClassCountDictionary(classValues);
114        }
115        while (curRow < inputVariableValues.Length) {
116          if (newBucket) {
117            for (int i = 0; i < minBucketSize && curRow + i < inputVariableValues.Length; i++) {
118              classCount[classValuesInDataset[curRow + i]] += 1;
119            }
120            curRow += minBucketSize;
121            if (curRow >= inputVariableValues.Length) {
122              break;
123            }
124            curSplit = inputVariableValues[curRow];
125            curRow = SetCurRowToEndOfSplit(curRow, inputVariableValues, classValuesInDataset, classCount, curSplit);
126            newBucket = false;
127          }
128
129          if (ExistsDominatingClass(classCount, out dominatingClass)) {
130            while (curRow + 1 < classValuesInDataset.Length &&
131              IsNextSplitStillDominatingClass(curRow, inputVariableValues, classValuesInDataset, curSplit, dominatingClass)) {
132              curRow++;
133              curSplit = inputVariableValues[curRow];
134              classCount[classValuesInDataset[curRow]] += 1;
135              curRow = SetCurRowToEndOfSplit(curRow, inputVariableValues, classValuesInDataset, classCount, curSplit);
136            }
137
138            curCorrectClassified += classCount[dominatingClass];
139            done = curRow >= inputVariableValues.Length - 1;
140
141            if (done) {
142              curSplit = Double.PositiveInfinity;
143              splits.Add(curSplit, dominatingClass);
144              break;
145            }
146
147            curRow++;
148            //intervals exclude end
149            curSplit = inputVariableValues[curRow];
150            splits.Add(curSplit, dominatingClass);
151
152            //intervals include start
153            curSplit = inputVariableValues[curRow];
154            classCount = PrepareClassCountDictionary(classValues);
155            newBucket = true;
156          } else {
157            curSplit = inputVariableValues[curRow];
158            classCount[classValuesInDataset[curRow]] += 1;
159            curRow = SetCurRowToEndOfSplit(curRow, inputVariableValues, classValuesInDataset, classCount, curSplit);
160          }
161        }
162
163        if (!done) {
164          curSplit = Double.PositiveInfinity;
165          double randomClass = GetRandomMaxClass(classCount, random);
166          splits.Add(curSplit, randomClass);
167
168          curCorrectClassified += classCount[randomClass];
169        }
170
171        if (curCorrectClassified > correctClassified) {
172          bestVariable = inputVariables[variable];
173          bestSplits = splits;
174        }
175      }
176
177      //merge intervals to simplify symbolic expression tree
178      Dictionary<double, double> mergedSplits = MergeSplits(bestSplits);
179
180      var model = new OneRClassificationModel(bestVariable, mergedSplits.Keys.ToArray(), mergedSplits.Values.ToArray(), missingValuesClass);
181      var solution = new OneRClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
182
183      return solution;
184    }
185
186    private static double GetRandomMaxClass(Dictionary<double, int> classCount, IRandom random) {
187      IList<double> possibleClasses = new List<double>();
188      int max = 0;
189      foreach (var item in classCount) {
190        if (max < item.Value) {
191          max = item.Value;
192          possibleClasses = new List<double>();
193          possibleClasses.Add(item.Key);
194        } else if (max == item.Value) {
195          possibleClasses.Add(item.Key);
196        }
197      }
198      int classindex = random.Next(possibleClasses.Count);
199      return possibleClasses[classindex];
200    }
201
202    private static bool IsNextSplitStillDominatingClass(int curRow, double[] inputVariableValues, double[] classValuesInDataset, double curSplit, double dominatingClass) {
203      if (curRow >= classValuesInDataset.Length) {
204        return false;
205      }
206      double nextSplit = inputVariableValues[curRow + 1];
207      int i = 1;
208      while (curRow + i < classValuesInDataset.Length
209        && inputVariableValues[curRow + i] == nextSplit
210        && classValuesInDataset[curRow + i] == dominatingClass) {
211        i++;
212      }
213      if (curRow + i >= classValuesInDataset.Length) {
214        return true;
215      }
216      if (inputVariableValues[curRow + i] != nextSplit) {
217        return true;
218      }
219      // the next split would also contain values of a class which
220      // is not dominating (classValuesInDataset[curRow + i] != dominatingClass)
221      return false;
222    }
223
224    // needed if variable contains the same value several times
225    private static int SetCurRowToEndOfSplit(int curRow, double[] inputVariableValues, double[] classValuesInDataset, Dictionary<double, int> classCount, double curSplit) {
226      while (curRow + 1 < inputVariableValues.Length && inputVariableValues[curRow + 1] == curSplit) {
227        curRow++;
228        classCount[classValuesInDataset[curRow]] += 1;
229      }
230      return curRow;
231    }
232
233    private static Dictionary<double, double> MergeSplits(Dictionary<double, double> bestSplits) {
234      Dictionary<double, double> mergedSplits = new Dictionary<double, double>();
235      double nextSplit, nextClass;
236      nextSplit = nextClass = double.NaN;
237      foreach (var item in bestSplits) {
238        if (Double.IsNaN(nextSplit)) {
239          nextSplit = item.Key;
240          nextClass = item.Value;
241        } else {
242          if (nextClass == item.Value) {
243            nextSplit = item.Key;
244          } else {
245            mergedSplits.Add(nextSplit, nextClass);
246            nextSplit = item.Key;
247            nextClass = item.Value;
248          }
249        }
250      }
251      mergedSplits.Add(nextSplit, nextClass);
252      return mergedSplits;
253    }
254
255    private static bool ExistsDominatingClass(Dictionary<double, int> classCount, out double dominatingClass) {
256      bool dominating = false;
257      int count = 0;
258      dominatingClass = double.NaN;
259      foreach (var item in classCount) {
260        if (item.Value > count) {
261          dominatingClass = item.Key;
262          count = item.Value;
263          dominating = true;
264        } else if (item.Value == count) {
265          dominatingClass = double.NaN;
266          dominating = false;
267        }
268      }
269      return dominating;
270    }
271
272    private static Dictionary<double, int> PrepareClassCountDictionary(double[] classValues) {
273      Dictionary<double, int> classCount = new Dictionary<double, int>();
274      for (int i = 0; i < classValues.Length; i++) {
275        classCount[classValues[i]] = 0;
276      }
277      return classCount;
278    }
279  }
280}
Note: See TracBrowser for help on using the repository browser.