Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/ExpressionClustering/flann/include/flann/nn/simplex_downhill.h @ 15840

Last change on this file since 15840 was 15840, checked in by gkronber, 6 years ago

#2886 added utility console program for clustering of expressions

File size: 5.6 KB
Line 
1/***********************************************************************
2 * Software License Agreement (BSD License)
3 *
4 * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5 * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
6 *
7 * THE BSD LICENSE
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * 1. Redistributions of source code must retain the above copyright
14 *    notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 *    notice, this list of conditions and the following disclaimer in the
17 *    documentation and/or other materials provided with the distribution.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 *************************************************************************/
30
31#ifndef FLANN_SIMPLEX_DOWNHILL_H_
32#define FLANN_SIMPLEX_DOWNHILL_H_
33
34namespace flann
35{
36
37/**
38    Adds val to array vals (and point to array points) and keeping the arrays sorted by vals.
39 */
40template <typename T>
41void addValue(int pos, float val, float* vals, T* point, T* points, int n)
42{
43    vals[pos] = val;
44    for (int i=0; i<n; ++i) {
45        points[pos*n+i] = point[i];
46    }
47
48    // bubble down
49    int j=pos;
50    while (j>0 && vals[j]<vals[j-1]) {
51        swap(vals[j],vals[j-1]);
52        for (int i=0; i<n; ++i) {
53            swap(points[j*n+i],points[(j-1)*n+i]);
54        }
55        --j;
56    }
57}
58
59
60/**
61    Simplex downhill optimization function.
62    Preconditions: points is a 2D mattrix of size (n+1) x n
63                    func is the cost function taking n an array of n params and returning float
64                    vals is the cost function in the n+1 simplex points, if NULL it will be computed
65
66    Postcondition: returns optimum value and points[0..n] are the optimum parameters
67 */
68template <typename T, typename F>
69float optimizeSimplexDownhill(T* points, int n, F func, float* vals = NULL )
70{
71    const int MAX_ITERATIONS = 10;
72
73    assert(n>0);
74
75    T* p_o = new T[n];
76    T* p_r = new T[n];
77    T* p_e = new T[n];
78
79    int alpha = 1;
80
81    int iterations = 0;
82
83    bool ownVals = false;
84    if (vals == NULL) {
85        ownVals = true;
86        vals = new float[n+1];
87        for (int i=0; i<n+1; ++i) {
88            float val = func(points+i*n);
89            addValue(i, val, vals, points+i*n, points, n);
90        }
91    }
92    int nn = n*n;
93
94    while (true) {
95
96        if (iterations++ > MAX_ITERATIONS) break;
97
98        // compute average of simplex points (except the highest point)
99        for (int j=0; j<n; ++j) {
100            p_o[j] = 0;
101            for (int i=0; i<n; ++i) {
102                p_o[i] += points[j*n+i];
103            }
104        }
105        for (int i=0; i<n; ++i) {
106            p_o[i] /= n;
107        }
108
109        bool converged = true;
110        for (int i=0; i<n; ++i) {
111            if (p_o[i] != points[nn+i]) {
112                converged = false;
113            }
114        }
115        if (converged) break;
116
117        // trying a reflection
118        for (int i=0; i<n; ++i) {
119            p_r[i] = p_o[i] + alpha*(p_o[i]-points[nn+i]);
120        }
121        float val_r = func(p_r);
122
123        if ((val_r>=vals[0])&&(val_r<vals[n])) {
124            // reflection between second highest and lowest
125            // add it to the simplex
126            Logger::info("Choosing reflection\n");
127            addValue(n, val_r,vals, p_r, points, n);
128            continue;
129        }
130
131        if (val_r<vals[0]) {
132            // value is smaller than smalest in simplex
133
134            // expand some more to see if it drops further
135            for (int i=0; i<n; ++i) {
136                p_e[i] = 2*p_r[i]-p_o[i];
137            }
138            float val_e = func(p_e);
139
140            if (val_e<val_r) {
141                Logger::info("Choosing reflection and expansion\n");
142                addValue(n, val_e,vals,p_e,points,n);
143            }
144            else {
145                Logger::info("Choosing reflection\n");
146                addValue(n, val_r,vals,p_r,points,n);
147            }
148            continue;
149        }
150        if (val_r>=vals[n]) {
151            for (int i=0; i<n; ++i) {
152                p_e[i] = (p_o[i]+points[nn+i])/2;
153            }
154            float val_e = func(p_e);
155
156            if (val_e<vals[n]) {
157                Logger::info("Choosing contraction\n");
158                addValue(n,val_e,vals,p_e,points,n);
159                continue;
160            }
161        }
162        {
163            Logger::info("Full contraction\n");
164            for (int j=1; j<=n; ++j) {
165                for (int i=0; i<n; ++i) {
166                    points[j*n+i] = (points[j*n+i]+points[i])/2;
167                }
168                float val = func(points+j*n);
169                addValue(j,val,vals,points+j*n,points,n);
170            }
171        }
172    }
173
174    float bestVal = vals[0];
175
176    delete[] p_r;
177    delete[] p_o;
178    delete[] p_e;
179    if (ownVals) delete[] vals;
180
181    return bestVal;
182}
183
184}
185
186#endif //FLANN_SIMPLEX_DOWNHILL_H_
Note: See TracBrowser for help on using the repository browser.