1 | /*********************************************************************** |
---|
2 | * Software License Agreement (BSD License) |
---|
3 | * |
---|
4 | * Copyright 2008-2010 Marius Muja (mariusm@cs.ubc.ca). All rights reserved. |
---|
5 | * Copyright 2008-2010 David G. Lowe (lowe@cs.ubc.ca). All rights reserved. |
---|
6 | * |
---|
7 | * Redistribution and use in source and binary forms, with or without |
---|
8 | * modification, are permitted provided that the following conditions |
---|
9 | * are met: |
---|
10 | * |
---|
11 | * 1. Redistributions of source code must retain the above copyright |
---|
12 | * notice, this list of conditions and the following disclaimer. |
---|
13 | * 2. Redistributions in binary form must reproduce the above copyright |
---|
14 | * notice, this list of conditions and the following disclaimer in the |
---|
15 | * documentation and/or other materials provided with the distribution. |
---|
16 | * |
---|
17 | * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR |
---|
18 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES |
---|
19 | * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. |
---|
20 | * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, |
---|
21 | * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT |
---|
22 | * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
---|
23 | * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
---|
24 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
---|
25 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF |
---|
26 | * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
---|
27 | *************************************************************************/ |
---|
28 | |
---|
29 | |
---|
30 | #ifndef FLANN_MPI_HPP_ |
---|
31 | #define FLANN_MPI_HPP_ |
---|
32 | |
---|
33 | #include <boost/mpi.hpp> |
---|
34 | #include <boost/serialization/array.hpp> |
---|
35 | #include <flann/flann.hpp> |
---|
36 | #include <flann/io/hdf5.h> |
---|
37 | |
---|
38 | namespace flann |
---|
39 | { |
---|
40 | namespace mpi |
---|
41 | { |
---|
42 | |
---|
43 | template<typename DistanceType> |
---|
44 | struct SearchResults |
---|
45 | { |
---|
46 | flann::Matrix<int> indices; |
---|
47 | flann::Matrix<DistanceType> dists; |
---|
48 | |
---|
49 | template<typename Archive> |
---|
50 | void serialize(Archive& ar, const unsigned int version) |
---|
51 | { |
---|
52 | ar& indices.rows; |
---|
53 | ar& indices.cols; |
---|
54 | if (Archive::is_loading::value) { |
---|
55 | indices = Matrix<int>(new int[indices.rows*indices.cols], indices.rows, indices.cols); |
---|
56 | } |
---|
57 | ar& boost::serialization::make_array(indices.ptr(), indices.rows*indices.cols); |
---|
58 | if (Archive::is_saving::value) { |
---|
59 | delete[] indices.ptr(); |
---|
60 | } |
---|
61 | ar& dists.rows; |
---|
62 | ar& dists.cols; |
---|
63 | if (Archive::is_loading::value) { |
---|
64 | dists = Matrix<DistanceType>(new DistanceType[dists.rows*dists.cols], dists.rows, dists.cols); |
---|
65 | } |
---|
66 | ar& boost::serialization::make_array(dists.ptr(), dists.rows*dists.cols); |
---|
67 | if (Archive::is_saving::value) { |
---|
68 | delete[] dists.ptr(); |
---|
69 | } |
---|
70 | } |
---|
71 | }; |
---|
72 | |
---|
73 | template<typename DistanceType> |
---|
74 | struct ResultsMerger |
---|
75 | { |
---|
76 | SearchResults<DistanceType> operator()(SearchResults<DistanceType> a, SearchResults<DistanceType> b) |
---|
77 | { |
---|
78 | SearchResults<DistanceType> results; |
---|
79 | results.indices = flann::Matrix<int>(new int[a.indices.rows*a.indices.cols],a.indices.rows,a.indices.cols); |
---|
80 | results.dists = flann::Matrix<DistanceType>(new DistanceType[a.dists.rows*a.dists.cols],a.dists.rows,a.dists.cols); |
---|
81 | |
---|
82 | |
---|
83 | for (size_t i = 0; i < results.dists.rows; ++i) { |
---|
84 | size_t idx = 0; |
---|
85 | size_t a_idx = 0; |
---|
86 | size_t b_idx = 0; |
---|
87 | while (idx < results.dists.cols) { |
---|
88 | if (a.dists[i][a_idx] <= b.dists[i][b_idx]) { |
---|
89 | results.dists[i][idx] = a.dists[i][a_idx]; |
---|
90 | results.indices[i][idx] = a.indices[i][a_idx]; |
---|
91 | idx++; |
---|
92 | a_idx++; |
---|
93 | } |
---|
94 | else { |
---|
95 | results.dists[i][idx] = b.dists[i][b_idx]; |
---|
96 | results.indices[i][idx] = b.indices[i][b_idx]; |
---|
97 | idx++; |
---|
98 | b_idx++; |
---|
99 | } |
---|
100 | } |
---|
101 | } |
---|
102 | delete[] a.indices.ptr(); |
---|
103 | delete[] a.dists.ptr(); |
---|
104 | delete[] b.indices.ptr(); |
---|
105 | delete[] b.dists.ptr(); |
---|
106 | return results; |
---|
107 | } |
---|
108 | }; |
---|
109 | |
---|
110 | |
---|
111 | |
---|
112 | template<typename Distance> |
---|
113 | class Index |
---|
114 | { |
---|
115 | typedef typename Distance::ElementType ElementType; |
---|
116 | typedef typename Distance::ResultType DistanceType; |
---|
117 | |
---|
118 | flann::Index<Distance>* flann_index; |
---|
119 | flann::Matrix<ElementType> dataset; |
---|
120 | int size_; |
---|
121 | int offset_; |
---|
122 | |
---|
123 | public: |
---|
124 | Index(const std::string& file_name, |
---|
125 | const std::string& dataset_name, |
---|
126 | const IndexParams& params); |
---|
127 | |
---|
128 | ~Index(); |
---|
129 | |
---|
130 | void buildIndex() |
---|
131 | { |
---|
132 | flann_index->buildIndex(); |
---|
133 | } |
---|
134 | |
---|
135 | void knnSearch(const flann::Matrix<ElementType>& queries, |
---|
136 | flann::Matrix<int>& indices, |
---|
137 | flann::Matrix<DistanceType>& dists, |
---|
138 | int knn, const |
---|
139 | SearchParams& params); |
---|
140 | |
---|
141 | int radiusSearch(const flann::Matrix<ElementType>& query, |
---|
142 | flann::Matrix<int>& indices, |
---|
143 | flann::Matrix<DistanceType>& dists, |
---|
144 | float radius, |
---|
145 | const SearchParams& params); |
---|
146 | |
---|
147 | // void save(std::string filename); |
---|
148 | |
---|
149 | int veclen() const |
---|
150 | { |
---|
151 | return flann_index->veclen(); |
---|
152 | } |
---|
153 | |
---|
154 | int size() const |
---|
155 | { |
---|
156 | return size_; |
---|
157 | } |
---|
158 | |
---|
159 | IndexParams getIndexParameters() |
---|
160 | { |
---|
161 | return flann_index->getParameters(); |
---|
162 | } |
---|
163 | }; |
---|
164 | |
---|
165 | |
---|
166 | template<typename Distance> |
---|
167 | Index<Distance>::Index(const std::string& file_name, const std::string& dataset_name, const IndexParams& params) |
---|
168 | { |
---|
169 | boost::mpi::communicator world; |
---|
170 | flann_algorithm_t index_type = get_param<flann_algorithm_t>(params,"algorithm"); |
---|
171 | if (index_type == SAVED) { |
---|
172 | throw FLANNException("Saving/loading of MPI indexes is not currently supported."); |
---|
173 | } |
---|
174 | flann::mpi::load_from_file(dataset, file_name, dataset_name); |
---|
175 | flann_index = new flann::Index<Distance>(dataset, params); |
---|
176 | |
---|
177 | std::vector<int> sizes; |
---|
178 | // get the sizes of all MPI indices |
---|
179 | all_gather(world, (int)flann_index->size(), sizes); |
---|
180 | size_ = 0; |
---|
181 | offset_ = 0; |
---|
182 | for (size_t i = 0; i < sizes.size(); ++i) { |
---|
183 | if ((int)i < world.rank()) offset_ += sizes[i]; |
---|
184 | size_ += sizes[i]; |
---|
185 | } |
---|
186 | } |
---|
187 | |
---|
188 | template<typename Distance> |
---|
189 | Index<Distance>::~Index() |
---|
190 | { |
---|
191 | delete flann_index; |
---|
192 | delete[] dataset.ptr(); |
---|
193 | } |
---|
194 | |
---|
195 | template<typename Distance> |
---|
196 | void Index<Distance>::knnSearch(const flann::Matrix<ElementType>& queries, flann::Matrix<int>& indices, flann::Matrix<DistanceType>& dists, int knn, const SearchParams& params) |
---|
197 | { |
---|
198 | boost::mpi::communicator world; |
---|
199 | flann::Matrix<int> local_indices(new int[queries.rows*knn], queries.rows, knn); |
---|
200 | flann::Matrix<DistanceType> local_dists(new DistanceType[queries.rows*knn], queries.rows, knn); |
---|
201 | |
---|
202 | flann_index->knnSearch(queries, local_indices, local_dists, knn, params); |
---|
203 | for (size_t i = 0; i < local_indices.rows; ++i) { |
---|
204 | for (size_t j = 0; j < local_indices.cols; ++j) { |
---|
205 | local_indices[i][j] += offset_; |
---|
206 | } |
---|
207 | } |
---|
208 | SearchResults<DistanceType> local_results; |
---|
209 | local_results.indices = local_indices; |
---|
210 | local_results.dists = local_dists; |
---|
211 | SearchResults<DistanceType> results; |
---|
212 | |
---|
213 | // perform MPI reduce |
---|
214 | reduce(world, local_results, results, ResultsMerger<DistanceType>(), 0); |
---|
215 | |
---|
216 | if (world.rank() == 0) { |
---|
217 | for (size_t i = 0; i < results.indices.rows; ++i) { |
---|
218 | for (size_t j = 0; j < results.indices.cols; ++j) { |
---|
219 | indices[i][j] = results.indices[i][j]; |
---|
220 | dists[i][j] = results.dists[i][j]; |
---|
221 | } |
---|
222 | } |
---|
223 | delete[] results.indices.ptr(); |
---|
224 | delete[] results.dists.ptr(); |
---|
225 | } |
---|
226 | } |
---|
227 | |
---|
228 | template<typename Distance> |
---|
229 | int Index<Distance>::radiusSearch(const flann::Matrix<ElementType>& query, flann::Matrix<int>& indices, flann::Matrix<DistanceType>& dists, float radius, const SearchParams& params) |
---|
230 | { |
---|
231 | boost::mpi::communicator world; |
---|
232 | flann::Matrix<int> local_indices(new int[indices.rows*indices.cols], indices.rows, indices.cols); |
---|
233 | flann::Matrix<DistanceType> local_dists(new DistanceType[dists.rows*dists.cols], dists.rows, dists.cols); |
---|
234 | |
---|
235 | flann_index->radiusSearch(query, local_indices, local_dists, radius, params); |
---|
236 | for (size_t i = 0; i < local_indices.rows; ++i) { |
---|
237 | for (size_t j = 0; j < local_indices.cols; ++j) { |
---|
238 | local_indices[i][j] += offset_; |
---|
239 | } |
---|
240 | } |
---|
241 | SearchResults<DistanceType> local_results; |
---|
242 | local_results.indices = local_indices; |
---|
243 | local_results.dists = local_dists; |
---|
244 | SearchResults<DistanceType> results; |
---|
245 | |
---|
246 | // perform MPI reduce |
---|
247 | reduce(world, local_results, results, ResultsMerger<DistanceType>(), 0); |
---|
248 | |
---|
249 | if (world.rank() == 0) { |
---|
250 | for (int i = 0; i < std::min(results.indices.rows, indices.rows); ++i) { |
---|
251 | for (int j = 0; j < std::min(results.indices.cols, indices.cols); ++j) { |
---|
252 | indices[i][j] = results.indices[i][j]; |
---|
253 | dists[i][j] = results.dists[i][j]; |
---|
254 | } |
---|
255 | } |
---|
256 | delete[] results.indices.ptr(); |
---|
257 | delete[] results.dists.ptr(); |
---|
258 | } |
---|
259 | return 0; |
---|
260 | } |
---|
261 | |
---|
262 | } |
---|
263 | } //namespace flann::mpi |
---|
264 | |
---|
265 | namespace boost { namespace mpi { |
---|
266 | template<> |
---|
267 | template<typename DistanceType> |
---|
268 | struct is_commutative<flann::mpi::ResultsMerger<DistanceType>, flann::mpi::SearchResults<DistanceType> > : mpl::true_ { }; |
---|
269 | } } // end namespace boost::mpi |
---|
270 | |
---|
271 | |
---|
272 | #endif /* FLANN_MPI_HPP_ */ |
---|