Changeset 17931 for trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs
- Timestamp:
- 04/09/21 19:41:33 (4 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs
r17180 r17931 32 32 /// Represents a nearest neighbour model for regression and classification 33 33 /// </summary> 34 [StorableType(" A76C0823-3077-4ACE-8A40-E9B717C7DB60")]34 [StorableType("04A07DF6-6EB5-4D29-B7AE-5BE204CAF6BC")] 35 35 [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")] 36 36 public sealed class NearestNeighbourModel : ClassificationModel, INearestNeighbourModel { 37 37 38 private readonly object kdTreeLockObject = new object(); 39 40 private alglib.nearestneighbor.kdtree kdTree; 41 public alglib.nearestneighbor.kdtree KDTree { 42 get { return kdTree; } 43 set { 44 if (value != kdTree) { 45 if (value == null) throw new ArgumentNullException(); 46 kdTree = value; 47 OnChanged(EventArgs.Empty); 48 } 49 } 38 private alglib.knnmodel model; 39 [Storable] 40 private string SerializedModel { 41 get { alglib.knnserialize(model, out var ser); return ser; } 42 set { if (value != null) alglib.knnunserialize(value, out model); } 50 43 } 51 44 … … 60 53 [Storable] 61 54 private int k; 62 [Storable(DefaultValue = false)] 63 private bool selfMatch; 64 [Storable(DefaultValue = null)] 65 private double[] weights; // not set for old versions loaded from disk 66 [Storable(DefaultValue = null)] 67 private double[] offsets; // not set for old versions loaded from disk 55 [Storable] 56 private double[] weights; 57 [Storable] 58 private double[] offsets; 68 59 69 60 [StorableConstructor] 70 private NearestNeighbourModel(StorableConstructorFlag _) : base(_) { 71 kdTree = new alglib.nearestneighbor.kdtree(); 72 } 61 private NearestNeighbourModel(StorableConstructorFlag _) : base(_) { } 73 62 private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner) 74 63 : base(original, cloner) { 75 kdTree = new alglib.nearestneighbor.kdtree(); 76 kdTree.approxf = original.kdTree.approxf; 77 kdTree.boxmax = (double[])original.kdTree.boxmax.Clone(); 78 kdTree.boxmin = (double[])original.kdTree.boxmin.Clone(); 79 kdTree.buf = (double[])original.kdTree.buf.Clone(); 80 kdTree.curboxmax = (double[])original.kdTree.curboxmax.Clone(); 81 kdTree.curboxmin = (double[])original.kdTree.curboxmin.Clone(); 82 kdTree.curdist = original.kdTree.curdist; 83 kdTree.debugcounter = original.kdTree.debugcounter; 84 kdTree.idx = (int[])original.kdTree.idx.Clone(); 85 kdTree.kcur = original.kdTree.kcur; 86 kdTree.kneeded = original.kdTree.kneeded; 87 kdTree.n = original.kdTree.n; 88 kdTree.nodes = (int[])original.kdTree.nodes.Clone(); 89 kdTree.normtype = original.kdTree.normtype; 90 kdTree.nx = original.kdTree.nx; 91 kdTree.ny = original.kdTree.ny; 92 kdTree.r = (double[])original.kdTree.r.Clone(); 93 kdTree.rneeded = original.kdTree.rneeded; 94 kdTree.selfmatch = original.kdTree.selfmatch; 95 kdTree.splits = (double[])original.kdTree.splits.Clone(); 96 kdTree.tags = (int[])original.kdTree.tags.Clone(); 97 kdTree.x = (double[])original.kdTree.x.Clone(); 98 kdTree.xy = (double[,])original.kdTree.xy.Clone(); 99 selfMatch = original.selfMatch; 64 if (original.model != null) 65 model = (alglib.knnmodel)original.model.make_copy(); 100 66 k = original.k; 101 isCompatibilityLoaded = original.IsCompatibilityLoaded; 102 if (!IsCompatibilityLoaded) { 103 weights = new double[original.weights.Length]; 104 Array.Copy(original.weights, weights, weights.Length); 105 offsets = new double[original.offsets.Length]; 106 Array.Copy(original.offsets, this.offsets, this.offsets.Length); 107 } 67 weights = new double[original.weights.Length]; 68 Array.Copy(original.weights, weights, weights.Length); 69 offsets = new double[original.offsets.Length]; 70 Array.Copy(original.offsets, this.offsets, this.offsets.Length); 71 108 72 allowedInputVariables = (string[])original.allowedInputVariables.Clone(); 109 73 if (original.classValues != null) 110 74 this.classValues = (double[])original.classValues.Clone(); 111 75 } 112 public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, bool selfMatch,string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)76 public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null) 113 77 : base(targetVariable) { 114 78 Name = ItemName; 115 79 Description = ItemDescription; 116 this.selfMatch = selfMatch;117 80 this.k = k; 118 81 this.allowedInputVariables = allowedInputVariables.ToArray(); 119 82 double[,] inputMatrix; 120 if (IsCompatibilityLoaded) { 121 // no scaling 122 inputMatrix = dataset.ToArray( 123 this.allowedInputVariables.Concat(new string[] { targetVariable }), 124 rows); 83 this.offsets = this.allowedInputVariables 84 .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1) 85 .Concat(new double[] { 0 }) // no offset for target variable 86 .ToArray(); 87 if (weights == null) { 88 // automatic determination of weights (all features should have variance = 1) 89 this.weights = this.allowedInputVariables 90 .Select(name => { 91 var pop = dataset.GetDoubleValues(name, rows).StandardDeviationPop(); 92 return pop.IsAlmost(0) ? 1.0 : 1.0 / pop; 93 }) 94 .Concat(new double[] { 1.0 }) // no scaling for target variable 95 .ToArray(); 125 96 } else { 126 this.offsets = this.allowedInputVariables 127 .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1) 128 .Concat(new double[] { 0 }) // no offset for target variable 129 .ToArray(); 130 if (weights == null) { 131 // automatic determination of weights (all features should have variance = 1) 132 this.weights = this.allowedInputVariables 133 .Select(name => { 134 var pop = dataset.GetDoubleValues(name, rows).StandardDeviationPop(); 135 return pop.IsAlmost(0) ? 1.0 : 1.0 / pop; 136 }) 137 .Concat(new double[] { 1.0 }) // no scaling for target variable 138 .ToArray(); 139 } else { 140 // user specified weights (+ 1 for target) 141 this.weights = weights.Concat(new double[] { 1.0 }).ToArray(); 142 if (this.weights.Length - 1 != this.allowedInputVariables.Length) 143 throw new ArgumentException("The number of elements in the weight vector must match the number of input variables"); 144 } 145 inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights); 146 } 97 // user specified weights (+ 1 for target) 98 this.weights = weights.Concat(new double[] { 1.0 }).ToArray(); 99 if (this.weights.Length - 1 != this.allowedInputVariables.Length) 100 throw new ArgumentException("The number of elements in the weight vector must match the number of input variables"); 101 } 102 inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights); 147 103 148 104 if (inputMatrix.ContainsNanOrInfinity()) 149 105 throw new NotSupportedException( 150 106 "Nearest neighbour model does not support NaN or infinity values in the input dataset."); 151 152 this.kdTree = new alglib.nearestneighbor.kdtree();153 107 154 108 var nRows = inputMatrix.GetLength(0); … … 167 121 } 168 122 } 169 alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree); 123 124 alglib.knnbuildercreate(out var knnbuilder); 125 if (classValues == null) { 126 alglib.knnbuildersetdatasetreg(knnbuilder, inputMatrix, nRows, nFeatures, nout: 1); 127 } else { 128 alglib.knnbuildersetdatasetcls(knnbuilder, inputMatrix, nRows, nFeatures, classValues.Length); 129 } 130 alglib.knnbuilderbuildknnmodel(knnbuilder, k, eps: 0.0, out model, out var report); // eps=0 (exact k-nn search is performed) 131 170 132 } 171 133 … … 184 146 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 185 147 double[,] inputData; 186 if (IsCompatibilityLoaded) { 187 inputData = dataset.ToArray(allowedInputVariables, rows); 188 } else { 189 inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights); 190 } 148 inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights); 191 149 192 150 int n = inputData.GetLength(0); 193 151 int columns = inputData.GetLength(1); 194 152 double[] x = new double[columns]; 195 double[] dists = new double[k]; 196 double[,] neighbours = new double[k, columns + 1];197 153 154 alglib.knncreatebuffer(model, out var buf); 155 var y = new double[1]; 198 156 for (int row = 0; row < n; row++) { 199 157 for (int column = 0; column < columns; column++) { 200 158 x[column] = inputData[row, column]; 201 159 } 202 int numNeighbours; 203 lock (kdTreeLockObject) { // gkronber: the following calls change the kdTree data structure 204 numNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, selfMatch); 205 alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists); 206 alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours); 207 } 208 if (selfMatch) { 209 // weights for neighbours are 1/d. 210 // override distances (=0) of exact matches using 1% of the distance of the next closest non-self-match neighbour -> selfmatches weight 100x more than the next closest neighbor. 211 // if all k neighbours are selfmatches then they all have weight 0.01. 212 double minDist = dists[0] + 1; 213 for (int i = 0; i < numNeighbours; i++) { 214 if ((minDist > dists[i]) && (dists[i] != 0)) { 215 minDist = dists[i]; 216 } 217 } 218 minDist /= 100.0; 219 for (int i = 0; i < numNeighbours; i++) { 220 if (dists[i] == 0) { 221 dists[i] = minDist; 222 } 223 } 224 } 225 double distanceWeightedValue = 0.0; 226 double distsSum = 0.0; 227 for (int i = 0; i < numNeighbours; i++) { 228 distanceWeightedValue += neighbours[i, columns] / dists[i]; 229 distsSum += 1.0 / dists[i]; 230 } 231 yield return distanceWeightedValue / distsSum; 160 alglib.knntsprocess(model, buf, x, ref y); // thread-safe process 161 yield return y[0]; 232 162 } 233 163 } … … 236 166 if (classValues == null) throw new InvalidOperationException("No class values are defined."); 237 167 double[,] inputData; 238 if (IsCompatibilityLoaded) { 239 inputData = dataset.ToArray(allowedInputVariables, rows); 240 } else { 241 inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights); 242 } 168 inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights); 169 243 170 int n = inputData.GetLength(0); 244 171 int columns = inputData.GetLength(1); 245 172 double[] x = new double[columns]; 246 int[] y = new int[classValues.Length]; 247 double[] dists = new double[k]; 248 double[,] neighbours = new double[k, columns + 1]; 249 173 174 alglib.knncreatebuffer(model, out var buf); 175 var y = new double[classValues.Length]; 250 176 for (int row = 0; row < n; row++) { 251 177 for (int column = 0; column < columns; column++) { 252 178 x[column] = inputData[row, column]; 253 179 } 254 int numNeighbours; 255 lock (kdTreeLockObject) { 256 // gkronber: the following calls change the kdTree data structure 257 numNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, selfMatch); 258 alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists); 259 alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours); 260 } 261 Array.Clear(y, 0, y.Length); 262 for (int i = 0; i < numNeighbours; i++) { 263 int classValue = (int)Math.Round(neighbours[i, columns]); 264 y[classValue]++; 265 } 266 267 // find class for with the largest probability value 268 int maxProbClassIndex = 0; 269 double maxProb = y[0]; 270 for (int i = 1; i < y.Length; i++) { 271 if (maxProb < y[i]) { 272 maxProb = y[i]; 273 maxProbClassIndex = i; 274 } 275 } 276 yield return classValues[maxProbClassIndex]; 180 alglib.knntsprocess(model, buf, x, ref y); // thread-safe process 181 // find most probably class 182 var maxC = 0; 183 for (int i = 1; i < y.Length; i++) 184 if (maxC < y[i]) maxC = i; 185 yield return classValues[maxC]; 277 186 } 278 187 } … … 303 212 return new NearestNeighbourClassificationSolution(this, new ClassificationProblemData(problemData)); 304 213 } 305 306 #region events307 public event EventHandler Changed;308 private void OnChanged(EventArgs e) {309 var handlers = Changed;310 if (handlers != null)311 handlers(this, e);312 }313 #endregion314 315 316 // BackwardsCompatibility3.3317 #region Backwards compatible code, remove with 3.4318 319 private bool isCompatibilityLoaded = false; // new kNN models have the value false, kNN models loaded from disc have the value true320 [Storable(DefaultValue = true)]321 public bool IsCompatibilityLoaded {322 get { return isCompatibilityLoaded; }323 set { isCompatibilityLoaded = value; }324 }325 #endregion326 #region persistence327 [Storable]328 public double KDTreeApproxF {329 get { return kdTree.approxf; }330 set { kdTree.approxf = value; }331 }332 [Storable]333 public double[] KDTreeBoxMax {334 get { return kdTree.boxmax; }335 set { kdTree.boxmax = value; }336 }337 [Storable]338 public double[] KDTreeBoxMin {339 get { return kdTree.boxmin; }340 set { kdTree.boxmin = value; }341 }342 [Storable]343 public double[] KDTreeBuf {344 get { return kdTree.buf; }345 set { kdTree.buf = value; }346 }347 [Storable]348 public double[] KDTreeCurBoxMax {349 get { return kdTree.curboxmax; }350 set { kdTree.curboxmax = value; }351 }352 [Storable]353 public double[] KDTreeCurBoxMin {354 get { return kdTree.curboxmin; }355 set { kdTree.curboxmin = value; }356 }357 [Storable]358 public double KDTreeCurDist {359 get { return kdTree.curdist; }360 set { kdTree.curdist = value; }361 }362 [Storable]363 public int KDTreeDebugCounter {364 get { return kdTree.debugcounter; }365 set { kdTree.debugcounter = value; }366 }367 [Storable]368 public int[] KDTreeIdx {369 get { return kdTree.idx; }370 set { kdTree.idx = value; }371 }372 [Storable]373 public int KDTreeKCur {374 get { return kdTree.kcur; }375 set { kdTree.kcur = value; }376 }377 [Storable]378 public int KDTreeKNeeded {379 get { return kdTree.kneeded; }380 set { kdTree.kneeded = value; }381 }382 [Storable]383 public int KDTreeN {384 get { return kdTree.n; }385 set { kdTree.n = value; }386 }387 [Storable]388 public int[] KDTreeNodes {389 get { return kdTree.nodes; }390 set { kdTree.nodes = value; }391 }392 [Storable]393 public int KDTreeNormType {394 get { return kdTree.normtype; }395 set { kdTree.normtype = value; }396 }397 [Storable]398 public int KDTreeNX {399 get { return kdTree.nx; }400 set { kdTree.nx = value; }401 }402 [Storable]403 public int KDTreeNY {404 get { return kdTree.ny; }405 set { kdTree.ny = value; }406 }407 [Storable]408 public double[] KDTreeR {409 get { return kdTree.r; }410 set { kdTree.r = value; }411 }412 [Storable]413 public double KDTreeRNeeded {414 get { return kdTree.rneeded; }415 set { kdTree.rneeded = value; }416 }417 [Storable]418 public bool KDTreeSelfMatch {419 get { return kdTree.selfmatch; }420 set { kdTree.selfmatch = value; }421 }422 [Storable]423 public double[] KDTreeSplits {424 get { return kdTree.splits; }425 set { kdTree.splits = value; }426 }427 [Storable]428 public int[] KDTreeTags {429 get { return kdTree.tags; }430 set { kdTree.tags = value; }431 }432 [Storable]433 public double[] KDTreeX {434 get { return kdTree.x; }435 set { kdTree.x = value; }436 }437 [Storable]438 public double[,] KDTreeXY {439 get { return kdTree.xy; }440 set { kdTree.xy = value; }441 }442 #endregion443 214 } 444 215 }
Note: See TracChangeset
for help on using the changeset viewer.