Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/ExpressionClustering/flann/include/flann/mpi/index.h @ 15840

Last change on this file since 15840 was 15840, checked in by gkronber, 6 years ago

#2886 added utility console program for clustering of expressions

File size: 9.1 KB
Line 
1/***********************************************************************
2 * Software License Agreement (BSD License)
3 *
4 * Copyright 2008-2010  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5 * Copyright 2008-2010  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 *
11 * 1. Redistributions of source code must retain the above copyright
12 *    notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 *    notice, this list of conditions and the following disclaimer in the
15 *    documentation and/or other materials provided with the distribution.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
18 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
19 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
20 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
21 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
22 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
24 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
26 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 *************************************************************************/
28
29
30#ifndef FLANN_MPI_HPP_
31#define FLANN_MPI_HPP_
32
33#include <boost/mpi.hpp>
34#include <boost/serialization/array.hpp>
35#include <flann/flann.hpp>
36#include <flann/io/hdf5.h>
37
38namespace flann
39{
40namespace mpi
41{
42
43template<typename DistanceType>
44struct SearchResults
45{
46    flann::Matrix<int> indices;
47    flann::Matrix<DistanceType> dists;
48
49    template<typename Archive>
50    void serialize(Archive& ar, const unsigned int version)
51    {
52        ar& indices.rows;
53        ar& indices.cols;
54        if (Archive::is_loading::value) {
55            indices = Matrix<int>(new int[indices.rows*indices.cols], indices.rows, indices.cols);
56        }
57        ar& boost::serialization::make_array(indices.ptr(), indices.rows*indices.cols);
58        if (Archive::is_saving::value) {
59            delete[] indices.ptr();
60        }
61        ar& dists.rows;
62        ar& dists.cols;
63        if (Archive::is_loading::value) {
64            dists = Matrix<DistanceType>(new DistanceType[dists.rows*dists.cols], dists.rows, dists.cols);
65        }
66        ar& boost::serialization::make_array(dists.ptr(), dists.rows*dists.cols);
67        if (Archive::is_saving::value) {
68            delete[] dists.ptr();
69        }
70    }
71};
72
73template<typename DistanceType>
74struct ResultsMerger
75{
76    SearchResults<DistanceType> operator()(SearchResults<DistanceType> a, SearchResults<DistanceType> b)
77    {
78        SearchResults<DistanceType> results;
79        results.indices = flann::Matrix<int>(new int[a.indices.rows*a.indices.cols],a.indices.rows,a.indices.cols);
80        results.dists = flann::Matrix<DistanceType>(new DistanceType[a.dists.rows*a.dists.cols],a.dists.rows,a.dists.cols);
81
82
83        for (size_t i = 0; i < results.dists.rows; ++i) {
84            size_t idx = 0;
85            size_t a_idx = 0;
86            size_t b_idx = 0;
87            while (idx < results.dists.cols) {
88                if (a.dists[i][a_idx] <= b.dists[i][b_idx]) {
89                    results.dists[i][idx] = a.dists[i][a_idx];
90                    results.indices[i][idx] = a.indices[i][a_idx];
91                    idx++;
92                    a_idx++;
93                }
94                else {
95                    results.dists[i][idx] = b.dists[i][b_idx];
96                    results.indices[i][idx] = b.indices[i][b_idx];
97                    idx++;
98                    b_idx++;
99                }
100            }
101        }
102        delete[] a.indices.ptr();
103        delete[] a.dists.ptr();
104        delete[] b.indices.ptr();
105        delete[] b.dists.ptr();
106        return results;
107    }
108};
109
110
111
112template<typename Distance>
113class Index
114{
115    typedef typename Distance::ElementType ElementType;
116    typedef typename Distance::ResultType DistanceType;
117
118    flann::Index<Distance>* flann_index;
119    flann::Matrix<ElementType> dataset;
120    int size_;
121    int offset_;
122
123public:
124    Index(const std::string& file_name,
125          const std::string& dataset_name,
126          const IndexParams& params);
127
128    ~Index();
129
130    void buildIndex()
131    {
132        flann_index->buildIndex();
133    }
134
135    void knnSearch(const flann::Matrix<ElementType>& queries,
136                   flann::Matrix<int>& indices,
137                   flann::Matrix<DistanceType>& dists,
138                   int knn, const
139                   SearchParams& params);
140
141    int radiusSearch(const flann::Matrix<ElementType>& query,
142                     flann::Matrix<int>& indices,
143                     flann::Matrix<DistanceType>& dists,
144                     float radius,
145                     const SearchParams& params);
146
147    // void save(std::string filename);
148
149    int veclen() const
150    {
151        return flann_index->veclen();
152    }
153
154    int size() const
155    {
156        return size_;
157    }
158
159    IndexParams getIndexParameters()
160    {
161        return flann_index->getParameters();
162    }
163};
164
165
166template<typename Distance>
167Index<Distance>::Index(const std::string& file_name, const std::string& dataset_name, const IndexParams& params)
168{
169    boost::mpi::communicator world;
170    flann_algorithm_t index_type = get_param<flann_algorithm_t>(params,"algorithm");
171    if (index_type == SAVED) {
172        throw FLANNException("Saving/loading of MPI indexes is not currently supported.");
173    }
174    flann::mpi::load_from_file(dataset, file_name, dataset_name);
175    flann_index = new flann::Index<Distance>(dataset, params);
176
177    std::vector<int> sizes;
178    // get the sizes of all MPI indices
179    all_gather(world, (int)flann_index->size(), sizes);
180    size_ = 0;
181    offset_ = 0;
182    for (size_t i = 0; i < sizes.size(); ++i) {
183        if ((int)i < world.rank()) offset_ += sizes[i];
184        size_ += sizes[i];
185    }
186}
187
188template<typename Distance>
189Index<Distance>::~Index()
190{
191    delete flann_index;
192    delete[] dataset.ptr();
193}
194
195template<typename Distance>
196void Index<Distance>::knnSearch(const flann::Matrix<ElementType>& queries, flann::Matrix<int>& indices, flann::Matrix<DistanceType>& dists, int knn, const SearchParams& params)
197{
198    boost::mpi::communicator world;
199    flann::Matrix<int> local_indices(new int[queries.rows*knn], queries.rows, knn);
200    flann::Matrix<DistanceType> local_dists(new DistanceType[queries.rows*knn], queries.rows, knn);
201
202    flann_index->knnSearch(queries, local_indices, local_dists, knn, params);
203    for (size_t i = 0; i < local_indices.rows; ++i) {
204        for (size_t j = 0; j < local_indices.cols; ++j) {
205            local_indices[i][j] += offset_;
206        }
207    }
208    SearchResults<DistanceType> local_results;
209    local_results.indices = local_indices;
210    local_results.dists = local_dists;
211    SearchResults<DistanceType> results;
212
213    // perform MPI reduce
214    reduce(world, local_results, results, ResultsMerger<DistanceType>(), 0);
215
216    if (world.rank() == 0) {
217        for (size_t i = 0; i < results.indices.rows; ++i) {
218            for (size_t j = 0; j < results.indices.cols; ++j) {
219                indices[i][j] = results.indices[i][j];
220                dists[i][j] = results.dists[i][j];
221            }
222        }
223        delete[] results.indices.ptr();
224        delete[] results.dists.ptr();
225    }
226}
227
228template<typename Distance>
229int Index<Distance>::radiusSearch(const flann::Matrix<ElementType>& query, flann::Matrix<int>& indices, flann::Matrix<DistanceType>& dists, float radius, const SearchParams& params)
230{
231    boost::mpi::communicator world;
232    flann::Matrix<int> local_indices(new int[indices.rows*indices.cols], indices.rows, indices.cols);
233    flann::Matrix<DistanceType> local_dists(new DistanceType[dists.rows*dists.cols], dists.rows, dists.cols);
234
235    flann_index->radiusSearch(query, local_indices, local_dists, radius, params);
236    for (size_t i = 0; i < local_indices.rows; ++i) {
237        for (size_t j = 0; j < local_indices.cols; ++j) {
238            local_indices[i][j] += offset_;
239        }
240    }
241    SearchResults<DistanceType> local_results;
242    local_results.indices = local_indices;
243    local_results.dists = local_dists;
244    SearchResults<DistanceType> results;
245
246    // perform MPI reduce
247    reduce(world, local_results, results, ResultsMerger<DistanceType>(), 0);
248
249    if (world.rank() == 0) {
250        for (int i = 0; i < std::min(results.indices.rows, indices.rows); ++i) {
251            for (int j = 0; j < std::min(results.indices.cols, indices.cols); ++j) {
252                indices[i][j] = results.indices[i][j];
253                dists[i][j] = results.dists[i][j];
254            }
255        }
256        delete[] results.indices.ptr();
257        delete[] results.dists.ptr();
258    }
259    return 0;
260}
261
262}
263} //namespace flann::mpi
264
265namespace boost { namespace mpi {
266template<>
267template<typename DistanceType>
268struct is_commutative<flann::mpi::ResultsMerger<DistanceType>, flann::mpi::SearchResults<DistanceType> > : mpl::true_ { };
269} } // end namespace boost::mpi
270
271
272#endif /* FLANN_MPI_HPP_ */
Note: See TracBrowser for help on using the repository browser.