- Timestamp:
- 12/05/12 18:54:25 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicNearestNeighbourClassificationModel.cs
r9002 r9003 39 39 private int k; 40 40 [Storable] 41 private List<KeyValuePair<double, Dictionary<double, int>>> trainedTargetPair; 41 private List<double> trainedClasses; 42 [Storable] 43 private List<double> trainedEstimatedValues; 44 42 45 [Storable] 43 46 private ClassFrequencyComparer frequencyComparer; … … 48 51 : base(original, cloner) { 49 52 k = original.k; 50 trainedTargetPair = original.trainedTargetPair.Select(x => new KeyValuePair<double, Dictionary<double, int>>(x.Key, new Dictionary<double, int>(x.Value))).ToList();51 53 frequencyComparer = new ClassFrequencyComparer(original.frequencyComparer); 54 trainedEstimatedValues = new List<double>(original.trainedEstimatedValues); 55 trainedClasses = new List<double>(original.trainedClasses); 52 56 } 53 57 public SymbolicNearestNeighbourClassificationModel(int k, ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue) 54 58 : base(tree, interpreter, lowerEstimationLimit, upperEstimationLimit) { 55 59 this.k = k; 56 trainedTargetPair = new List<KeyValuePair<double, Dictionary<double, int>>>();57 60 frequencyComparer = new ClassFrequencyComparer(); 61 58 62 } 59 63 … … 65 69 var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows) 66 70 .LimitToRange(LowerEstimationLimit, UpperEstimationLimit); 67 var neighborClasses = new Dictionary<double, int>();68 71 foreach (var ev in estimatedValues) { 72 // find the range [lower, upper[ of trainedTargetValues that contains the k closest neighbours 73 // the range can span more than k elements when there are equal estimated values 74 69 75 // find the index of the training-point to which distance is shortest 70 var upper = trainedTargetPair.BinarySearch(0, trainedTargetPair.Count, new KeyValuePair<double, Dictionary<double, int>>(ev, null), new KeyValuePairKeyComparer<Dictionary<double, int>>()); 71 if (upper < 0) upper = ~upper; // if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item 72 var lower = upper - 1; 73 neighborClasses.Clear(); 74 // continue to the left and right of this index and look for the nearest neighbors 75 for (int i = 0; i < Math.Min(k, trainedTargetPair.Count); i++) { 76 bool lowerIsCloser = upper >= trainedTargetPair.Count || (lower >= 0 && ev - trainedTargetPair[lower].Key <= trainedTargetPair[upper].Key - ev); 77 bool upperIsCloser = lower < 0 || (upper < trainedTargetPair.Count && ev - trainedTargetPair[lower].Key >= trainedTargetPair[upper].Key - ev); 76 int lower = trainedEstimatedValues.BinarySearch(ev); 77 int upper; 78 // if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item 79 if (lower < 0) { 80 lower = ~lower; 81 // lower is not necessarily the closer one 82 // determine which element is closer to ev (lower - 1) or (lower) 83 if (lower == trainedEstimatedValues.Count || 84 (lower > 0 && Math.Abs(ev - trainedEstimatedValues[lower - 1]) < Math.Abs(ev - trainedEstimatedValues[lower]))) { 85 lower = lower - 1; 86 } 87 } 88 upper = lower + 1; 89 // at this point we have a range [lower, upper[ that includes only the closest element to ev 90 91 // expand the range to left or right looking for the nearest neighbors 92 while (upper - lower < Math.Min(k, trainedEstimatedValues.Count)) { 93 bool lowerIsCloser = upper >= trainedEstimatedValues.Count || 94 (lower > 0 && ev - trainedEstimatedValues[lower] <= trainedEstimatedValues[upper] - ev); 95 bool upperIsCloser = lower <= 0 || 96 (upper < trainedEstimatedValues.Count && 97 ev - trainedEstimatedValues[lower] >= trainedEstimatedValues[upper] - ev); 98 if (!lowerIsCloser && !upperIsCloser) break; 78 99 if (lowerIsCloser) { 79 // the nearer neighbor is to the left80 var lowerClassSamples = trainedTargetPair[lower].Value.Select(x => x.Value).Sum(); // should be 1, except when multiple samples are estimated the same value81 var lowerClasses = trainedTargetPair[lower].Value;82 foreach (var lowerClass in lowerClasses) {83 if (!neighborClasses.ContainsKey(lowerClass.Key)) neighborClasses[lowerClass.Key] = lowerClass.Value;84 else neighborClasses[lowerClass.Key] += lowerClass.Value;85 }86 100 lower--; 87 i += (lowerClassSamples - 1); 101 // eat up all equal values 102 while (lower > 0 && trainedEstimatedValues[lower - 1].IsAlmost(trainedEstimatedValues[lower])) 103 lower--; 88 104 } 89 // they could, in very rare cases, be equally far apart90 105 if (upperIsCloser) { 91 // the nearer neighbor is to the right92 var upperClassSamples = trainedTargetPair[upper].Value.Select(x => x.Value).Sum(); // should be 1, except when multiple samples are estimated the same value93 var upperClasses = trainedTargetPair[upper].Value;94 foreach (var upperClass in upperClasses) {95 if (!neighborClasses.ContainsKey(upperClass.Key)) neighborClasses[upperClass.Key] = upperClass.Value;96 else neighborClasses[upperClass.Key] += upperClass.Value;97 }98 106 upper++; 99 i += (upperClassSamples - 1); 107 while (upper < trainedEstimatedValues.Count && 108 trainedEstimatedValues[upper - 1].IsAlmost(trainedEstimatedValues[upper])) 109 upper++; 100 110 } 101 111 } 102 112 // majority voting with preference for bigger class in case of tie 103 yield return neighborClasses.MaxItems(x => x.Value).OrderByDescending(x => x.Key, frequencyComparer).First().Key; 113 yield return Enumerable.Range(lower, upper - lower) 114 .Select(i => trainedClasses[i]) 115 .GroupBy(c => c) 116 .Select(g => new { Class = g.Key, Votes = g.Count() }) 117 .MaxItems(p => p.Votes) 118 .OrderByDescending(m => m.Class, frequencyComparer) 119 .First().Class; 104 120 } 105 121 } … … 109 125 .LimitToRange(LowerEstimationLimit, UpperEstimationLimit); 110 126 var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 111 var pair = estimatedValues.Zip(targetValues, (e, t) => new { Estimated = e, Target = t }); 127 var trainedClasses = targetValues.ToArray(); 128 var trainedEstimatedValues = estimatedValues.ToArray(); 112 129 113 // there could be more than one target value per estimated value 114 var dict = new Dictionary<double, Dictionary<double, int>>(); 115 var classFrequencies = new Dictionary<double, int>(); 116 foreach (var p in pair) { 117 // get target values for each estimated value (foreach estimated value there may be different target values in different amounts) 118 if (!dict.ContainsKey(p.Estimated)) dict[p.Estimated] = new Dictionary<double, int>(); 119 if (!dict[p.Estimated].ContainsKey(p.Target)) dict[p.Estimated][p.Target] = 0; 120 dict[p.Estimated][p.Target]++; 121 // get class frequencies 122 if (!classFrequencies.ContainsKey(p.Target)) 123 classFrequencies[p.Target] = 1; 124 else classFrequencies[p.Target]++; 125 } 130 Array.Sort(trainedEstimatedValues, trainedClasses); 131 this.trainedClasses = new List<double>(trainedClasses); 132 this.trainedEstimatedValues = new List<double>(trainedEstimatedValues); 126 133 127 frequencyComparer = new ClassFrequencyComparer(classFrequencies); 128 129 trainedTargetPair = new List<KeyValuePair<double, Dictionary<double, int>>>(); 130 foreach (var ev in dict) { 131 trainedTargetPair.Add(new KeyValuePair<double, Dictionary<double, int>>(ev.Key, ev.Value)); 132 } 133 trainedTargetPair = trainedTargetPair.OrderBy(x => x.Key).ToList(); 134 var freq = trainedClasses 135 .GroupBy(c => c) 136 .ToDictionary(g => g.Key, g => g.Count()); 137 this.frequencyComparer = new ClassFrequencyComparer(freq); 134 138 } 135 139 136 140 public override ISymbolicClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 137 141 return new SymbolicClassificationSolution((ISymbolicClassificationModel)Clone(), problemData); 138 }139 }140 141 internal class KeyValuePairKeyComparer<T> : IComparer<KeyValuePair<double, T>> {142 public int Compare(KeyValuePair<double, T> x, KeyValuePair<double, T> y) {143 return x.Key.CompareTo(y.Key);144 142 } 145 143 }
Note: See TracChangeset
for help on using the changeset viewer.