- Timestamp:
- 11/30/12 12:13:26 (12 years ago)
- Location:
- trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/ModelCreators/NearestNeighborModelCreator.cs
r8606 r8978 48 48 public NearestNeighborModelCreator() 49 49 : base() { 50 Parameters.Add(new FixedValueParameter<IntValue>("K", "The number of neighbours to use to determine the class.", new IntValue( 3)));50 Parameters.Add(new FixedValueParameter<IntValue>("K", "The number of neighbours to use to determine the class.", new IntValue(11))); 51 51 } 52 52 -
trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicNearestNeighbourClassificationModel.cs
r8606 r8978 40 40 [Storable] 41 41 private List<KeyValuePair<double, double>> trainedTargetPair; 42 [Storable] 43 private ClassFrequencyComparer frequencyComparer; 42 44 43 45 [StorableConstructor] … … 47 49 k = original.k; 48 50 trainedTargetPair = new List<KeyValuePair<double, double>>(original.trainedTargetPair); 51 frequencyComparer = new ClassFrequencyComparer(original.frequencyComparer); 49 52 } 50 53 public SymbolicNearestNeighbourClassificationModel(int k, ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue) … … 52 55 this.k = k; 53 56 this.trainedTargetPair = new List<KeyValuePair<double, double>>(); 57 frequencyComparer = new ClassFrequencyComparer(); 54 58 } 55 59 … … 58 62 } 59 63 64 [StorableHook(HookType.AfterDeserialization)] 65 private void AfterDeserialization() { 66 if (frequencyComparer == null) { 67 var dict = trainedTargetPair 68 .GroupBy(x => x.Value) 69 .ToDictionary(x => x.Key, y => y.Count()); 70 frequencyComparer = new ClassFrequencyComparer(dict); 71 } 72 } 73 60 74 public override IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 61 var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows); 62 var neighbors = new Dictionary<double, int>(); 75 var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows) 76 .LimitToRange(LowerEstimationLimit, UpperEstimationLimit); 77 var neighborClasses = new Dictionary<double, int>(); 63 78 foreach (var ev in estimatedValues) { 64 int lower = 0, upper = 1; 65 double sdist = Math.Abs(ev - trainedTargetPair[0].Key); 66 for (int i = 1; i < trainedTargetPair.Count; i++) { 67 double d = Math.Abs(ev - trainedTargetPair[i].Key); 68 if (d > sdist) break; 69 lower = i; 70 upper = i + 1; 71 sdist = d; 72 } 73 neighbors.Clear(); 74 neighbors[trainedTargetPair[lower].Value] = 1; 75 lower--; 76 for (int i = 1; i < Math.Min(k, trainedTargetPair.Count); i++) { 79 // find the index of the training-point to which distance is shortest 80 var upper = trainedTargetPair.BinarySearch(0, trainedTargetPair.Count, new KeyValuePair<double, double>(ev, double.NaN), new KeyValuePairKeyComparer()); 81 if (upper < 0) upper = ~upper; // if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item 82 var lower = upper - 1; 83 neighborClasses.Clear(); 84 // continue to the left and right of this index and look for the nearest neighbors 85 for (int i = 0; i < Math.Min(k, trainedTargetPair.Count); i++) { 77 86 if (upper >= trainedTargetPair.Count || (lower > 0 && ev - trainedTargetPair[lower].Key < trainedTargetPair[upper].Key - ev)) { 78 if (!neighbors.ContainsKey(trainedTargetPair[lower].Value)) 79 neighbors[trainedTargetPair[lower].Value] = 1; 80 else neighbors[trainedTargetPair[lower].Value]++; 87 // the nearer neighbor is to the left 88 var lowerClass = trainedTargetPair[lower].Value; 89 if (!neighborClasses.ContainsKey(lowerClass)) neighborClasses[lowerClass] = 1; 90 else neighborClasses[lowerClass]++; 81 91 lower--; 82 92 } else { 83 if (!neighbors.ContainsKey(trainedTargetPair[upper].Value)) 84 neighbors[trainedTargetPair[upper].Value] = 1; 85 else neighbors[trainedTargetPair[upper].Value]++; 93 // the nearer neighbor is to the right 94 var upperClass = trainedTargetPair[upper].Value; 95 if (!neighborClasses.ContainsKey(upperClass)) neighborClasses[upperClass] = 1; 96 else neighborClasses[upperClass]++; 86 97 upper++; 87 98 } 88 99 } 89 yield return neighbors.MaxItems(x => x.Value).First().Key; 100 // majority voting with preference for bigger class in case of tie 101 yield return neighborClasses.MaxItems(x => x.Value).OrderByDescending(x => x.Key, frequencyComparer).First().Key; 90 102 } 91 103 } 92 104 93 105 public override void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) { 94 var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, problemData.Dataset, rows); 106 var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, problemData.Dataset, rows) 107 .LimitToRange(LowerEstimationLimit, UpperEstimationLimit); 95 108 var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 96 109 var pair = estimatedValues.Zip(targetValues, (e, t) => new { Estimated = e, Target = t }); … … 98 111 // there could be more than one target value per estimated value 99 112 var dict = new Dictionary<double, Dictionary<double, int>>(); 113 var classFrequencies = new Dictionary<double, int>(); 100 114 foreach (var p in pair) { 101 115 if (!dict.ContainsKey(p.Estimated)) dict[p.Estimated] = new Dictionary<double, int>(); 102 116 if (!dict[p.Estimated].ContainsKey(p.Target)) dict[p.Estimated][p.Target] = 0; 103 117 dict[p.Estimated][p.Target]++; 118 119 if (!classFrequencies.ContainsKey(p.Target)) 120 classFrequencies[p.Target] = 1; 121 else classFrequencies[p.Target]++; 104 122 } 123 124 frequencyComparer = new ClassFrequencyComparer(classFrequencies); 105 125 106 126 trainedTargetPair = new List<KeyValuePair<double, double>>(); 107 127 foreach (var ev in dict) { 108 var target = ev.Value.MaxItems(x => x.Value). First().Key;128 var target = ev.Value.MaxItems(x => x.Value).OrderByDescending(x => x.Key, frequencyComparer).First().Key; 109 129 trainedTargetPair.Add(new KeyValuePair<double, double>(ev.Key, target)); 110 130 } … … 116 136 } 117 137 } 138 139 internal class KeyValuePairKeyComparer : IComparer<KeyValuePair<double, double>> { 140 public int Compare(KeyValuePair<double, double> x, KeyValuePair<double, double> y) { 141 return x.Key.CompareTo(y.Key); 142 } 143 } 144 145 [StorableClass] 146 internal class ClassFrequencyComparer : IComparer<double> { 147 [Storable] 148 private Dictionary<double, int> classFrequencies; 149 150 [StorableConstructor] 151 private ClassFrequencyComparer(bool deserializing) { } 152 public ClassFrequencyComparer() { 153 classFrequencies = new Dictionary<double, int>(); 154 } 155 public ClassFrequencyComparer(Dictionary<double, int> frequencies) { 156 classFrequencies = frequencies; 157 } 158 public ClassFrequencyComparer(ClassFrequencyComparer original) { 159 classFrequencies = new Dictionary<double, int>(original.classFrequencies); 160 } 161 162 public int Compare(double x, double y) { 163 bool cx = classFrequencies.ContainsKey(x), cy = classFrequencies.ContainsKey(y); 164 if (cx && cy) 165 return classFrequencies[x].CompareTo(classFrequencies[y]); 166 if (cx) return 1; 167 return -1; 168 } 169 } 118 170 }
Note: See TracChangeset
for help on using the changeset viewer.