1 | /*********************************************************************** |
---|
2 | * Software License Agreement (BSD License) |
---|
3 | * |
---|
4 | * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved. |
---|
5 | * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved. |
---|
6 | * |
---|
7 | * THE BSD LICENSE |
---|
8 | * |
---|
9 | * Redistribution and use in source and binary forms, with or without |
---|
10 | * modification, are permitted provided that the following conditions |
---|
11 | * are met: |
---|
12 | * |
---|
13 | * 1. Redistributions of source code must retain the above copyright |
---|
14 | * notice, this list of conditions and the following disclaimer. |
---|
15 | * 2. Redistributions in binary form must reproduce the above copyright |
---|
16 | * notice, this list of conditions and the following disclaimer in the |
---|
17 | * documentation and/or other materials provided with the distribution. |
---|
18 | * |
---|
19 | * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR |
---|
20 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES |
---|
21 | * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. |
---|
22 | * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, |
---|
23 | * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT |
---|
24 | * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
---|
25 | * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
---|
26 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
---|
27 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF |
---|
28 | * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
---|
29 | *************************************************************************/ |
---|
30 | |
---|
31 | #ifndef FLANN_KMEANS_INDEX_H_ |
---|
32 | #define FLANN_KMEANS_INDEX_H_ |
---|
33 | |
---|
34 | #include <algorithm> |
---|
35 | #include <string> |
---|
36 | #include <map> |
---|
37 | #include <cassert> |
---|
38 | #include <limits> |
---|
39 | #include <cmath> |
---|
40 | |
---|
41 | #include "flann/general.h" |
---|
42 | #include "flann/algorithms/nn_index.h" |
---|
43 | #include "flann/algorithms/dist.h" |
---|
44 | #include "flann/util/matrix.h" |
---|
45 | #include "flann/util/result_set.h" |
---|
46 | #include "flann/util/heap.h" |
---|
47 | #include "flann/util/allocator.h" |
---|
48 | #include "flann/util/random.h" |
---|
49 | #include "flann/util/saving.h" |
---|
50 | #include "flann/util/logger.h" |
---|
51 | |
---|
52 | |
---|
53 | namespace flann |
---|
54 | { |
---|
55 | |
---|
56 | struct KMeansIndexParams : public IndexParams |
---|
57 | { |
---|
58 | KMeansIndexParams(int branching = 32, int iterations = 11, |
---|
59 | flann_centers_init_t centers_init = FLANN_CENTERS_RANDOM, float cb_index = 0.2 ) |
---|
60 | { |
---|
61 | (*this)["algorithm"] = FLANN_INDEX_KMEANS; |
---|
62 | // branching factor |
---|
63 | (*this)["branching"] = branching; |
---|
64 | // max iterations to perform in one kmeans clustering (kmeans tree) |
---|
65 | (*this)["iterations"] = iterations; |
---|
66 | // algorithm used for picking the initial cluster centers for kmeans tree |
---|
67 | (*this)["centers_init"] = centers_init; |
---|
68 | // cluster boundary index. Used when searching the kmeans tree |
---|
69 | (*this)["cb_index"] = cb_index; |
---|
70 | } |
---|
71 | }; |
---|
72 | |
---|
73 | |
---|
74 | /** |
---|
75 | * Hierarchical kmeans index |
---|
76 | * |
---|
77 | * Contains a tree constructed through a hierarchical kmeans clustering |
---|
78 | * and other information for indexing a set of points for nearest-neighbour matching. |
---|
79 | */ |
---|
80 | template <typename Distance> |
---|
81 | class KMeansIndex : public NNIndex<Distance> |
---|
82 | { |
---|
83 | public: |
---|
84 | typedef typename Distance::ElementType ElementType; |
---|
85 | typedef typename Distance::ResultType DistanceType; |
---|
86 | |
---|
87 | typedef bool needs_vector_space_distance; |
---|
88 | |
---|
89 | |
---|
90 | typedef void (KMeansIndex::* centersAlgFunction)(int, int*, int, int*, int&); |
---|
91 | |
---|
92 | /** |
---|
93 | * The function used for choosing the cluster centers. |
---|
94 | */ |
---|
95 | centersAlgFunction chooseCenters; |
---|
96 | |
---|
97 | |
---|
98 | |
---|
99 | /** |
---|
100 | * Chooses the initial centers in the k-means clustering in a random manner. |
---|
101 | * |
---|
102 | * Params: |
---|
103 | * k = number of centers |
---|
104 | * vecs = the dataset of points |
---|
105 | * indices = indices in the dataset |
---|
106 | * indices_length = length of indices vector |
---|
107 | * |
---|
108 | */ |
---|
109 | void chooseCentersRandom(int k, int* indices, int indices_length, int* centers, int& centers_length) |
---|
110 | { |
---|
111 | UniqueRandom r(indices_length); |
---|
112 | |
---|
113 | int index; |
---|
114 | for (index=0; index<k; ++index) { |
---|
115 | bool duplicate = true; |
---|
116 | int rnd; |
---|
117 | while (duplicate) { |
---|
118 | duplicate = false; |
---|
119 | rnd = r.next(); |
---|
120 | if (rnd<0) { |
---|
121 | centers_length = index; |
---|
122 | return; |
---|
123 | } |
---|
124 | |
---|
125 | centers[index] = indices[rnd]; |
---|
126 | |
---|
127 | for (int j=0; j<index; ++j) { |
---|
128 | DistanceType sq = distance_(dataset_[centers[index]], dataset_[centers[j]], dataset_.cols); |
---|
129 | if (sq<1e-16) { |
---|
130 | duplicate = true; |
---|
131 | } |
---|
132 | } |
---|
133 | } |
---|
134 | } |
---|
135 | |
---|
136 | centers_length = index; |
---|
137 | } |
---|
138 | |
---|
139 | |
---|
140 | /** |
---|
141 | * Chooses the initial centers in the k-means using Gonzales' algorithm |
---|
142 | * so that the centers are spaced apart from each other. |
---|
143 | * |
---|
144 | * Params: |
---|
145 | * k = number of centers |
---|
146 | * vecs = the dataset of points |
---|
147 | * indices = indices in the dataset |
---|
148 | * Returns: |
---|
149 | */ |
---|
150 | void chooseCentersGonzales(int k, int* indices, int indices_length, int* centers, int& centers_length) |
---|
151 | { |
---|
152 | int n = indices_length; |
---|
153 | |
---|
154 | int rnd = rand_int(n); |
---|
155 | assert(rnd >=0 && rnd < n); |
---|
156 | |
---|
157 | centers[0] = indices[rnd]; |
---|
158 | |
---|
159 | int index; |
---|
160 | for (index=1; index<k; ++index) { |
---|
161 | |
---|
162 | int best_index = -1; |
---|
163 | DistanceType best_val = 0; |
---|
164 | for (int j=0; j<n; ++j) { |
---|
165 | DistanceType dist = distance_(dataset_[centers[0]],dataset_[indices[j]],dataset_.cols); |
---|
166 | for (int i=1; i<index; ++i) { |
---|
167 | DistanceType tmp_dist = distance_(dataset_[centers[i]],dataset_[indices[j]],dataset_.cols); |
---|
168 | if (tmp_dist<dist) { |
---|
169 | dist = tmp_dist; |
---|
170 | } |
---|
171 | } |
---|
172 | if (dist>best_val) { |
---|
173 | best_val = dist; |
---|
174 | best_index = j; |
---|
175 | } |
---|
176 | } |
---|
177 | if (best_index!=-1) { |
---|
178 | centers[index] = indices[best_index]; |
---|
179 | } |
---|
180 | else { |
---|
181 | break; |
---|
182 | } |
---|
183 | } |
---|
184 | centers_length = index; |
---|
185 | } |
---|
186 | |
---|
187 | |
---|
188 | /** |
---|
189 | * Chooses the initial centers in the k-means using the algorithm |
---|
190 | * proposed in the KMeans++ paper: |
---|
191 | * Arthur, David; Vassilvitskii, Sergei - k-means++: The Advantages of Careful Seeding |
---|
192 | * |
---|
193 | * Implementation of this function was converted from the one provided in Arthur's code. |
---|
194 | * |
---|
195 | * Params: |
---|
196 | * k = number of centers |
---|
197 | * vecs = the dataset of points |
---|
198 | * indices = indices in the dataset |
---|
199 | * Returns: |
---|
200 | */ |
---|
201 | void chooseCentersKMeanspp(int k, int* indices, int indices_length, int* centers, int& centers_length) |
---|
202 | { |
---|
203 | int n = indices_length; |
---|
204 | |
---|
205 | double currentPot = 0; |
---|
206 | DistanceType* closestDistSq = new DistanceType[n]; |
---|
207 | |
---|
208 | // Choose one random center and set the closestDistSq values |
---|
209 | int index = rand_int(n); |
---|
210 | assert(index >=0 && index < n); |
---|
211 | centers[0] = indices[index]; |
---|
212 | |
---|
213 | for (int i = 0; i < n; i++) { |
---|
214 | closestDistSq[i] = distance_(dataset_[indices[i]], dataset_[indices[index]], dataset_.cols); |
---|
215 | currentPot += closestDistSq[i]; |
---|
216 | } |
---|
217 | |
---|
218 | |
---|
219 | const int numLocalTries = 1; |
---|
220 | |
---|
221 | // Choose each center |
---|
222 | int centerCount; |
---|
223 | for (centerCount = 1; centerCount < k; centerCount++) { |
---|
224 | |
---|
225 | // Repeat several trials |
---|
226 | double bestNewPot = -1; |
---|
227 | int bestNewIndex = -1; |
---|
228 | for (int localTrial = 0; localTrial < numLocalTries; localTrial++) { |
---|
229 | |
---|
230 | // Choose our center - have to be slightly careful to return a valid answer even accounting |
---|
231 | // for possible rounding errors |
---|
232 | double randVal = rand_double(currentPot); |
---|
233 | for (index = 0; index < n-1; index++) { |
---|
234 | if (randVal <= closestDistSq[index]) break; |
---|
235 | else randVal -= closestDistSq[index]; |
---|
236 | } |
---|
237 | |
---|
238 | // Compute the new potential |
---|
239 | double newPot = 0; |
---|
240 | for (int i = 0; i < n; i++) newPot += std::min( distance_(dataset_[indices[i]], dataset_[indices[index]], dataset_.cols), closestDistSq[i] ); |
---|
241 | |
---|
242 | // Store the best result |
---|
243 | if ((bestNewPot < 0)||(newPot < bestNewPot)) { |
---|
244 | bestNewPot = newPot; |
---|
245 | bestNewIndex = index; |
---|
246 | } |
---|
247 | } |
---|
248 | |
---|
249 | // Add the appropriate center |
---|
250 | centers[centerCount] = indices[bestNewIndex]; |
---|
251 | currentPot = bestNewPot; |
---|
252 | for (int i = 0; i < n; i++) closestDistSq[i] = std::min( distance_(dataset_[indices[i]], dataset_[indices[bestNewIndex]], dataset_.cols), closestDistSq[i] ); |
---|
253 | } |
---|
254 | |
---|
255 | centers_length = centerCount; |
---|
256 | |
---|
257 | delete[] closestDistSq; |
---|
258 | } |
---|
259 | |
---|
260 | |
---|
261 | |
---|
262 | public: |
---|
263 | |
---|
264 | flann_algorithm_t getType() const |
---|
265 | { |
---|
266 | return FLANN_INDEX_KMEANS; |
---|
267 | } |
---|
268 | |
---|
269 | /** |
---|
270 | * Index constructor |
---|
271 | * |
---|
272 | * Params: |
---|
273 | * inputData = dataset with the input features |
---|
274 | * params = parameters passed to the hierarchical k-means algorithm |
---|
275 | */ |
---|
276 | KMeansIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KMeansIndexParams(), |
---|
277 | Distance d = Distance()) |
---|
278 | : dataset_(inputData), index_params_(params), root_(NULL), indices_(NULL), distance_(d) |
---|
279 | { |
---|
280 | memoryCounter_ = 0; |
---|
281 | |
---|
282 | size_ = dataset_.rows; |
---|
283 | veclen_ = dataset_.cols; |
---|
284 | |
---|
285 | branching_ = get_param(params,"branching",32); |
---|
286 | iterations_ = get_param(params,"iterations",11); |
---|
287 | if (iterations_<0) { |
---|
288 | iterations_ = (std::numeric_limits<int>::max)(); |
---|
289 | } |
---|
290 | centers_init_ = get_param(params,"centers_init",FLANN_CENTERS_RANDOM); |
---|
291 | |
---|
292 | if (centers_init_==FLANN_CENTERS_RANDOM) { |
---|
293 | chooseCenters = &KMeansIndex::chooseCentersRandom; |
---|
294 | } |
---|
295 | else if (centers_init_==FLANN_CENTERS_GONZALES) { |
---|
296 | chooseCenters = &KMeansIndex::chooseCentersGonzales; |
---|
297 | } |
---|
298 | else if (centers_init_==FLANN_CENTERS_KMEANSPP) { |
---|
299 | chooseCenters = &KMeansIndex::chooseCentersKMeanspp; |
---|
300 | } |
---|
301 | else { |
---|
302 | throw FLANNException("Unknown algorithm for choosing initial centers."); |
---|
303 | } |
---|
304 | cb_index_ = 0.4f; |
---|
305 | |
---|
306 | } |
---|
307 | |
---|
308 | |
---|
309 | KMeansIndex(const KMeansIndex&); |
---|
310 | KMeansIndex& operator=(const KMeansIndex&); |
---|
311 | |
---|
312 | |
---|
313 | /** |
---|
314 | * Index destructor. |
---|
315 | * |
---|
316 | * Release the memory used by the index. |
---|
317 | */ |
---|
318 | virtual ~KMeansIndex() |
---|
319 | { |
---|
320 | if (root_ != NULL) { |
---|
321 | free_centers(root_); |
---|
322 | } |
---|
323 | if (indices_!=NULL) { |
---|
324 | delete[] indices_; |
---|
325 | } |
---|
326 | } |
---|
327 | |
---|
328 | /** |
---|
329 | * Returns size of index. |
---|
330 | */ |
---|
331 | size_t size() const |
---|
332 | { |
---|
333 | return size_; |
---|
334 | } |
---|
335 | |
---|
336 | /** |
---|
337 | * Returns the length of an index feature. |
---|
338 | */ |
---|
339 | size_t veclen() const |
---|
340 | { |
---|
341 | return veclen_; |
---|
342 | } |
---|
343 | |
---|
344 | |
---|
345 | void set_cb_index( float index) |
---|
346 | { |
---|
347 | cb_index_ = index; |
---|
348 | } |
---|
349 | |
---|
350 | /** |
---|
351 | * Computes the inde memory usage |
---|
352 | * Returns: memory used by the index |
---|
353 | */ |
---|
354 | int usedMemory() const |
---|
355 | { |
---|
356 | return pool_.usedMemory+pool_.wastedMemory+memoryCounter_; |
---|
357 | } |
---|
358 | |
---|
359 | /** |
---|
360 | * Builds the index |
---|
361 | */ |
---|
362 | void buildIndex() |
---|
363 | { |
---|
364 | if (branching_<2) { |
---|
365 | throw FLANNException("Branching factor must be at least 2"); |
---|
366 | } |
---|
367 | |
---|
368 | indices_ = new int[size_]; |
---|
369 | for (size_t i=0; i<size_; ++i) { |
---|
370 | indices_[i] = int(i); |
---|
371 | } |
---|
372 | |
---|
373 | root_ = pool_.allocate<KMeansNode>(); |
---|
374 | computeNodeStatistics(root_, indices_, (int)size_); |
---|
375 | computeClustering(root_, indices_, (int)size_, branching_,0); |
---|
376 | } |
---|
377 | |
---|
378 | |
---|
379 | void saveIndex(FILE* stream) |
---|
380 | { |
---|
381 | save_value(stream, branching_); |
---|
382 | save_value(stream, iterations_); |
---|
383 | save_value(stream, memoryCounter_); |
---|
384 | save_value(stream, cb_index_); |
---|
385 | save_value(stream, *indices_, (int)size_); |
---|
386 | |
---|
387 | save_tree(stream, root_); |
---|
388 | } |
---|
389 | |
---|
390 | |
---|
391 | void loadIndex(FILE* stream) |
---|
392 | { |
---|
393 | load_value(stream, branching_); |
---|
394 | load_value(stream, iterations_); |
---|
395 | load_value(stream, memoryCounter_); |
---|
396 | load_value(stream, cb_index_); |
---|
397 | if (indices_!=NULL) { |
---|
398 | delete[] indices_; |
---|
399 | } |
---|
400 | indices_ = new int[size_]; |
---|
401 | load_value(stream, *indices_, size_); |
---|
402 | |
---|
403 | if (root_!=NULL) { |
---|
404 | free_centers(root_); |
---|
405 | } |
---|
406 | load_tree(stream, root_); |
---|
407 | |
---|
408 | index_params_["algorithm"] = getType(); |
---|
409 | index_params_["branching"] = branching_; |
---|
410 | index_params_["iterations"] = iterations_; |
---|
411 | index_params_["centers_init"] = centers_init_; |
---|
412 | index_params_["cb_index"] = cb_index_; |
---|
413 | |
---|
414 | } |
---|
415 | |
---|
416 | |
---|
417 | /** |
---|
418 | * Find set of nearest neighbors to vec. Their indices are stored inside |
---|
419 | * the result object. |
---|
420 | * |
---|
421 | * Params: |
---|
422 | * result = the result object in which the indices of the nearest-neighbors are stored |
---|
423 | * vec = the vector for which to search the nearest neighbors |
---|
424 | * searchParams = parameters that influence the search algorithm (checks, cb_index) |
---|
425 | */ |
---|
426 | void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) |
---|
427 | { |
---|
428 | |
---|
429 | int maxChecks = searchParams.checks; |
---|
430 | |
---|
431 | if (maxChecks==FLANN_CHECKS_UNLIMITED) { |
---|
432 | findExactNN(root_, result, vec); |
---|
433 | } |
---|
434 | else { |
---|
435 | // Priority queue storing intermediate branches in the best-bin-first search |
---|
436 | Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_); |
---|
437 | |
---|
438 | int checks = 0; |
---|
439 | findNN(root_, result, vec, checks, maxChecks, heap); |
---|
440 | |
---|
441 | BranchSt branch; |
---|
442 | while (heap->popMin(branch) && (checks<maxChecks || !result.full())) { |
---|
443 | KMeansNodePtr node = branch.node; |
---|
444 | findNN(node, result, vec, checks, maxChecks, heap); |
---|
445 | } |
---|
446 | |
---|
447 | delete heap; |
---|
448 | } |
---|
449 | |
---|
450 | } |
---|
451 | |
---|
452 | /** |
---|
453 | * Clustering function that takes a cut in the hierarchical k-means |
---|
454 | * tree and return the clusters centers of that clustering. |
---|
455 | * Params: |
---|
456 | * numClusters = number of clusters to have in the clustering computed |
---|
457 | * Returns: number of cluster centers |
---|
458 | */ |
---|
459 | int getClusterCenters(Matrix<DistanceType>& centers) |
---|
460 | { |
---|
461 | int numClusters = centers.rows; |
---|
462 | if (numClusters<1) { |
---|
463 | throw FLANNException("Number of clusters must be at least 1"); |
---|
464 | } |
---|
465 | |
---|
466 | DistanceType variance; |
---|
467 | KMeansNodePtr* clusters = new KMeansNodePtr[numClusters]; |
---|
468 | |
---|
469 | int clusterCount = getMinVarianceClusters(root_, clusters, numClusters, variance); |
---|
470 | |
---|
471 | Logger::info("Clusters requested: %d, returning %d\n",numClusters, clusterCount); |
---|
472 | |
---|
473 | for (int i=0; i<clusterCount; ++i) { |
---|
474 | DistanceType* center = clusters[i]->pivot; |
---|
475 | for (size_t j=0; j<veclen_; ++j) { |
---|
476 | centers[i][j] = center[j]; |
---|
477 | } |
---|
478 | } |
---|
479 | delete[] clusters; |
---|
480 | |
---|
481 | return clusterCount; |
---|
482 | } |
---|
483 | |
---|
484 | IndexParams getParameters() const |
---|
485 | { |
---|
486 | return index_params_; |
---|
487 | } |
---|
488 | |
---|
489 | |
---|
490 | private: |
---|
491 | /** |
---|
492 | * Struture representing a node in the hierarchical k-means tree. |
---|
493 | */ |
---|
494 | struct KMeansNode |
---|
495 | { |
---|
496 | /** |
---|
497 | * The cluster center. |
---|
498 | */ |
---|
499 | DistanceType* pivot; |
---|
500 | /** |
---|
501 | * The cluster radius. |
---|
502 | */ |
---|
503 | DistanceType radius; |
---|
504 | /** |
---|
505 | * The cluster mean radius. |
---|
506 | */ |
---|
507 | DistanceType mean_radius; |
---|
508 | /** |
---|
509 | * The cluster variance. |
---|
510 | */ |
---|
511 | DistanceType variance; |
---|
512 | /** |
---|
513 | * The cluster size (number of points in the cluster) |
---|
514 | */ |
---|
515 | int size; |
---|
516 | /** |
---|
517 | * Child nodes (only for non-terminal nodes) |
---|
518 | */ |
---|
519 | KMeansNode** childs; |
---|
520 | /** |
---|
521 | * Node points (only for terminal nodes) |
---|
522 | */ |
---|
523 | int* indices; |
---|
524 | /** |
---|
525 | * Level |
---|
526 | */ |
---|
527 | int level; |
---|
528 | }; |
---|
529 | typedef KMeansNode* KMeansNodePtr; |
---|
530 | |
---|
531 | /** |
---|
532 | * Alias definition for a nicer syntax. |
---|
533 | */ |
---|
534 | typedef BranchStruct<KMeansNodePtr, DistanceType> BranchSt; |
---|
535 | |
---|
536 | |
---|
537 | |
---|
538 | |
---|
539 | void save_tree(FILE* stream, KMeansNodePtr node) |
---|
540 | { |
---|
541 | save_value(stream, *node); |
---|
542 | save_value(stream, *(node->pivot), (int)veclen_); |
---|
543 | if (node->childs==NULL) { |
---|
544 | int indices_offset = (int)(node->indices - indices_); |
---|
545 | save_value(stream, indices_offset); |
---|
546 | } |
---|
547 | else { |
---|
548 | for(int i=0; i<branching_; ++i) { |
---|
549 | save_tree(stream, node->childs[i]); |
---|
550 | } |
---|
551 | } |
---|
552 | } |
---|
553 | |
---|
554 | |
---|
555 | void load_tree(FILE* stream, KMeansNodePtr& node) |
---|
556 | { |
---|
557 | node = pool_.allocate<KMeansNode>(); |
---|
558 | load_value(stream, *node); |
---|
559 | node->pivot = new DistanceType[veclen_]; |
---|
560 | load_value(stream, *(node->pivot), (int)veclen_); |
---|
561 | if (node->childs==NULL) { |
---|
562 | int indices_offset; |
---|
563 | load_value(stream, indices_offset); |
---|
564 | node->indices = indices_ + indices_offset; |
---|
565 | } |
---|
566 | else { |
---|
567 | node->childs = pool_.allocate<KMeansNodePtr>(branching_); |
---|
568 | for(int i=0; i<branching_; ++i) { |
---|
569 | load_tree(stream, node->childs[i]); |
---|
570 | } |
---|
571 | } |
---|
572 | } |
---|
573 | |
---|
574 | |
---|
575 | /** |
---|
576 | * Helper function |
---|
577 | */ |
---|
578 | void free_centers(KMeansNodePtr node) |
---|
579 | { |
---|
580 | delete[] node->pivot; |
---|
581 | if (node->childs!=NULL) { |
---|
582 | for (int k=0; k<branching_; ++k) { |
---|
583 | free_centers(node->childs[k]); |
---|
584 | } |
---|
585 | } |
---|
586 | } |
---|
587 | |
---|
588 | /** |
---|
589 | * Computes the statistics of a node (mean, radius, variance). |
---|
590 | * |
---|
591 | * Params: |
---|
592 | * node = the node to use |
---|
593 | * indices = the indices of the points belonging to the node |
---|
594 | */ |
---|
595 | void computeNodeStatistics(KMeansNodePtr node, int* indices, int indices_length) |
---|
596 | { |
---|
597 | |
---|
598 | DistanceType radius = 0; |
---|
599 | DistanceType variance = 0; |
---|
600 | DistanceType* mean = new DistanceType[veclen_]; |
---|
601 | memoryCounter_ += int(veclen_*sizeof(DistanceType)); |
---|
602 | |
---|
603 | memset(mean,0,veclen_*sizeof(DistanceType)); |
---|
604 | |
---|
605 | for (size_t i=0; i<size_; ++i) { |
---|
606 | ElementType* vec = dataset_[indices[i]]; |
---|
607 | for (size_t j=0; j<veclen_; ++j) { |
---|
608 | mean[j] += vec[j]; |
---|
609 | } |
---|
610 | variance += distance_(vec, ZeroIterator<ElementType>(), veclen_); |
---|
611 | } |
---|
612 | for (size_t j=0; j<veclen_; ++j) { |
---|
613 | mean[j] /= size_; |
---|
614 | } |
---|
615 | variance /= size_; |
---|
616 | variance -= distance_(mean, ZeroIterator<ElementType>(), veclen_); |
---|
617 | |
---|
618 | DistanceType tmp = 0; |
---|
619 | for (int i=0; i<indices_length; ++i) { |
---|
620 | tmp = distance_(mean, dataset_[indices[i]], veclen_); |
---|
621 | if (tmp>radius) { |
---|
622 | radius = tmp; |
---|
623 | } |
---|
624 | } |
---|
625 | |
---|
626 | node->variance = variance; |
---|
627 | node->radius = radius; |
---|
628 | node->pivot = mean; |
---|
629 | } |
---|
630 | |
---|
631 | |
---|
632 | /** |
---|
633 | * The method responsible with actually doing the recursive hierarchical |
---|
634 | * clustering |
---|
635 | * |
---|
636 | * Params: |
---|
637 | * node = the node to cluster |
---|
638 | * indices = indices of the points belonging to the current node |
---|
639 | * branching = the branching factor to use in the clustering |
---|
640 | * |
---|
641 | * TODO: for 1-sized clusters don't store a cluster center (it's the same as the single cluster point) |
---|
642 | */ |
---|
643 | void computeClustering(KMeansNodePtr node, int* indices, int indices_length, int branching, int level) |
---|
644 | { |
---|
645 | node->size = indices_length; |
---|
646 | node->level = level; |
---|
647 | |
---|
648 | if (indices_length < branching) { |
---|
649 | node->indices = indices; |
---|
650 | std::sort(node->indices,node->indices+indices_length); |
---|
651 | node->childs = NULL; |
---|
652 | return; |
---|
653 | } |
---|
654 | |
---|
655 | int* centers_idx = new int[branching]; |
---|
656 | int centers_length; |
---|
657 | (this->*chooseCenters)(branching, indices, indices_length, centers_idx, centers_length); |
---|
658 | |
---|
659 | if (centers_length<branching) { |
---|
660 | node->indices = indices; |
---|
661 | std::sort(node->indices,node->indices+indices_length); |
---|
662 | node->childs = NULL; |
---|
663 | delete [] centers_idx; |
---|
664 | return; |
---|
665 | } |
---|
666 | |
---|
667 | |
---|
668 | Matrix<double> dcenters(new double[branching*veclen_],branching,veclen_); |
---|
669 | for (int i=0; i<centers_length; ++i) { |
---|
670 | ElementType* vec = dataset_[centers_idx[i]]; |
---|
671 | for (size_t k=0; k<veclen_; ++k) { |
---|
672 | dcenters[i][k] = double(vec[k]); |
---|
673 | } |
---|
674 | } |
---|
675 | delete[] centers_idx; |
---|
676 | |
---|
677 | DistanceType* radiuses = new DistanceType[branching]; |
---|
678 | int* count = new int[branching]; |
---|
679 | for (int i=0; i<branching; ++i) { |
---|
680 | radiuses[i] = 0; |
---|
681 | count[i] = 0; |
---|
682 | } |
---|
683 | |
---|
684 | // assign points to clusters |
---|
685 | int* belongs_to = new int[indices_length]; |
---|
686 | for (int i=0; i<indices_length; ++i) { |
---|
687 | |
---|
688 | DistanceType sq_dist = distance_(dataset_[indices[i]], dcenters[0], veclen_); |
---|
689 | belongs_to[i] = 0; |
---|
690 | for (int j=1; j<branching; ++j) { |
---|
691 | DistanceType new_sq_dist = distance_(dataset_[indices[i]], dcenters[j], veclen_); |
---|
692 | if (sq_dist>new_sq_dist) { |
---|
693 | belongs_to[i] = j; |
---|
694 | sq_dist = new_sq_dist; |
---|
695 | } |
---|
696 | } |
---|
697 | if (sq_dist>radiuses[belongs_to[i]]) { |
---|
698 | radiuses[belongs_to[i]] = sq_dist; |
---|
699 | } |
---|
700 | count[belongs_to[i]]++; |
---|
701 | } |
---|
702 | |
---|
703 | bool converged = false; |
---|
704 | int iteration = 0; |
---|
705 | while (!converged && iteration<iterations_) { |
---|
706 | converged = true; |
---|
707 | iteration++; |
---|
708 | |
---|
709 | // compute the new cluster centers |
---|
710 | for (int i=0; i<branching; ++i) { |
---|
711 | memset(dcenters[i],0,sizeof(double)*veclen_); |
---|
712 | radiuses[i] = 0; |
---|
713 | } |
---|
714 | for (int i=0; i<indices_length; ++i) { |
---|
715 | ElementType* vec = dataset_[indices[i]]; |
---|
716 | double* center = dcenters[belongs_to[i]]; |
---|
717 | for (size_t k=0; k<veclen_; ++k) { |
---|
718 | center[k] += vec[k]; |
---|
719 | } |
---|
720 | } |
---|
721 | for (int i=0; i<branching; ++i) { |
---|
722 | int cnt = count[i]; |
---|
723 | for (size_t k=0; k<veclen_; ++k) { |
---|
724 | dcenters[i][k] /= cnt; |
---|
725 | } |
---|
726 | } |
---|
727 | |
---|
728 | // reassign points to clusters |
---|
729 | for (int i=0; i<indices_length; ++i) { |
---|
730 | DistanceType sq_dist = distance_(dataset_[indices[i]], dcenters[0], veclen_); |
---|
731 | int new_centroid = 0; |
---|
732 | for (int j=1; j<branching; ++j) { |
---|
733 | DistanceType new_sq_dist = distance_(dataset_[indices[i]], dcenters[j], veclen_); |
---|
734 | if (sq_dist>new_sq_dist) { |
---|
735 | new_centroid = j; |
---|
736 | sq_dist = new_sq_dist; |
---|
737 | } |
---|
738 | } |
---|
739 | if (sq_dist>radiuses[new_centroid]) { |
---|
740 | radiuses[new_centroid] = sq_dist; |
---|
741 | } |
---|
742 | if (new_centroid != belongs_to[i]) { |
---|
743 | count[belongs_to[i]]--; |
---|
744 | count[new_centroid]++; |
---|
745 | belongs_to[i] = new_centroid; |
---|
746 | |
---|
747 | converged = false; |
---|
748 | } |
---|
749 | } |
---|
750 | |
---|
751 | for (int i=0; i<branching; ++i) { |
---|
752 | // if one cluster converges to an empty cluster, |
---|
753 | // move an element into that cluster |
---|
754 | if (count[i]==0) { |
---|
755 | int j = (i+1)%branching; |
---|
756 | while (count[j]<=1) { |
---|
757 | j = (j+1)%branching; |
---|
758 | } |
---|
759 | |
---|
760 | for (int k=0; k<indices_length; ++k) { |
---|
761 | if (belongs_to[k]==j) { |
---|
762 | belongs_to[k] = i; |
---|
763 | count[j]--; |
---|
764 | count[i]++; |
---|
765 | break; |
---|
766 | } |
---|
767 | } |
---|
768 | converged = false; |
---|
769 | } |
---|
770 | } |
---|
771 | |
---|
772 | } |
---|
773 | |
---|
774 | DistanceType** centers = new DistanceType*[branching]; |
---|
775 | |
---|
776 | for (int i=0; i<branching; ++i) { |
---|
777 | centers[i] = new DistanceType[veclen_]; |
---|
778 | memoryCounter_ += veclen_*sizeof(DistanceType); |
---|
779 | for (size_t k=0; k<veclen_; ++k) { |
---|
780 | centers[i][k] = (DistanceType)dcenters[i][k]; |
---|
781 | } |
---|
782 | } |
---|
783 | |
---|
784 | |
---|
785 | // compute kmeans clustering for each of the resulting clusters |
---|
786 | node->childs = pool_.allocate<KMeansNodePtr>(branching); |
---|
787 | int start = 0; |
---|
788 | int end = start; |
---|
789 | for (int c=0; c<branching; ++c) { |
---|
790 | int s = count[c]; |
---|
791 | |
---|
792 | DistanceType variance = 0; |
---|
793 | DistanceType mean_radius =0; |
---|
794 | for (int i=0; i<indices_length; ++i) { |
---|
795 | if (belongs_to[i]==c) { |
---|
796 | DistanceType d = distance_(dataset_[indices[i]], ZeroIterator<ElementType>(), veclen_); |
---|
797 | variance += d; |
---|
798 | mean_radius += sqrt(d); |
---|
799 | std::swap(indices[i],indices[end]); |
---|
800 | std::swap(belongs_to[i],belongs_to[end]); |
---|
801 | end++; |
---|
802 | } |
---|
803 | } |
---|
804 | variance /= s; |
---|
805 | mean_radius /= s; |
---|
806 | variance -= distance_(centers[c], ZeroIterator<ElementType>(), veclen_); |
---|
807 | |
---|
808 | node->childs[c] = pool_.allocate<KMeansNode>(); |
---|
809 | node->childs[c]->radius = radiuses[c]; |
---|
810 | node->childs[c]->pivot = centers[c]; |
---|
811 | node->childs[c]->variance = variance; |
---|
812 | node->childs[c]->mean_radius = mean_radius; |
---|
813 | node->childs[c]->indices = NULL; |
---|
814 | computeClustering(node->childs[c],indices+start, end-start, branching, level+1); |
---|
815 | start=end; |
---|
816 | } |
---|
817 | |
---|
818 | delete[] dcenters.ptr(); |
---|
819 | delete[] centers; |
---|
820 | delete[] radiuses; |
---|
821 | delete[] count; |
---|
822 | delete[] belongs_to; |
---|
823 | } |
---|
824 | |
---|
825 | |
---|
826 | |
---|
827 | /** |
---|
828 | * Performs one descent in the hierarchical k-means tree. The branches not |
---|
829 | * visited are stored in a priority queue. |
---|
830 | * |
---|
831 | * Params: |
---|
832 | * node = node to explore |
---|
833 | * result = container for the k-nearest neighbors found |
---|
834 | * vec = query points |
---|
835 | * checks = how many points in the dataset have been checked so far |
---|
836 | * maxChecks = maximum dataset points to checks |
---|
837 | */ |
---|
838 | |
---|
839 | |
---|
840 | void findNN(KMeansNodePtr node, ResultSet<DistanceType>& result, const ElementType* vec, int& checks, int maxChecks, |
---|
841 | Heap<BranchSt>* heap) |
---|
842 | { |
---|
843 | // Ignore those clusters that are too far away |
---|
844 | { |
---|
845 | DistanceType bsq = distance_(vec, node->pivot, veclen_); |
---|
846 | DistanceType rsq = node->radius; |
---|
847 | DistanceType wsq = result.worstDist(); |
---|
848 | |
---|
849 | DistanceType val = bsq-rsq-wsq; |
---|
850 | DistanceType val2 = val*val-4*rsq*wsq; |
---|
851 | |
---|
852 | //if (val>0) { |
---|
853 | if ((val>0)&&(val2>0)) { |
---|
854 | return; |
---|
855 | } |
---|
856 | } |
---|
857 | |
---|
858 | if (node->childs==NULL) { |
---|
859 | if (checks>=maxChecks) { |
---|
860 | if (result.full()) return; |
---|
861 | } |
---|
862 | checks += node->size; |
---|
863 | for (int i=0; i<node->size; ++i) { |
---|
864 | int index = node->indices[i]; |
---|
865 | DistanceType dist = distance_(dataset_[index], vec, veclen_); |
---|
866 | result.addPoint(dist, index); |
---|
867 | } |
---|
868 | } |
---|
869 | else { |
---|
870 | DistanceType* domain_distances = new DistanceType[branching_]; |
---|
871 | int closest_center = exploreNodeBranches(node, vec, domain_distances, heap); |
---|
872 | delete[] domain_distances; |
---|
873 | findNN(node->childs[closest_center],result,vec, checks, maxChecks, heap); |
---|
874 | } |
---|
875 | } |
---|
876 | |
---|
877 | /** |
---|
878 | * Helper function that computes the nearest childs of a node to a given query point. |
---|
879 | * Params: |
---|
880 | * node = the node |
---|
881 | * q = the query point |
---|
882 | * distances = array with the distances to each child node. |
---|
883 | * Returns: |
---|
884 | */ |
---|
885 | int exploreNodeBranches(KMeansNodePtr node, const ElementType* q, DistanceType* domain_distances, Heap<BranchSt>* heap) |
---|
886 | { |
---|
887 | |
---|
888 | int best_index = 0; |
---|
889 | domain_distances[best_index] = distance_(q, node->childs[best_index]->pivot, veclen_); |
---|
890 | for (int i=1; i<branching_; ++i) { |
---|
891 | domain_distances[i] = distance_(q, node->childs[i]->pivot, veclen_); |
---|
892 | if (domain_distances[i]<domain_distances[best_index]) { |
---|
893 | best_index = i; |
---|
894 | } |
---|
895 | } |
---|
896 | |
---|
897 | // float* best_center = node->childs[best_index]->pivot; |
---|
898 | for (int i=0; i<branching_; ++i) { |
---|
899 | if (i != best_index) { |
---|
900 | domain_distances[i] -= cb_index_*node->childs[i]->variance; |
---|
901 | |
---|
902 | // float dist_to_border = getDistanceToBorder(node.childs[i].pivot,best_center,q); |
---|
903 | // if (domain_distances[i]<dist_to_border) { |
---|
904 | // domain_distances[i] = dist_to_border; |
---|
905 | // } |
---|
906 | heap->insert(BranchSt(node->childs[i],domain_distances[i])); |
---|
907 | } |
---|
908 | } |
---|
909 | |
---|
910 | return best_index; |
---|
911 | } |
---|
912 | |
---|
913 | |
---|
914 | /** |
---|
915 | * Function the performs exact nearest neighbor search by traversing the entire tree. |
---|
916 | */ |
---|
917 | void findExactNN(KMeansNodePtr node, ResultSet<DistanceType>& result, const ElementType* vec) |
---|
918 | { |
---|
919 | // Ignore those clusters that are too far away |
---|
920 | { |
---|
921 | DistanceType bsq = distance_(vec, node->pivot, veclen_); |
---|
922 | DistanceType rsq = node->radius; |
---|
923 | DistanceType wsq = result.worstDist(); |
---|
924 | |
---|
925 | DistanceType val = bsq-rsq-wsq; |
---|
926 | DistanceType val2 = val*val-4*rsq*wsq; |
---|
927 | |
---|
928 | // if (val>0) { |
---|
929 | if ((val>0)&&(val2>0)) { |
---|
930 | return; |
---|
931 | } |
---|
932 | } |
---|
933 | |
---|
934 | |
---|
935 | if (node->childs==NULL) { |
---|
936 | for (int i=0; i<node->size; ++i) { |
---|
937 | int index = node->indices[i]; |
---|
938 | DistanceType dist = distance_(dataset_[index], vec, veclen_); |
---|
939 | result.addPoint(dist, index); |
---|
940 | } |
---|
941 | } |
---|
942 | else { |
---|
943 | int* sort_indices = new int[branching_]; |
---|
944 | |
---|
945 | getCenterOrdering(node, vec, sort_indices); |
---|
946 | |
---|
947 | for (int i=0; i<branching_; ++i) { |
---|
948 | findExactNN(node->childs[sort_indices[i]],result,vec); |
---|
949 | } |
---|
950 | |
---|
951 | delete[] sort_indices; |
---|
952 | } |
---|
953 | } |
---|
954 | |
---|
955 | |
---|
956 | /** |
---|
957 | * Helper function. |
---|
958 | * |
---|
959 | * I computes the order in which to traverse the child nodes of a particular node. |
---|
960 | */ |
---|
961 | void getCenterOrdering(KMeansNodePtr node, const ElementType* q, int* sort_indices) |
---|
962 | { |
---|
963 | DistanceType* domain_distances = new DistanceType[branching_]; |
---|
964 | for (int i=0; i<branching_; ++i) { |
---|
965 | DistanceType dist = distance_(q, node->childs[i]->pivot, veclen_); |
---|
966 | |
---|
967 | int j=0; |
---|
968 | while (domain_distances[j]<dist && j<i) j++; |
---|
969 | for (int k=i; k>j; --k) { |
---|
970 | domain_distances[k] = domain_distances[k-1]; |
---|
971 | sort_indices[k] = sort_indices[k-1]; |
---|
972 | } |
---|
973 | domain_distances[j] = dist; |
---|
974 | sort_indices[j] = i; |
---|
975 | } |
---|
976 | delete[] domain_distances; |
---|
977 | } |
---|
978 | |
---|
979 | /** |
---|
980 | * Method that computes the squared distance from the query point q |
---|
981 | * from inside region with center c to the border between this |
---|
982 | * region and the region with center p |
---|
983 | */ |
---|
984 | DistanceType getDistanceToBorder(DistanceType* p, DistanceType* c, DistanceType* q) |
---|
985 | { |
---|
986 | DistanceType sum = 0; |
---|
987 | DistanceType sum2 = 0; |
---|
988 | |
---|
989 | for (int i=0; i<veclen_; ++i) { |
---|
990 | DistanceType t = c[i]-p[i]; |
---|
991 | sum += t*(q[i]-(c[i]+p[i])/2); |
---|
992 | sum2 += t*t; |
---|
993 | } |
---|
994 | |
---|
995 | return sum*sum/sum2; |
---|
996 | } |
---|
997 | |
---|
998 | |
---|
999 | /** |
---|
1000 | * Helper function the descends in the hierarchical k-means tree by spliting those clusters that minimize |
---|
1001 | * the overall variance of the clustering. |
---|
1002 | * Params: |
---|
1003 | * root = root node |
---|
1004 | * clusters = array with clusters centers (return value) |
---|
1005 | * varianceValue = variance of the clustering (return value) |
---|
1006 | * Returns: |
---|
1007 | */ |
---|
1008 | int getMinVarianceClusters(KMeansNodePtr root, KMeansNodePtr* clusters, int clusters_length, DistanceType& varianceValue) |
---|
1009 | { |
---|
1010 | int clusterCount = 1; |
---|
1011 | clusters[0] = root; |
---|
1012 | |
---|
1013 | DistanceType meanVariance = root->variance*root->size; |
---|
1014 | |
---|
1015 | while (clusterCount<clusters_length) { |
---|
1016 | DistanceType minVariance = (std::numeric_limits<DistanceType>::max)(); |
---|
1017 | int splitIndex = -1; |
---|
1018 | |
---|
1019 | for (int i=0; i<clusterCount; ++i) { |
---|
1020 | if (clusters[i]->childs != NULL) { |
---|
1021 | |
---|
1022 | DistanceType variance = meanVariance - clusters[i]->variance*clusters[i]->size; |
---|
1023 | |
---|
1024 | for (int j=0; j<branching_; ++j) { |
---|
1025 | variance += clusters[i]->childs[j]->variance*clusters[i]->childs[j]->size; |
---|
1026 | } |
---|
1027 | if (variance<minVariance) { |
---|
1028 | minVariance = variance; |
---|
1029 | splitIndex = i; |
---|
1030 | } |
---|
1031 | } |
---|
1032 | } |
---|
1033 | |
---|
1034 | if (splitIndex==-1) break; |
---|
1035 | if ( (branching_+clusterCount-1) > clusters_length) break; |
---|
1036 | |
---|
1037 | meanVariance = minVariance; |
---|
1038 | |
---|
1039 | // split node |
---|
1040 | KMeansNodePtr toSplit = clusters[splitIndex]; |
---|
1041 | clusters[splitIndex] = toSplit->childs[0]; |
---|
1042 | for (int i=1; i<branching_; ++i) { |
---|
1043 | clusters[clusterCount++] = toSplit->childs[i]; |
---|
1044 | } |
---|
1045 | } |
---|
1046 | |
---|
1047 | varianceValue = meanVariance/root->size; |
---|
1048 | return clusterCount; |
---|
1049 | } |
---|
1050 | |
---|
1051 | private: |
---|
1052 | /** The branching factor used in the hierarchical k-means clustering */ |
---|
1053 | int branching_; |
---|
1054 | |
---|
1055 | /** Maximum number of iterations to use when performing k-means clustering */ |
---|
1056 | int iterations_; |
---|
1057 | |
---|
1058 | /** Algorithm for choosing the cluster centers */ |
---|
1059 | flann_centers_init_t centers_init_; |
---|
1060 | |
---|
1061 | /** |
---|
1062 | * Cluster border index. This is used in the tree search phase when determining |
---|
1063 | * the closest cluster to explore next. A zero value takes into account only |
---|
1064 | * the cluster centres, a value greater then zero also take into account the size |
---|
1065 | * of the cluster. |
---|
1066 | */ |
---|
1067 | float cb_index_; |
---|
1068 | |
---|
1069 | /** |
---|
1070 | * The dataset used by this index |
---|
1071 | */ |
---|
1072 | const Matrix<ElementType> dataset_; |
---|
1073 | |
---|
1074 | /** Index parameters */ |
---|
1075 | IndexParams index_params_; |
---|
1076 | |
---|
1077 | /** |
---|
1078 | * Number of features in the dataset. |
---|
1079 | */ |
---|
1080 | size_t size_; |
---|
1081 | |
---|
1082 | /** |
---|
1083 | * Length of each feature. |
---|
1084 | */ |
---|
1085 | size_t veclen_; |
---|
1086 | |
---|
1087 | /** |
---|
1088 | * The root node in the tree. |
---|
1089 | */ |
---|
1090 | KMeansNodePtr root_; |
---|
1091 | |
---|
1092 | /** |
---|
1093 | * Array of indices to vectors in the dataset. |
---|
1094 | */ |
---|
1095 | int* indices_; |
---|
1096 | |
---|
1097 | /** |
---|
1098 | * The distance |
---|
1099 | */ |
---|
1100 | Distance distance_; |
---|
1101 | |
---|
1102 | /** |
---|
1103 | * Pooled memory allocator. |
---|
1104 | */ |
---|
1105 | PooledAllocator pool_; |
---|
1106 | |
---|
1107 | /** |
---|
1108 | * Memory occupied by the index. |
---|
1109 | */ |
---|
1110 | int memoryCounter_; |
---|
1111 | }; |
---|
1112 | |
---|
1113 | } |
---|
1114 | |
---|
1115 | #endif //FLANN_KMEANS_INDEX_H_ |
---|