Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/ExpressionClustering/flann/include/flann/algorithms/kdtree_single_index.h @ 15840

Last change on this file since 15840 was 15840, checked in by gkronber, 6 years ago

#2886 added utility console program for clustering of expressions

File size: 18.6 KB
Line 
1/***********************************************************************
2 * Software License Agreement (BSD License)
3 *
4 * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5 * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
6 *
7 * THE BSD LICENSE
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * 1. Redistributions of source code must retain the above copyright
14 *    notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 *    notice, this list of conditions and the following disclaimer in the
17 *    documentation and/or other materials provided with the distribution.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 *************************************************************************/
30
31#ifndef FLANN_KDTREE_SINGLE_INDEX_H_
32#define FLANN_KDTREE_SINGLE_INDEX_H_
33
34#include <algorithm>
35#include <map>
36#include <cassert>
37#include <cstring>
38
39#include "flann/general.h"
40#include "flann/algorithms/nn_index.h"
41#include "flann/util/matrix.h"
42#include "flann/util/result_set.h"
43#include "flann/util/heap.h"
44#include "flann/util/allocator.h"
45#include "flann/util/random.h"
46#include "flann/util/saving.h"
47
48namespace flann
49{
50
51struct KDTreeSingleIndexParams : public IndexParams
52{
53    KDTreeSingleIndexParams(int leaf_max_size = 10, bool reorder = true)
54    {
55        (*this)["algorithm"] = FLANN_INDEX_KDTREE_SINGLE;
56        (*this)["leaf_max_size"] = leaf_max_size;
57        (*this)["reorder"] = reorder;
58    }
59};
60
61
62/**
63 * Single kd-tree index
64 *
65 * Contains the k-d trees and other information for indexing a set of points
66 * for nearest-neighbor matching.
67 */
68template <typename Distance>
69class KDTreeSingleIndex : public NNIndex<Distance>
70{
71public:
72    typedef typename Distance::ElementType ElementType;
73    typedef typename Distance::ResultType DistanceType;
74
75    typedef bool needs_kdtree_distance;
76
77    /**
78     * KDTree constructor
79     *
80     * Params:
81     *          inputData = dataset with the input features
82     *          params = parameters passed to the kdtree algorithm
83     */
84    KDTreeSingleIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KDTreeSingleIndexParams(),
85                      Distance d = Distance() ) :
86        dataset_(inputData), index_params_(params), distance_(d)
87    {
88        size_ = dataset_.rows;
89        dim_ = dataset_.cols;
90        leaf_max_size_ = get_param(params,"leaf_max_size",10);
91        reorder_ = get_param(params,"reorder",true);
92
93        // Create a permutable array of indices to the input vectors.
94        vind_.resize(size_);
95        for (size_t i = 0; i < size_; i++) {
96            vind_[i] = i;
97        }
98    }
99
100    KDTreeSingleIndex(const KDTreeSingleIndex&);
101    KDTreeSingleIndex& operator=(const KDTreeSingleIndex&);
102
103    /**
104     * Standard destructor
105     */
106    ~KDTreeSingleIndex()
107    {
108        if (reorder_) delete[] data_.ptr();
109    }
110
111    /**
112     * Builds the index
113     */
114    void buildIndex()
115    {
116        computeBoundingBox(root_bbox_);
117        root_node_ = divideTree(0, size_, root_bbox_ );   // construct the tree
118
119        if (reorder_) {
120            data_ = flann::Matrix<ElementType>(new ElementType[size_*dim_], size_, dim_);
121            for (size_t i=0; i<size_; ++i) {
122                for (size_t j=0; j<dim_; ++j) {
123                    data_[i][j] = dataset_[vind_[i]][j];
124                }
125            }
126        }
127        else {
128            data_ = dataset_;
129        }
130    }
131
132    flann_algorithm_t getType() const
133    {
134        return FLANN_INDEX_KDTREE_SINGLE;
135    }
136
137
138    void saveIndex(FILE* stream)
139    {
140        save_value(stream, size_);
141        save_value(stream, dim_);
142        save_value(stream, root_bbox_);
143        save_value(stream, reorder_);
144        save_value(stream, leaf_max_size_);
145        save_value(stream, vind_);
146        if (reorder_) {
147            save_value(stream, data_);
148        }
149        save_tree(stream, root_node_);
150    }
151
152
153    void loadIndex(FILE* stream)
154    {
155        load_value(stream, size_);
156        load_value(stream, dim_);
157        load_value(stream, root_bbox_);
158        load_value(stream, reorder_);
159        load_value(stream, leaf_max_size_);
160        load_value(stream, vind_);
161        if (reorder_) {
162            load_value(stream, data_);
163        }
164        else {
165            data_ = dataset_;
166        }
167        load_tree(stream, root_node_);
168
169
170        index_params_["algorithm"] = getType();
171        index_params_["leaf_max_size"] = leaf_max_size_;
172        index_params_["reorder"] = reorder_;
173    }
174
175    /**
176     *  Returns size of index.
177     */
178    size_t size() const
179    {
180        return size_;
181    }
182
183    /**
184     * Returns the length of an index feature.
185     */
186    size_t veclen() const
187    {
188        return dim_;
189    }
190
191    /**
192     * Computes the inde memory usage
193     * Returns: memory used by the index
194     */
195    int usedMemory() const
196    {
197        return pool_.usedMemory+pool_.wastedMemory+dataset_.rows*sizeof(int);  // pool memory and vind array memory
198    }
199
200    IndexParams getParameters() const
201    {
202        return index_params_;
203    }
204
205    /**
206     * Find set of nearest neighbors to vec. Their indices are stored inside
207     * the result object.
208     *
209     * Params:
210     *     result = the result object in which the indices of the nearest-neighbors are stored
211     *     vec = the vector for which to search the nearest neighbors
212     *     maxCheck = the maximum number of restarts (in a best-bin-first manner)
213     */
214    void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
215    {
216        float epsError = 1+searchParams.eps;
217
218        std::vector<DistanceType> dists(dim_,0);
219        DistanceType distsq = computeInitialDistances(vec, dists);
220        searchLevel(result, vec, root_node_, distsq, dists, epsError);
221    }
222
223private:
224
225
226    /*--------------------- Internal Data Structures --------------------------*/
227    struct Node
228    {
229      /**
230       * Indices of points in leaf node
231       */
232      int left, right;
233      /**
234       * Dimension used for subdivision.
235       */
236      int divfeat;
237      /**
238       * The values used for subdivision.
239       */
240      DistanceType divlow, divhigh;
241        /**
242         * The child nodes.
243         */
244        Node* child1, * child2;
245    };
246    typedef Node* NodePtr;
247
248
249    struct Interval
250    {
251        DistanceType low, high;
252    };
253
254    typedef std::vector<Interval> BoundingBox;
255
256    typedef BranchStruct<NodePtr, DistanceType> BranchSt;
257    typedef BranchSt* Branch;
258
259
260
261
262    void save_tree(FILE* stream, NodePtr tree)
263    {
264        save_value(stream, *tree);
265        if (tree->child1!=NULL) {
266            save_tree(stream, tree->child1);
267        }
268        if (tree->child2!=NULL) {
269            save_tree(stream, tree->child2);
270        }
271    }
272
273
274    void load_tree(FILE* stream, NodePtr& tree)
275    {
276        tree = pool_.allocate<Node>();
277        load_value(stream, *tree);
278        if (tree->child1!=NULL) {
279            load_tree(stream, tree->child1);
280        }
281        if (tree->child2!=NULL) {
282            load_tree(stream, tree->child2);
283        }
284    }
285
286
287    void computeBoundingBox(BoundingBox& bbox)
288    {
289        bbox.resize(dim_);
290        for (size_t i=0; i<dim_; ++i) {
291            bbox[i].low = (DistanceType)dataset_[0][i];
292            bbox[i].high = (DistanceType)dataset_[0][i];
293        }
294        for (size_t k=1; k<dataset_.rows; ++k) {
295            for (size_t i=0; i<dim_; ++i) {
296                if (dataset_[k][i]<bbox[i].low) bbox[i].low = (DistanceType)dataset_[k][i];
297                if (dataset_[k][i]>bbox[i].high) bbox[i].high = (DistanceType)dataset_[k][i];
298            }
299        }
300    }
301
302
303    /**
304     * Create a tree node that subdivides the list of vecs from vind[first]
305     * to vind[last].  The routine is called recursively on each sublist.
306     * Place a pointer to this new tree node in the location pTree.
307     *
308     * Params: pTree = the new node to create
309     *                  first = index of the first vector
310     *                  last = index of the last vector
311     */
312    NodePtr divideTree(int left, int right, BoundingBox& bbox)
313    {
314        NodePtr node = pool_.allocate<Node>(); // allocate memory
315
316        /* If too few exemplars remain, then make this a leaf node. */
317        if ( (right-left) <= leaf_max_size_) {
318            node->child1 = node->child2 = NULL;    /* Mark as leaf node. */
319            node->left = left;
320            node->right = right;
321
322            // compute bounding-box of leaf points
323            for (size_t i=0; i<dim_; ++i) {
324                bbox[i].low = (DistanceType)dataset_[vind_[left]][i];
325                bbox[i].high = (DistanceType)dataset_[vind_[left]][i];
326            }
327            for (int k=left+1; k<right; ++k) {
328                for (size_t i=0; i<dim_; ++i) {
329                    if (bbox[i].low>dataset_[vind_[k]][i]) bbox[i].low=(DistanceType)dataset_[vind_[k]][i];
330                    if (bbox[i].high<dataset_[vind_[k]][i]) bbox[i].high=(DistanceType)dataset_[vind_[k]][i];
331                }
332            }
333        }
334        else {
335            int idx;
336            int cutfeat;
337            DistanceType cutval;
338            middleSplit(&vind_[0]+left, right-left, idx, cutfeat, cutval, bbox);
339
340            node->divfeat = cutfeat;
341
342            BoundingBox left_bbox(bbox);
343            left_bbox[cutfeat].high = cutval;
344            node->child1 = divideTree(left, left+idx, left_bbox);
345
346            BoundingBox right_bbox(bbox);
347            right_bbox[cutfeat].low = cutval;
348            node->child2 = divideTree(left+idx, right, right_bbox);
349
350            node->divlow = left_bbox[cutfeat].high;
351            node->divhigh = right_bbox[cutfeat].low;
352
353            for (size_t i=0; i<dim_; ++i) {
354              bbox[i].low = std::min(left_bbox[i].low, right_bbox[i].low);
355              bbox[i].high = std::max(left_bbox[i].high, right_bbox[i].high);
356            }
357        }
358
359        return node;
360    }
361
362    void computeMinMax(int* ind, int count, int dim, ElementType& min_elem, ElementType& max_elem)
363    {
364        min_elem = dataset_[ind[0]][dim];
365        max_elem = dataset_[ind[0]][dim];
366        for (int i=1; i<count; ++i) {
367            ElementType val = dataset_[ind[i]][dim];
368            if (val<min_elem) min_elem = val;
369            if (val>max_elem) max_elem = val;
370        }
371    }
372
373    void middleSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
374    {
375        // find the largest span from the approximate bounding box
376        ElementType max_span = bbox[0].high-bbox[0].low;
377        cutfeat = 0;
378        cutval = (bbox[0].high+bbox[0].low)/2;
379        for (size_t i=1; i<dim_; ++i) {
380            ElementType span = bbox[i].low-bbox[i].low;
381            if (span>max_span) {
382                max_span = span;
383                cutfeat = i;
384                cutval = (bbox[i].high+bbox[i].low)/2;
385            }
386        }
387
388        // compute exact span on the found dimension
389        ElementType min_elem, max_elem;
390        computeMinMax(ind, count, cutfeat, min_elem, max_elem);
391        cutval = (min_elem+max_elem)/2;
392        max_span = max_elem - min_elem;
393
394        // check if a dimension of a largest span exists
395        size_t k = cutfeat;
396        for (size_t i=0; i<dim_; ++i) {
397            if (i==k) continue;
398            ElementType span = bbox[i].high-bbox[i].low;
399            if (span>max_span) {
400                computeMinMax(ind, count, i, min_elem, max_elem);
401                span = max_elem - min_elem;
402                if (span>max_span) {
403                    max_span = span;
404                    cutfeat = i;
405                    cutval = (min_elem+max_elem)/2;
406                }
407            }
408        }
409        int lim1, lim2;
410        planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
411
412        if (lim1>count/2) index = lim1;
413        else if (lim2<count/2) index = lim2;
414        else index = count/2;
415    }
416
417
418    void middleSplit_(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
419    {
420        const float eps_val=0.00001f;
421        DistanceType max_span = bbox[0].high-bbox[0].low;
422        for (size_t i=1; i<dim_; ++i) {
423            DistanceType span = bbox[i].high-bbox[i].low;
424            if (span>max_span) {
425                max_span = span;
426            }
427        }
428        DistanceType max_spread = -1;
429        cutfeat = 0;
430        for (size_t i=0; i<dim_; ++i) {
431            DistanceType span = bbox[i].high-bbox[i].low;
432            if (span>(DistanceType)((1-eps_val)*max_span)) {
433                ElementType min_elem, max_elem;
434                computeMinMax(ind, count, cutfeat, min_elem, max_elem);
435                DistanceType spread = (DistanceType)(max_elem-min_elem);
436                if (spread>max_spread) {
437                    cutfeat = i;
438                    max_spread = spread;
439                }
440            }
441        }
442        // split in the middle
443        DistanceType split_val = (bbox[cutfeat].low+bbox[cutfeat].high)/2;
444        ElementType min_elem, max_elem;
445        computeMinMax(ind, count, cutfeat, min_elem, max_elem);
446
447        if (split_val<min_elem) cutval = (DistanceType)min_elem;
448        else if (split_val>max_elem) cutval = (DistanceType)max_elem;
449        else cutval = split_val;
450
451        int lim1, lim2;
452        planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
453
454        if (lim1>count/2) index = lim1;
455        else if (lim2<count/2) index = lim2;
456        else index = count/2;
457    }
458
459
460    /**
461     *  Subdivide the list of points by a plane perpendicular on axe corresponding
462     *  to the 'cutfeat' dimension at 'cutval' position.
463     *
464     *  On return:
465     *  dataset[ind[0..lim1-1]][cutfeat]<cutval
466     *  dataset[ind[lim1..lim2-1]][cutfeat]==cutval
467     *  dataset[ind[lim2..count]][cutfeat]>cutval
468     */
469    void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
470    {
471        /* Move vector indices for left subtree to front of list. */
472        int left = 0;
473        int right = count-1;
474        for (;; ) {
475            while (left<=right && dataset_[ind[left]][cutfeat]<cutval) ++left;
476            while (left<=right && dataset_[ind[right]][cutfeat]>=cutval) --right;
477            if (left>right) break;
478            std::swap(ind[left], ind[right]); ++left; --right;
479        }
480        /* If either list is empty, it means that all remaining features
481         * are identical. Split in the middle to maintain a balanced tree.
482         */
483        lim1 = left;
484        right = count-1;
485        for (;; ) {
486            while (left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++left;
487            while (left<=right && dataset_[ind[right]][cutfeat]>cutval) --right;
488            if (left>right) break;
489            std::swap(ind[left], ind[right]); ++left; --right;
490        }
491        lim2 = left;
492    }
493
494    DistanceType computeInitialDistances(const ElementType* vec, std::vector<DistanceType>& dists)
495    {
496        DistanceType distsq = 0.0;
497
498        for (size_t i = 0; i < dim_; ++i) {
499            if (vec[i] < root_bbox_[i].low) {
500                dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].low, i);
501                distsq += dists[i];
502            }
503            if (vec[i] > root_bbox_[i].high) {
504                dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].high, i);
505                distsq += dists[i];
506            }
507        }
508
509        return distsq;
510    }
511
512    /**
513     * Performs an exact search in the tree starting from a node.
514     */
515    void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindistsq,
516                     std::vector<DistanceType>& dists, const float epsError)
517    {
518        /* If this is a leaf node, then do check and return. */
519        if ((node->child1 == NULL)&&(node->child2 == NULL)) {
520            DistanceType worst_dist = result_set.worstDist();
521            for (int i=node->left; i<node->right; ++i) {
522                int index = reorder_ ? i : vind_[i];
523                DistanceType dist = distance_(vec, data_[index], dim_, worst_dist);
524                if (dist<worst_dist) {
525                    result_set.addPoint(dist,vind_[i]);
526                }
527            }
528            return;
529        }
530
531        /* Which child branch should be taken first? */
532        int idx = node->divfeat;
533        ElementType val = vec[idx];
534        DistanceType diff1 = val - node->divlow;
535        DistanceType diff2 = val - node->divhigh;
536
537        NodePtr bestChild;
538        NodePtr otherChild;
539        DistanceType cut_dist;
540        if ((diff1+diff2)<0) {
541            bestChild = node->child1;
542            otherChild = node->child2;
543            cut_dist = distance_.accum_dist(val, node->divhigh, idx);
544        }
545        else {
546            bestChild = node->child2;
547            otherChild = node->child1;
548            cut_dist = distance_.accum_dist( val, node->divlow, idx);
549        }
550
551        /* Call recursively to search next level down. */
552        searchLevel(result_set, vec, bestChild, mindistsq, dists, epsError);
553
554        DistanceType dst = dists[idx];
555        mindistsq = mindistsq + cut_dist - dst;
556        dists[idx] = cut_dist;
557        if (mindistsq*epsError<=result_set.worstDist()) {
558            searchLevel(result_set, vec, otherChild, mindistsq, dists, epsError);
559        }
560        dists[idx] = dst;
561    }
562
563private:
564
565    /**
566     * The dataset used by this index
567     */
568    const Matrix<ElementType> dataset_;
569
570    IndexParams index_params_;
571
572    int leaf_max_size_;
573    bool reorder_;
574
575
576    /**
577     *  Array of indices to vectors in the dataset.
578     */
579    std::vector<int> vind_;
580
581    Matrix<ElementType> data_;
582
583    size_t size_;
584    size_t dim_;
585
586    /**
587     * Array of k-d trees used to find neighbours.
588     */
589    NodePtr root_node_;
590
591    BoundingBox root_bbox_;
592
593    /**
594     * Pooled memory allocator.
595     *
596     * Using a pooled memory allocator is more efficient
597     * than allocating memory directly when there is a large
598     * number small of memory allocations.
599     */
600    PooledAllocator pool_;
601
602    Distance distance_;
603};   // class KDTree
604
605}
606
607#endif //FLANN_KDTREE_SINGLE_INDEX_H_
Note: See TracBrowser for help on using the repository browser.