Changeset 9002
- Timestamp:
- 12/05/12 17:03:53 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicNearestNeighbourClassificationModel.cs
r8979 r9002 39 39 private int k; 40 40 [Storable] 41 private List<KeyValuePair<double, double>> trainedTargetPair;41 private List<KeyValuePair<double, Dictionary<double, int>>> trainedTargetPair; 42 42 [Storable] 43 43 private ClassFrequencyComparer frequencyComparer; … … 48 48 : base(original, cloner) { 49 49 k = original.k; 50 trainedTargetPair = new List<KeyValuePair<double, double>>(original.trainedTargetPair);50 trainedTargetPair = original.trainedTargetPair.Select(x => new KeyValuePair<double, Dictionary<double, int>>(x.Key, new Dictionary<double, int>(x.Value))).ToList(); 51 51 frequencyComparer = new ClassFrequencyComparer(original.frequencyComparer); 52 52 } … … 54 54 : base(tree, interpreter, lowerEstimationLimit, upperEstimationLimit) { 55 55 this.k = k; 56 t his.trainedTargetPair = new List<KeyValuePair<double, double>>();56 trainedTargetPair = new List<KeyValuePair<double, Dictionary<double, int>>>(); 57 57 frequencyComparer = new ClassFrequencyComparer(); 58 58 } … … 60 60 public override IDeepCloneable Clone(Cloner cloner) { 61 61 return new SymbolicNearestNeighbourClassificationModel(this, cloner); 62 }63 64 [StorableHook(HookType.AfterDeserialization)]65 private void AfterDeserialization() {66 if (frequencyComparer == null) {67 var dict = trainedTargetPair68 .GroupBy(x => x.Value)69 .ToDictionary(x => x.Key, y => y.Count());70 frequencyComparer = new ClassFrequencyComparer(dict);71 }72 62 } 73 63 … … 78 68 foreach (var ev in estimatedValues) { 79 69 // find the index of the training-point to which distance is shortest 80 var upper = trainedTargetPair.BinarySearch(0, trainedTargetPair.Count, new KeyValuePair<double, double>(ev, double.NaN), new KeyValuePairKeyComparer());70 var upper = trainedTargetPair.BinarySearch(0, trainedTargetPair.Count, new KeyValuePair<double, Dictionary<double, int>>(ev, null), new KeyValuePairKeyComparer<Dictionary<double, int>>()); 81 71 if (upper < 0) upper = ~upper; // if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item 82 72 var lower = upper - 1; … … 84 74 // continue to the left and right of this index and look for the nearest neighbors 85 75 for (int i = 0; i < Math.Min(k, trainedTargetPair.Count); i++) { 86 if (upper >= trainedTargetPair.Count || (lower > 0 && ev - trainedTargetPair[lower].Key < trainedTargetPair[upper].Key - ev)) { 76 bool lowerIsCloser = upper >= trainedTargetPair.Count || (lower >= 0 && ev - trainedTargetPair[lower].Key <= trainedTargetPair[upper].Key - ev); 77 bool upperIsCloser = lower < 0 || (upper < trainedTargetPair.Count && ev - trainedTargetPair[lower].Key >= trainedTargetPair[upper].Key - ev); 78 if (lowerIsCloser) { 87 79 // the nearer neighbor is to the left 88 var lowerClass = trainedTargetPair[lower].Value; 89 if (!neighborClasses.ContainsKey(lowerClass)) neighborClasses[lowerClass] = 1; 90 else neighborClasses[lowerClass]++; 80 var lowerClassSamples = trainedTargetPair[lower].Value.Select(x => x.Value).Sum(); // should be 1, except when multiple samples are estimated the same value 81 var lowerClasses = trainedTargetPair[lower].Value; 82 foreach (var lowerClass in lowerClasses) { 83 if (!neighborClasses.ContainsKey(lowerClass.Key)) neighborClasses[lowerClass.Key] = lowerClass.Value; 84 else neighborClasses[lowerClass.Key] += lowerClass.Value; 85 } 91 86 lower--; 92 } else { 87 i += (lowerClassSamples - 1); 88 } 89 // they could, in very rare cases, be equally far apart 90 if (upperIsCloser) { 93 91 // the nearer neighbor is to the right 94 var upperClass = trainedTargetPair[upper].Value; 95 if (!neighborClasses.ContainsKey(upperClass)) neighborClasses[upperClass] = 1; 96 else neighborClasses[upperClass]++; 92 var upperClassSamples = trainedTargetPair[upper].Value.Select(x => x.Value).Sum(); // should be 1, except when multiple samples are estimated the same value 93 var upperClasses = trainedTargetPair[upper].Value; 94 foreach (var upperClass in upperClasses) { 95 if (!neighborClasses.ContainsKey(upperClass.Key)) neighborClasses[upperClass.Key] = upperClass.Value; 96 else neighborClasses[upperClass.Key] += upperClass.Value; 97 } 97 98 upper++; 99 i += (upperClassSamples - 1); 98 100 } 99 101 } … … 113 115 var classFrequencies = new Dictionary<double, int>(); 114 116 foreach (var p in pair) { 117 // get target values for each estimated value (foreach estimated value there may be different target values in different amounts) 115 118 if (!dict.ContainsKey(p.Estimated)) dict[p.Estimated] = new Dictionary<double, int>(); 116 119 if (!dict[p.Estimated].ContainsKey(p.Target)) dict[p.Estimated][p.Target] = 0; 117 120 dict[p.Estimated][p.Target]++; 118 121 // get class frequencies 119 122 if (!classFrequencies.ContainsKey(p.Target)) 120 123 classFrequencies[p.Target] = 1; … … 124 127 frequencyComparer = new ClassFrequencyComparer(classFrequencies); 125 128 126 trainedTargetPair = new List<KeyValuePair<double, double>>();129 trainedTargetPair = new List<KeyValuePair<double, Dictionary<double, int>>>(); 127 130 foreach (var ev in dict) { 128 var target = ev.Value.MaxItems(x => x.Value).OrderByDescending(x => x.Key, frequencyComparer).First().Key; 129 trainedTargetPair.Add(new KeyValuePair<double, double>(ev.Key, target)); 131 trainedTargetPair.Add(new KeyValuePair<double, Dictionary<double, int>>(ev.Key, ev.Value)); 130 132 } 131 133 trainedTargetPair = trainedTargetPair.OrderBy(x => x.Key).ToList(); … … 133 135 134 136 public override ISymbolicClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 135 return new SymbolicClassificationSolution((ISymbolicClassificationModel) this.Clone(), problemData);137 return new SymbolicClassificationSolution((ISymbolicClassificationModel)Clone(), problemData); 136 138 } 137 139 } 138 140 139 internal class KeyValuePairKeyComparer : IComparer<KeyValuePair<double, double>> {140 public int Compare(KeyValuePair<double, double> x, KeyValuePair<double, double> y) {141 internal class KeyValuePairKeyComparer<T> : IComparer<KeyValuePair<double, T>> { 142 public int Compare(KeyValuePair<double, T> x, KeyValuePair<double, T> y) { 141 143 return x.Key.CompareTo(y.Key); 142 144 } … … 146 148 internal sealed class ClassFrequencyComparer : IComparer<double> { 147 149 [Storable] 148 private Dictionary<double, int> classFrequencies;150 private readonly Dictionary<double, int> classFrequencies; 149 151 150 152 [StorableConstructor]
Note: See TracChangeset
for help on using the changeset viewer.