1 | /*********************************************************************** |
---|
2 | * Software License Agreement (BSD License) |
---|
3 | * |
---|
4 | * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved. |
---|
5 | * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved. |
---|
6 | * |
---|
7 | * THE BSD LICENSE |
---|
8 | * |
---|
9 | * Redistribution and use in source and binary forms, with or without |
---|
10 | * modification, are permitted provided that the following conditions |
---|
11 | * are met: |
---|
12 | * |
---|
13 | * 1. Redistributions of source code must retain the above copyright |
---|
14 | * notice, this list of conditions and the following disclaimer. |
---|
15 | * 2. Redistributions in binary form must reproduce the above copyright |
---|
16 | * notice, this list of conditions and the following disclaimer in the |
---|
17 | * documentation and/or other materials provided with the distribution. |
---|
18 | * |
---|
19 | * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR |
---|
20 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES |
---|
21 | * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. |
---|
22 | * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, |
---|
23 | * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT |
---|
24 | * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
---|
25 | * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
---|
26 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
---|
27 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF |
---|
28 | * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
---|
29 | *************************************************************************/ |
---|
30 | #ifndef FLANN_AUTOTUNED_INDEX_H_ |
---|
31 | #define FLANN_AUTOTUNED_INDEX_H_ |
---|
32 | |
---|
33 | #include "flann/general.h" |
---|
34 | #include "flann/algorithms/nn_index.h" |
---|
35 | #include "flann/nn/ground_truth.h" |
---|
36 | #include "flann/nn/index_testing.h" |
---|
37 | #include "flann/util/sampling.h" |
---|
38 | #include "flann/algorithms/kdtree_index.h" |
---|
39 | #include "flann/algorithms/kdtree_single_index.h" |
---|
40 | #include "flann/algorithms/kmeans_index.h" |
---|
41 | #include "flann/algorithms/composite_index.h" |
---|
42 | #include "flann/algorithms/linear_index.h" |
---|
43 | #include "flann/util/logger.h" |
---|
44 | |
---|
45 | namespace flann |
---|
46 | { |
---|
47 | |
---|
48 | template<typename Distance> |
---|
49 | NNIndex<Distance>* create_index_by_type(const Matrix<typename Distance::ElementType>& dataset, const IndexParams& params, const Distance& distance); |
---|
50 | |
---|
51 | |
---|
52 | struct AutotunedIndexParams : public IndexParams |
---|
53 | { |
---|
54 | AutotunedIndexParams(float target_precision = 0.8, float build_weight = 0.01, float memory_weight = 0, float sample_fraction = 0.1) |
---|
55 | { |
---|
56 | (*this)["algorithm"] = FLANN_INDEX_AUTOTUNED; |
---|
57 | // precision desired (used for autotuning, -1 otherwise) |
---|
58 | (*this)["target_precision"] = target_precision; |
---|
59 | // build tree time weighting factor |
---|
60 | (*this)["build_weight"] = build_weight; |
---|
61 | // index memory weighting factor |
---|
62 | (*this)["memory_weight"] = memory_weight; |
---|
63 | // what fraction of the dataset to use for autotuning |
---|
64 | (*this)["sample_fraction"] = sample_fraction; |
---|
65 | } |
---|
66 | }; |
---|
67 | |
---|
68 | |
---|
69 | template <typename Distance> |
---|
70 | class AutotunedIndex : public NNIndex<Distance> |
---|
71 | { |
---|
72 | public: |
---|
73 | typedef typename Distance::ElementType ElementType; |
---|
74 | typedef typename Distance::ResultType DistanceType; |
---|
75 | |
---|
76 | typedef bool needs_kdtree_distance; |
---|
77 | |
---|
78 | AutotunedIndex(const Matrix<ElementType>& inputData, const IndexParams& params = AutotunedIndexParams(), Distance d = Distance()) : |
---|
79 | dataset_(inputData), distance_(d) |
---|
80 | { |
---|
81 | target_precision_ = get_param(params, "target_precision",0.8f); |
---|
82 | build_weight_ = get_param(params,"build_weight", 0.01f); |
---|
83 | memory_weight_ = get_param(params, "memory_weight", 0.0f); |
---|
84 | sample_fraction_ = get_param(params,"sample_fraction", 0.1f); |
---|
85 | bestIndex_ = NULL; |
---|
86 | } |
---|
87 | |
---|
88 | AutotunedIndex(const AutotunedIndex&); |
---|
89 | AutotunedIndex& operator=(const AutotunedIndex&); |
---|
90 | |
---|
91 | virtual ~AutotunedIndex() |
---|
92 | { |
---|
93 | if (bestIndex_ != NULL) { |
---|
94 | delete bestIndex_; |
---|
95 | bestIndex_ = NULL; |
---|
96 | } |
---|
97 | } |
---|
98 | |
---|
99 | /** |
---|
100 | * Method responsible with building the index. |
---|
101 | */ |
---|
102 | virtual void buildIndex() |
---|
103 | { |
---|
104 | bestParams_ = estimateBuildParams(); |
---|
105 | Logger::info("----------------------------------------------------\n"); |
---|
106 | Logger::info("Autotuned parameters:\n"); |
---|
107 | print_params(bestParams_); |
---|
108 | Logger::info("----------------------------------------------------\n"); |
---|
109 | |
---|
110 | bestIndex_ = create_index_by_type(dataset_, bestParams_, distance_); |
---|
111 | bestIndex_->buildIndex(); |
---|
112 | speedup_ = estimateSearchParams(bestSearchParams_); |
---|
113 | Logger::info("----------------------------------------------------\n"); |
---|
114 | Logger::info("Search parameters:\n"); |
---|
115 | print_params(bestSearchParams_); |
---|
116 | Logger::info("----------------------------------------------------\n"); |
---|
117 | } |
---|
118 | |
---|
119 | /** |
---|
120 | * Saves the index to a stream |
---|
121 | */ |
---|
122 | virtual void saveIndex(FILE* stream) |
---|
123 | { |
---|
124 | save_value(stream, (int)bestIndex_->getType()); |
---|
125 | bestIndex_->saveIndex(stream); |
---|
126 | save_value(stream, bestSearchParams_.checks); |
---|
127 | } |
---|
128 | |
---|
129 | /** |
---|
130 | * Loads the index from a stream |
---|
131 | */ |
---|
132 | virtual void loadIndex(FILE* stream) |
---|
133 | { |
---|
134 | int index_type; |
---|
135 | |
---|
136 | load_value(stream, index_type); |
---|
137 | IndexParams params; |
---|
138 | params["algorithm"] = (flann_algorithm_t)index_type; |
---|
139 | bestIndex_ = create_index_by_type<Distance>(dataset_, params, distance_); |
---|
140 | bestIndex_->loadIndex(stream); |
---|
141 | load_value(stream, bestSearchParams_.checks); |
---|
142 | } |
---|
143 | |
---|
144 | /** |
---|
145 | * Method that searches for nearest-neighbors |
---|
146 | */ |
---|
147 | virtual void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) |
---|
148 | { |
---|
149 | if (searchParams.checks == FLANN_CHECKS_AUTOTUNED) { |
---|
150 | bestIndex_->findNeighbors(result, vec, bestSearchParams_); |
---|
151 | } |
---|
152 | else { |
---|
153 | bestIndex_->findNeighbors(result, vec, searchParams); |
---|
154 | } |
---|
155 | } |
---|
156 | |
---|
157 | |
---|
158 | IndexParams getParameters() const |
---|
159 | { |
---|
160 | return bestIndex_->getParameters(); |
---|
161 | } |
---|
162 | |
---|
163 | SearchParams getSearchParameters() const |
---|
164 | { |
---|
165 | return bestSearchParams_; |
---|
166 | } |
---|
167 | |
---|
168 | float getSpeedup() const |
---|
169 | { |
---|
170 | return speedup_; |
---|
171 | } |
---|
172 | |
---|
173 | |
---|
174 | /** |
---|
175 | * Number of features in this index. |
---|
176 | */ |
---|
177 | virtual size_t size() const |
---|
178 | { |
---|
179 | return bestIndex_->size(); |
---|
180 | } |
---|
181 | |
---|
182 | /** |
---|
183 | * The length of each vector in this index. |
---|
184 | */ |
---|
185 | virtual size_t veclen() const |
---|
186 | { |
---|
187 | return bestIndex_->veclen(); |
---|
188 | } |
---|
189 | |
---|
190 | /** |
---|
191 | * The amount of memory (in bytes) this index uses. |
---|
192 | */ |
---|
193 | virtual int usedMemory() const |
---|
194 | { |
---|
195 | return bestIndex_->usedMemory(); |
---|
196 | } |
---|
197 | |
---|
198 | /** |
---|
199 | * Algorithm name |
---|
200 | */ |
---|
201 | virtual flann_algorithm_t getType() const |
---|
202 | { |
---|
203 | return FLANN_INDEX_AUTOTUNED; |
---|
204 | } |
---|
205 | |
---|
206 | private: |
---|
207 | |
---|
208 | struct CostData |
---|
209 | { |
---|
210 | float searchTimeCost; |
---|
211 | float buildTimeCost; |
---|
212 | float memoryCost; |
---|
213 | float totalCost; |
---|
214 | IndexParams params; |
---|
215 | }; |
---|
216 | |
---|
217 | void evaluate_kmeans(CostData& cost) |
---|
218 | { |
---|
219 | StartStopTimer t; |
---|
220 | int checks; |
---|
221 | const int nn = 1; |
---|
222 | |
---|
223 | Logger::info("KMeansTree using params: max_iterations=%d, branching=%d\n", |
---|
224 | get_param<int>(cost.params,"iterations"), |
---|
225 | get_param<int>(cost.params,"branching")); |
---|
226 | KMeansIndex<Distance> kmeans(sampledDataset_, cost.params, distance_); |
---|
227 | // measure index build time |
---|
228 | t.start(); |
---|
229 | kmeans.buildIndex(); |
---|
230 | t.stop(); |
---|
231 | float buildTime = (float)t.value; |
---|
232 | |
---|
233 | // measure search time |
---|
234 | float searchTime = test_index_precision(kmeans, sampledDataset_, testDataset_, gt_matches_, target_precision_, checks, distance_, nn); |
---|
235 | |
---|
236 | float datasetMemory = float(sampledDataset_.rows * sampledDataset_.cols * sizeof(float)); |
---|
237 | cost.memoryCost = (kmeans.usedMemory() + datasetMemory) / datasetMemory; |
---|
238 | cost.searchTimeCost = searchTime; |
---|
239 | cost.buildTimeCost = buildTime; |
---|
240 | Logger::info("KMeansTree buildTime=%g, searchTime=%g, build_weight=%g\n", buildTime, searchTime, build_weight_); |
---|
241 | } |
---|
242 | |
---|
243 | |
---|
244 | void evaluate_kdtree(CostData& cost) |
---|
245 | { |
---|
246 | StartStopTimer t; |
---|
247 | int checks; |
---|
248 | const int nn = 1; |
---|
249 | |
---|
250 | Logger::info("KDTree using params: trees=%d\n", get_param<int>(cost.params,"trees")); |
---|
251 | KDTreeIndex<Distance> kdtree(sampledDataset_, cost.params, distance_); |
---|
252 | |
---|
253 | t.start(); |
---|
254 | kdtree.buildIndex(); |
---|
255 | t.stop(); |
---|
256 | float buildTime = (float)t.value; |
---|
257 | |
---|
258 | //measure search time |
---|
259 | float searchTime = test_index_precision(kdtree, sampledDataset_, testDataset_, gt_matches_, target_precision_, checks, distance_, nn); |
---|
260 | |
---|
261 | float datasetMemory = float(sampledDataset_.rows * sampledDataset_.cols * sizeof(float)); |
---|
262 | cost.memoryCost = (kdtree.usedMemory() + datasetMemory) / datasetMemory; |
---|
263 | cost.searchTimeCost = searchTime; |
---|
264 | cost.buildTimeCost = buildTime; |
---|
265 | Logger::info("KDTree buildTime=%g, searchTime=%g\n", buildTime, searchTime); |
---|
266 | } |
---|
267 | |
---|
268 | |
---|
269 | // struct KMeansSimpleDownhillFunctor { |
---|
270 | // |
---|
271 | // Autotune& autotuner; |
---|
272 | // KMeansSimpleDownhillFunctor(Autotune& autotuner_) : autotuner(autotuner_) {}; |
---|
273 | // |
---|
274 | // float operator()(int* params) { |
---|
275 | // |
---|
276 | // float maxFloat = numeric_limits<float>::max(); |
---|
277 | // |
---|
278 | // if (params[0]<2) return maxFloat; |
---|
279 | // if (params[1]<0) return maxFloat; |
---|
280 | // |
---|
281 | // CostData c; |
---|
282 | // c.params["algorithm"] = KMEANS; |
---|
283 | // c.params["centers-init"] = CENTERS_RANDOM; |
---|
284 | // c.params["branching"] = params[0]; |
---|
285 | // c.params["max-iterations"] = params[1]; |
---|
286 | // |
---|
287 | // autotuner.evaluate_kmeans(c); |
---|
288 | // |
---|
289 | // return c.timeCost; |
---|
290 | // |
---|
291 | // } |
---|
292 | // }; |
---|
293 | // |
---|
294 | // struct KDTreeSimpleDownhillFunctor { |
---|
295 | // |
---|
296 | // Autotune& autotuner; |
---|
297 | // KDTreeSimpleDownhillFunctor(Autotune& autotuner_) : autotuner(autotuner_) {}; |
---|
298 | // |
---|
299 | // float operator()(int* params) { |
---|
300 | // float maxFloat = numeric_limits<float>::max(); |
---|
301 | // |
---|
302 | // if (params[0]<1) return maxFloat; |
---|
303 | // |
---|
304 | // CostData c; |
---|
305 | // c.params["algorithm"] = KDTREE; |
---|
306 | // c.params["trees"] = params[0]; |
---|
307 | // |
---|
308 | // autotuner.evaluate_kdtree(c); |
---|
309 | // |
---|
310 | // return c.timeCost; |
---|
311 | // |
---|
312 | // } |
---|
313 | // }; |
---|
314 | |
---|
315 | |
---|
316 | |
---|
317 | void optimizeKMeans(std::vector<CostData>& costs) |
---|
318 | { |
---|
319 | Logger::info("KMEANS, Step 1: Exploring parameter space\n"); |
---|
320 | |
---|
321 | // explore kmeans parameters space using combinations of the parameters below |
---|
322 | int maxIterations[] = { 1, 5, 10, 15 }; |
---|
323 | int branchingFactors[] = { 16, 32, 64, 128, 256 }; |
---|
324 | |
---|
325 | int kmeansParamSpaceSize = FLANN_ARRAY_LEN(maxIterations) * FLANN_ARRAY_LEN(branchingFactors); |
---|
326 | costs.reserve(costs.size() + kmeansParamSpaceSize); |
---|
327 | |
---|
328 | // evaluate kmeans for all parameter combinations |
---|
329 | for (size_t i = 0; i < FLANN_ARRAY_LEN(maxIterations); ++i) { |
---|
330 | for (size_t j = 0; j < FLANN_ARRAY_LEN(branchingFactors); ++j) { |
---|
331 | CostData cost; |
---|
332 | cost.params["algorithm"] = FLANN_INDEX_KMEANS; |
---|
333 | cost.params["centers_init"] = FLANN_CENTERS_RANDOM; |
---|
334 | cost.params["iterations"] = maxIterations[i]; |
---|
335 | cost.params["branching"] = branchingFactors[j]; |
---|
336 | |
---|
337 | evaluate_kmeans(cost); |
---|
338 | costs.push_back(cost); |
---|
339 | } |
---|
340 | } |
---|
341 | |
---|
342 | // Logger::info("KMEANS, Step 2: simplex-downhill optimization\n"); |
---|
343 | // |
---|
344 | // const int n = 2; |
---|
345 | // // choose initial simplex points as the best parameters so far |
---|
346 | // int kmeansNMPoints[n*(n+1)]; |
---|
347 | // float kmeansVals[n+1]; |
---|
348 | // for (int i=0;i<n+1;++i) { |
---|
349 | // kmeansNMPoints[i*n] = (int)kmeansCosts[i].params["branching"]; |
---|
350 | // kmeansNMPoints[i*n+1] = (int)kmeansCosts[i].params["max-iterations"]; |
---|
351 | // kmeansVals[i] = kmeansCosts[i].timeCost; |
---|
352 | // } |
---|
353 | // KMeansSimpleDownhillFunctor kmeans_cost_func(*this); |
---|
354 | // // run optimization |
---|
355 | // optimizeSimplexDownhill(kmeansNMPoints,n,kmeans_cost_func,kmeansVals); |
---|
356 | // // store results |
---|
357 | // for (int i=0;i<n+1;++i) { |
---|
358 | // kmeansCosts[i].params["branching"] = kmeansNMPoints[i*2]; |
---|
359 | // kmeansCosts[i].params["max-iterations"] = kmeansNMPoints[i*2+1]; |
---|
360 | // kmeansCosts[i].timeCost = kmeansVals[i]; |
---|
361 | // } |
---|
362 | } |
---|
363 | |
---|
364 | |
---|
365 | void optimizeKDTree(std::vector<CostData>& costs) |
---|
366 | { |
---|
367 | Logger::info("KD-TREE, Step 1: Exploring parameter space\n"); |
---|
368 | |
---|
369 | // explore kd-tree parameters space using the parameters below |
---|
370 | int testTrees[] = { 1, 4, 8, 16, 32 }; |
---|
371 | |
---|
372 | // evaluate kdtree for all parameter combinations |
---|
373 | for (size_t i = 0; i < FLANN_ARRAY_LEN(testTrees); ++i) { |
---|
374 | CostData cost; |
---|
375 | cost.params["trees"] = testTrees[i]; |
---|
376 | |
---|
377 | evaluate_kdtree(cost); |
---|
378 | costs.push_back(cost); |
---|
379 | } |
---|
380 | |
---|
381 | // Logger::info("KD-TREE, Step 2: simplex-downhill optimization\n"); |
---|
382 | // |
---|
383 | // const int n = 1; |
---|
384 | // // choose initial simplex points as the best parameters so far |
---|
385 | // int kdtreeNMPoints[n*(n+1)]; |
---|
386 | // float kdtreeVals[n+1]; |
---|
387 | // for (int i=0;i<n+1;++i) { |
---|
388 | // kdtreeNMPoints[i] = (int)kdtreeCosts[i].params["trees"]; |
---|
389 | // kdtreeVals[i] = kdtreeCosts[i].timeCost; |
---|
390 | // } |
---|
391 | // KDTreeSimpleDownhillFunctor kdtree_cost_func(*this); |
---|
392 | // // run optimization |
---|
393 | // optimizeSimplexDownhill(kdtreeNMPoints,n,kdtree_cost_func,kdtreeVals); |
---|
394 | // // store results |
---|
395 | // for (int i=0;i<n+1;++i) { |
---|
396 | // kdtreeCosts[i].params["trees"] = kdtreeNMPoints[i]; |
---|
397 | // kdtreeCosts[i].timeCost = kdtreeVals[i]; |
---|
398 | // } |
---|
399 | } |
---|
400 | |
---|
401 | /** |
---|
402 | * Chooses the best nearest-neighbor algorithm and estimates the optimal |
---|
403 | * parameters to use when building the index (for a given precision). |
---|
404 | * Returns a dictionary with the optimal parameters. |
---|
405 | */ |
---|
406 | IndexParams estimateBuildParams() |
---|
407 | { |
---|
408 | std::vector<CostData> costs; |
---|
409 | |
---|
410 | int sampleSize = int(sample_fraction_ * dataset_.rows); |
---|
411 | int testSampleSize = std::min(sampleSize / 10, 1000); |
---|
412 | |
---|
413 | Logger::info("Entering autotuning, dataset size: %d, sampleSize: %d, testSampleSize: %d, target precision: %g\n", dataset_.rows, sampleSize, testSampleSize, target_precision_); |
---|
414 | |
---|
415 | // For a very small dataset, it makes no sense to build any fancy index, just |
---|
416 | // use linear search |
---|
417 | if (testSampleSize < 10) { |
---|
418 | Logger::info("Choosing linear, dataset too small\n"); |
---|
419 | return LinearIndexParams(); |
---|
420 | } |
---|
421 | |
---|
422 | // We use a fraction of the original dataset to speedup the autotune algorithm |
---|
423 | sampledDataset_ = random_sample(dataset_, sampleSize); |
---|
424 | // We use a cross-validation approach, first we sample a testset from the dataset |
---|
425 | testDataset_ = random_sample(sampledDataset_, testSampleSize, true); |
---|
426 | |
---|
427 | // We compute the ground truth using linear search |
---|
428 | Logger::info("Computing ground truth... \n"); |
---|
429 | gt_matches_ = Matrix<int>(new int[testDataset_.rows], testDataset_.rows, 1); |
---|
430 | StartStopTimer t; |
---|
431 | t.start(); |
---|
432 | compute_ground_truth<Distance>(sampledDataset_, testDataset_, gt_matches_, 0, distance_); |
---|
433 | t.stop(); |
---|
434 | |
---|
435 | CostData linear_cost; |
---|
436 | linear_cost.searchTimeCost = (float)t.value; |
---|
437 | linear_cost.buildTimeCost = 0; |
---|
438 | linear_cost.memoryCost = 0; |
---|
439 | linear_cost.params["algorithm"] = FLANN_INDEX_LINEAR; |
---|
440 | |
---|
441 | costs.push_back(linear_cost); |
---|
442 | |
---|
443 | // Start parameter autotune process |
---|
444 | Logger::info("Autotuning parameters...\n"); |
---|
445 | |
---|
446 | optimizeKMeans(costs); |
---|
447 | optimizeKDTree(costs); |
---|
448 | |
---|
449 | float bestTimeCost = costs[0].searchTimeCost; |
---|
450 | for (size_t i = 0; i < costs.size(); ++i) { |
---|
451 | float timeCost = costs[i].buildTimeCost * build_weight_ + costs[i].searchTimeCost; |
---|
452 | if (timeCost < bestTimeCost) { |
---|
453 | bestTimeCost = timeCost; |
---|
454 | } |
---|
455 | } |
---|
456 | |
---|
457 | float bestCost = costs[0].searchTimeCost / bestTimeCost; |
---|
458 | IndexParams bestParams = costs[0].params; |
---|
459 | if (bestTimeCost > 0) { |
---|
460 | for (size_t i = 0; i < costs.size(); ++i) { |
---|
461 | float crtCost = (costs[i].buildTimeCost * build_weight_ + costs[i].searchTimeCost) / bestTimeCost + |
---|
462 | memory_weight_ * costs[i].memoryCost; |
---|
463 | if (crtCost < bestCost) { |
---|
464 | bestCost = crtCost; |
---|
465 | bestParams = costs[i].params; |
---|
466 | } |
---|
467 | } |
---|
468 | } |
---|
469 | |
---|
470 | delete[] gt_matches_.ptr(); |
---|
471 | delete[] testDataset_.ptr(); |
---|
472 | delete[] sampledDataset_.ptr(); |
---|
473 | |
---|
474 | return bestParams; |
---|
475 | } |
---|
476 | |
---|
477 | |
---|
478 | |
---|
479 | /** |
---|
480 | * Estimates the search time parameters needed to get the desired precision. |
---|
481 | * Precondition: the index is built |
---|
482 | * Postcondition: the searchParams will have the optimum params set, also the speedup obtained over linear search. |
---|
483 | */ |
---|
484 | float estimateSearchParams(SearchParams& searchParams) |
---|
485 | { |
---|
486 | const int nn = 1; |
---|
487 | const size_t SAMPLE_COUNT = 1000; |
---|
488 | |
---|
489 | assert(bestIndex_ != NULL); // must have a valid index |
---|
490 | |
---|
491 | float speedup = 0; |
---|
492 | |
---|
493 | int samples = (int)std::min(dataset_.rows / 10, SAMPLE_COUNT); |
---|
494 | if (samples > 0) { |
---|
495 | Matrix<ElementType> testDataset = random_sample(dataset_, samples); |
---|
496 | |
---|
497 | Logger::info("Computing ground truth\n"); |
---|
498 | |
---|
499 | // we need to compute the ground truth first |
---|
500 | Matrix<int> gt_matches(new int[testDataset.rows], testDataset.rows, 1); |
---|
501 | StartStopTimer t; |
---|
502 | t.start(); |
---|
503 | compute_ground_truth<Distance>(dataset_, testDataset, gt_matches, 1, distance_); |
---|
504 | t.stop(); |
---|
505 | float linear = (float)t.value; |
---|
506 | |
---|
507 | int checks; |
---|
508 | Logger::info("Estimating number of checks\n"); |
---|
509 | |
---|
510 | float searchTime; |
---|
511 | float cb_index; |
---|
512 | if (bestIndex_->getType() == FLANN_INDEX_KMEANS) { |
---|
513 | Logger::info("KMeans algorithm, estimating cluster border factor\n"); |
---|
514 | KMeansIndex<Distance>* kmeans = (KMeansIndex<Distance>*)bestIndex_; |
---|
515 | float bestSearchTime = -1; |
---|
516 | float best_cb_index = -1; |
---|
517 | int best_checks = -1; |
---|
518 | for (cb_index = 0; cb_index < 1.1f; cb_index += 0.2f) { |
---|
519 | kmeans->set_cb_index(cb_index); |
---|
520 | searchTime = test_index_precision(*kmeans, dataset_, testDataset, gt_matches, target_precision_, checks, distance_, nn, 1); |
---|
521 | if ((searchTime < bestSearchTime) || (bestSearchTime == -1)) { |
---|
522 | bestSearchTime = searchTime; |
---|
523 | best_cb_index = cb_index; |
---|
524 | best_checks = checks; |
---|
525 | } |
---|
526 | } |
---|
527 | searchTime = bestSearchTime; |
---|
528 | cb_index = best_cb_index; |
---|
529 | checks = best_checks; |
---|
530 | |
---|
531 | kmeans->set_cb_index(best_cb_index); |
---|
532 | Logger::info("Optimum cb_index: %g\n", cb_index); |
---|
533 | bestParams_["cb_index"] = cb_index; |
---|
534 | } |
---|
535 | else { |
---|
536 | searchTime = test_index_precision(*bestIndex_, dataset_, testDataset, gt_matches, target_precision_, checks, distance_, nn, 1); |
---|
537 | } |
---|
538 | |
---|
539 | Logger::info("Required number of checks: %d \n", checks); |
---|
540 | searchParams.checks = checks; |
---|
541 | |
---|
542 | speedup = linear / searchTime; |
---|
543 | |
---|
544 | delete[] gt_matches.ptr(); |
---|
545 | delete[] testDataset.ptr(); |
---|
546 | } |
---|
547 | |
---|
548 | return speedup; |
---|
549 | } |
---|
550 | |
---|
551 | private: |
---|
552 | NNIndex<Distance>* bestIndex_; |
---|
553 | |
---|
554 | IndexParams bestParams_; |
---|
555 | SearchParams bestSearchParams_; |
---|
556 | |
---|
557 | Matrix<ElementType> sampledDataset_; |
---|
558 | Matrix<ElementType> testDataset_; |
---|
559 | Matrix<int> gt_matches_; |
---|
560 | |
---|
561 | float speedup_; |
---|
562 | |
---|
563 | /** |
---|
564 | * The dataset used by this index |
---|
565 | */ |
---|
566 | const Matrix<ElementType> dataset_; |
---|
567 | |
---|
568 | /** |
---|
569 | * Index parameters |
---|
570 | */ |
---|
571 | float target_precision_; |
---|
572 | float build_weight_; |
---|
573 | float memory_weight_; |
---|
574 | float sample_fraction_; |
---|
575 | |
---|
576 | Distance distance_; |
---|
577 | |
---|
578 | |
---|
579 | }; |
---|
580 | } |
---|
581 | |
---|
582 | #endif /* FLANN_AUTOTUNED_INDEX_H_ */ |
---|