1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Diagnostics;
|
---|
4 | using System.Linq;
|
---|
5 | using System.Runtime.InteropServices;
|
---|
6 | using System.Text;
|
---|
7 | using System.Threading.Tasks;
|
---|
8 | using HeuristicLab.Common;
|
---|
9 | using HeuristicLab.Core;
|
---|
10 | using HeuristicLab.Problems.DataAnalysis;
|
---|
11 |
|
---|
12 | namespace HeuristicLab.Algorithms.DataAnalysis.Experimental {
|
---|
13 | public static class SBART {
|
---|
14 | /*
|
---|
15 | # A Cubic B-spline Smoothing routine.
|
---|
16 |
|
---|
17 | #
|
---|
18 | # The algorithm minimises:
|
---|
19 | #
|
---|
20 | # (1/n) * sum ws(i)**2 * (ys(i)-sz(i))**2 + lambda* int ( sz"(xs) )**2 dxs
|
---|
21 | #
|
---|
22 | # lambda is a function of the spar which is assumed to be between
|
---|
23 | # 0 and 1
|
---|
24 |
|
---|
25 |
|
---|
26 | # Input
|
---|
27 |
|
---|
28 | # n number of data points
|
---|
29 | # ys(n) vector of length n containing the observations
|
---|
30 | # ws(n) vector containing the weights given to each data point
|
---|
31 | # xs(n) vector containing the ordinates of the observations
|
---|
32 |
|
---|
33 |
|
---|
34 | # nk number of b-spline coefficients to be estimated
|
---|
35 | # nk <= n+2
|
---|
36 | # knot(nk+4) vector of knot points defining the cubic b-spline basis.
|
---|
37 |
|
---|
38 |
|
---|
39 | # spar penalised likelihood smoothing parameter
|
---|
40 | # ispar indicator saying if spar is supplied or to be estimated
|
---|
41 | # lspar, uspar lower and upper values for spar 0.,1. are good values
|
---|
42 | # tol used in Golden Search routine
|
---|
43 |
|
---|
44 | # isetup setup indicator
|
---|
45 |
|
---|
46 | # icrit indicator saying which cross validation score
|
---|
47 | # is to be computed
|
---|
48 |
|
---|
49 | # ld4 the leading dimension of abd (ie ld4=4)
|
---|
50 | # ldnk the leading dimension of p2ip (not referenced)
|
---|
51 |
|
---|
52 |
|
---|
53 | # Output
|
---|
54 |
|
---|
55 | # coef(nk) vector of spline coefficients
|
---|
56 | # sz(n) vector of smoothed z-values
|
---|
57 | # lev(n) vector of leverages
|
---|
58 | # crit either ordinary of generalized CV score
|
---|
59 | # ier error indicator
|
---|
60 | # ier = 0 ___ everything fine
|
---|
61 | # ier = 1 ___ spar too small or too big
|
---|
62 | # problem in cholesky decomposition
|
---|
63 |
|
---|
64 |
|
---|
65 |
|
---|
66 | # Working arrays/matrix
|
---|
67 | # xwy X'Wy
|
---|
68 | # hs0,hs1,hs2,hs3 the diagonals of the X'WX matrix
|
---|
69 | # sg0,sg1,sg2,sg3 the diagonals of the Gram matrix
|
---|
70 | # abd(ld4,nk) [ X'WX+lambda*SIGMA] in diagonal form
|
---|
71 | # p1ip(ld4,nk) inner products between columns of L inverse
|
---|
72 | # p2ip(ldnk,nk) all inner products between columns of L inverse
|
---|
73 | # L'L = [X'WX+lambdaSIGMA] NOT REFERENCED
|
---|
74 |
|
---|
75 | */
|
---|
76 |
|
---|
77 | /*
|
---|
78 | * sbart(xs,ys,ws,n,knot,nk,
|
---|
79 | coef,sz,lev,
|
---|
80 | crit,icrit,spar,ispar,lspar,uspar,tol,
|
---|
81 | isetup,
|
---|
82 | xwy,
|
---|
83 | hs0,hs1,hs2,hs3,
|
---|
84 | sg0,sg1,sg2,sg3,
|
---|
85 | abd,p1ip,p2ip,ld4,ldnk,ier)
|
---|
86 |
|
---|
87 | */
|
---|
88 |
|
---|
89 |
|
---|
90 |
|
---|
91 | // To build the fortran library (x64) use:
|
---|
92 | // > ifort /dll /Qm64 /libs:static /winapp sbart.f interv.f bsplvb.f spbfa.f spbsl.f bvalue.f scopy.f ssort.f sdot.f saxpy.f bsplvd.f /Fesbart_x64.dll
|
---|
93 | // check dumpbin /EXPORTS sbart_x64.dll
|
---|
94 | // and dumpbin /IMPORTS sbart_x64.dll
|
---|
95 | [DllImport("sbart_x64.dll", CallingConvention = CallingConvention.Cdecl, EntryPoint = "sbart")]
|
---|
96 | public static extern void sbart_x64(
|
---|
97 | float[] xs, float[] ys, float[] ws, ref int n, float[] knot, ref int nk,
|
---|
98 | float[] coeff, float[] sz, float[] lev,
|
---|
99 | ref float crit, ref int icrit, ref float spar, ref int ispar, ref float lspar, ref float uspar, ref float tol,
|
---|
100 | ref int isetup,
|
---|
101 | float[] xwy,
|
---|
102 | float[] hs0, float[] hs1, float[] hs2, float[] hs3,
|
---|
103 | float[] sg0, float[] sg1, float[] sg2, float[] sg3,
|
---|
104 | float[,] abd, float[,] p1ip, float[,] p2ip, ref int ld4, ref int ldnk, ref int ier);
|
---|
105 |
|
---|
106 | [DllImport("sbart_x86.dll", CallingConvention = CallingConvention.Cdecl, EntryPoint = "sbart")]
|
---|
107 | public static extern void sbart_x86();
|
---|
108 |
|
---|
109 |
|
---|
110 | [DllImport("sbart_x64.dll", CallingConvention = CallingConvention.Cdecl, EntryPoint = "sknotl")]
|
---|
111 | public static extern void sknotl_x64(float[] x, ref int n, float[] knot, ref int k);
|
---|
112 | [DllImport("sbart_x64.dll", CallingConvention = CallingConvention.Cdecl, EntryPoint = "setreg")]
|
---|
113 | public static extern void setreg_x64(float[] x, float[] y, float[] w, ref int n, float[] xw, ref int nx, ref float min, ref float range, float[] knot, ref int nk);
|
---|
114 | /*
|
---|
115 | * calculates value at x of jderiv-th derivative of spline from b-repr.
|
---|
116 | c the spline is taken to be continuous from the right, EXCEPT at the
|
---|
117 | c rightmost knot, where it is taken to be continuous from the left.
|
---|
118 | c
|
---|
119 | c****** i n p u t ******
|
---|
120 | c t, bcoef, n, k......forms the b-representation of the spline f to
|
---|
121 | c be evaluated. specifically,
|
---|
122 | c t.....knot sequence, of length n+k, assumed nondecreasing.
|
---|
123 | c bcoef.....b-coefficient sequence, of length n .
|
---|
124 | c n.....length of bcoef and dimension of spline(k,t),
|
---|
125 | c a s s u m e d positive .
|
---|
126 | c k.....order of the spline .
|
---|
127 | c
|
---|
128 | c w a r n i n g . . . the restriction k .le. kmax (=20) is imposed
|
---|
129 | c arbitrarily by the dimension statement for aj, dl, dr below,
|
---|
130 | c but is n o w h e r e c h e c k e d for.
|
---|
131 | c
|
---|
132 | c x.....the point at which to evaluate .
|
---|
133 | c jderiv.....integer giving the order of the derivative to be evaluated
|
---|
134 | c a s s u m e d to be zero or positive.
|
---|
135 | c
|
---|
136 | c****** o u t p u t ******
|
---|
137 | c bvalue.....the value of the (jderiv)-th derivative of f at x .
|
---|
138 | c
|
---|
139 | c****** m e t h o d ******
|
---|
140 | c The nontrivial knot interval (t(i),t(i+1)) containing x is lo-
|
---|
141 | c cated with the aid of interv . The k b-coeffs of f relevant for
|
---|
142 | c this interval are then obtained from bcoef (or taken to be zero if
|
---|
143 | c not explicitly available) and are then differenced jderiv times to
|
---|
144 | c obtain the b-coeffs of (d**jderiv)f relevant for that interval.
|
---|
145 | c Precisely, with j = jderiv, we have from x.(12) of the text that
|
---|
146 | c
|
---|
147 | c (d**j)f = sum ( bcoef(.,j)*b(.,k-j,t) )
|
---|
148 | c
|
---|
149 | c where
|
---|
150 | c / bcoef(.), , j .eq. 0
|
---|
151 | c /
|
---|
152 | c bcoef(.,j) = / bcoef(.,j-1) - bcoef(.-1,j-1)
|
---|
153 | c / ----------------------------- , j .gt. 0
|
---|
154 | c / (t(.+k-j) - t(.))/(k-j)
|
---|
155 | c
|
---|
156 | c Then, we use repeatedly the fact that
|
---|
157 | c
|
---|
158 | c sum ( a(.)*b(.,m,t)(x) ) = sum ( a(.,x)*b(.,m-1,t)(x) )
|
---|
159 | c with
|
---|
160 | c (x - t(.))*a(.) + (t(.+m-1) - x)*a(.-1)
|
---|
161 | c a(.,x) = ---------------------------------------
|
---|
162 | c (x - t(.)) + (t(.+m-1) - x)
|
---|
163 | c
|
---|
164 | c to write (d**j)f(x) eventually as a linear combination of b-splines
|
---|
165 | c of order 1 , and the coefficient for b(i,1,t)(x) must then be the
|
---|
166 | c desired number (d**j)f(x). (see x.(17)-(19) of text).
|
---|
167 | */
|
---|
168 | [DllImport("sbart_x64.dll", CallingConvention = CallingConvention.Cdecl, EntryPoint = "bvalue")]
|
---|
169 | public static extern float bvalue(float[] t, float[] bcoeff, ref int n, ref int k, ref float x, ref int jderiv);
|
---|
170 |
|
---|
171 | public class SBART_Report {
|
---|
172 | public double smoothingParameter;
|
---|
173 | public double gcv;
|
---|
174 | public double[] leverage;
|
---|
175 | }
|
---|
176 |
|
---|
177 |
|
---|
178 | public static IRegressionModel CalculateSBART(double[] x, double[] y, double[] w, int nKnots, string targetVariable, string[] inputVars, out SBART_Report rep) {
|
---|
179 | // use kMeans to find knot points
|
---|
180 | double[,] xy = new double[x.Length, 1];
|
---|
181 | for (int i = 0; i < x.Length; i++) xy[i, 0] = x[i];
|
---|
182 | double[,] c;
|
---|
183 | int[] xyc;
|
---|
184 | int info;
|
---|
185 | alglib.kmeansgenerate(xy, x.Length, 1, nKnots, 10, out info, out c, out xyc);
|
---|
186 | var g = x.Zip(xyc, (double xi, int ci) => Tuple.Create(xi,ci)).GroupBy(t => t.Item2).Select(gr => HeuristicLab.Common.EnumerableStatisticExtensions.Median(gr.Select(gi=>gi.Item1))).ToArray();
|
---|
187 | // find medians
|
---|
188 | double[] knots = new double[nKnots];
|
---|
189 | for (int i = 0; i < g.Length; i++) knots[i] = g[i];
|
---|
190 | return CalculateSBART(x, y, w, knots, targetVariable, inputVars, out rep);
|
---|
191 | }
|
---|
192 |
|
---|
193 | public static IRegressionModel CalculateSBART(double[] x, double[] y, double[] w, double[] knots, string targetVariable, string[] inputVars, out SBART_Report rep) {
|
---|
194 | int ier = 99;
|
---|
195 | int tries = 0;
|
---|
196 | float tol = 0.01f;
|
---|
197 |
|
---|
198 | do {
|
---|
199 | tries++;
|
---|
200 | float[] xs = x.Select(xi => (float)xi).ToArray();
|
---|
201 | float[] ys = y.Select(xi => (float)xi).ToArray();
|
---|
202 | float[] ws = w.Select(xi => (float)xi).ToArray();
|
---|
203 | float[] k = knots.Select(xi => (float)xi).ToArray();
|
---|
204 |
|
---|
205 | int n = xs.Length;
|
---|
206 | if (n < 4) throw new ArgumentException("n < 4");
|
---|
207 | if (knots.Length > n + 2) throw new ArgumentException("more than n+2 knots");
|
---|
208 | float[] xw = new float[n];
|
---|
209 | int nx = -99;
|
---|
210 | float min = 0.0f;
|
---|
211 | float range = 0.0f;
|
---|
212 | int nk = -99;
|
---|
213 | float[] regKnots = new float[n + 6];
|
---|
214 |
|
---|
215 | // sort xs together with ys and ws
|
---|
216 | // combine rows with duplicate x values
|
---|
217 | // transform x to range [0 .. 1]
|
---|
218 | // create a set of knots (using a heuristic for the number of knots)
|
---|
219 | // knots are located at data points. denser regions of x contain more knots.
|
---|
220 | SBART.setreg_x64(xs, ys, ws,
|
---|
221 | ref n, xw, ref nx, ref min, ref range, regKnots, ref nk);
|
---|
222 |
|
---|
223 | // in this case we want to use the knots supplied by the caller.
|
---|
224 | // the knot values produced by setreg are overwritten with scaled knots supplied by caller.
|
---|
225 | // knots must be ordered as well.
|
---|
226 | int i = 0;
|
---|
227 | // left boundary
|
---|
228 | regKnots[i++] = 0.0f;
|
---|
229 | regKnots[i++] = 0.0f;
|
---|
230 | regKnots[i++] = 0.0f;
|
---|
231 | regKnots[i++] = 0.0f;
|
---|
232 | int j = 1;
|
---|
233 | foreach (var knot in knots.OrderBy(ki => ki)) {
|
---|
234 | regKnots[i++] = xs[j * nx / (knots.Length + 1)]; // ((float)knot - min) / range;
|
---|
235 | j++;
|
---|
236 | }
|
---|
237 | // right boundary
|
---|
238 | regKnots[i++] = 1.0f;
|
---|
239 | regKnots[i++] = 1.0f;
|
---|
240 | regKnots[i++] = 1.0f;
|
---|
241 | regKnots[i++] = 1.0f;
|
---|
242 | nk = i - 4;
|
---|
243 |
|
---|
244 | float criterion = -99.0f; // GCV
|
---|
245 | int icrit = 1; // calculate GCV
|
---|
246 | float smoothingParameter = -99.0f;
|
---|
247 | int smoothingParameterIndicator = 0;
|
---|
248 | float lowerSmoothingParameter = 0.0f;
|
---|
249 | float upperSmoothingParameter = 1.0f;
|
---|
250 | int isetup = 0; // not setup?
|
---|
251 |
|
---|
252 | // results
|
---|
253 | float[] coeff = new float[nk];
|
---|
254 | float[] leverage = new float[nx];
|
---|
255 | float[] y_smoothed = new float[nx];
|
---|
256 |
|
---|
257 | // working arrays for sbart
|
---|
258 | float[] xwy = new float[nk];
|
---|
259 | float[] hs0 = new float[nk];
|
---|
260 | float[] hs1 = new float[nk];
|
---|
261 | float[] hs2 = new float[nk];
|
---|
262 | float[] hs3 = new float[nk];
|
---|
263 | float[] sg0 = new float[nk];
|
---|
264 | float[] sg1 = new float[nk];
|
---|
265 | float[] sg2 = new float[nk];
|
---|
266 | float[] sg3 = new float[nk];
|
---|
267 | int ld4 = 4;
|
---|
268 | float[,] adb = new float[ld4, nk];
|
---|
269 |
|
---|
270 | float[,] p1ip = new float[nk, ld4];
|
---|
271 | int ldnk = nk + 4;
|
---|
272 | float[,] p2ip = new float[nk, nx];
|
---|
273 |
|
---|
274 | SBART.sbart_x64(xs.Take(nx).ToArray(), ys.Take(nx).ToArray(), ws.Take(nx).ToArray(), ref nx,
|
---|
275 | regKnots, ref nk,
|
---|
276 | coeff, y_smoothed, leverage,
|
---|
277 | ref criterion, ref icrit,
|
---|
278 | ref smoothingParameter, ref smoothingParameterIndicator, ref lowerSmoothingParameter, ref upperSmoothingParameter,
|
---|
279 | ref tol, ref isetup,
|
---|
280 | xwy, hs0, hs1, hs2, hs3, sg0, sg1, sg2, sg3, adb, p1ip, p2ip, ref ld4, ref ldnk, ref ier);
|
---|
281 |
|
---|
282 |
|
---|
283 | if (ier > 0) {
|
---|
284 | Console.WriteLine("ERROR {0} smooth {1} criterion {2}", ier, smoothingParameter, criterion);
|
---|
285 | tol *= 2;
|
---|
286 | tol = Math.Min(tol, 1.0f);
|
---|
287 | } else {
|
---|
288 | if (tries > 1) {
|
---|
289 | Console.WriteLine("Success {0} smooth {1} criterion {2}", ier, smoothingParameter, criterion);
|
---|
290 | }
|
---|
291 | rep = new SBART_Report();
|
---|
292 | rep.gcv = criterion;
|
---|
293 | rep.smoothingParameter = smoothingParameter;
|
---|
294 | rep.leverage = leverage.Select(li => (double)li).ToArray();
|
---|
295 | return new BartRegressionModel(regKnots.Take(nk + 4).ToArray(), coeff, targetVariable, inputVars, min, range);
|
---|
296 | }
|
---|
297 | } while (ier > 0);
|
---|
298 | throw new ArgumentException();
|
---|
299 | }
|
---|
300 |
|
---|
301 | public static IRegressionModel CalculateSBART(double[] x, double[] y,
|
---|
302 | string targetVariable, string[] inputVars,
|
---|
303 | out SBART_Report report) {
|
---|
304 | var w = Enumerable.Repeat(1.0, x.Length).ToArray();
|
---|
305 |
|
---|
306 | int n = x.Length;
|
---|
307 | int ic = n - 1;
|
---|
308 | int ier = -99;
|
---|
309 | int nk = n;
|
---|
310 | float[] knots = new float[nk + 6];
|
---|
311 |
|
---|
312 | float crit = -99.0f;
|
---|
313 | int icrit = 1; // 0..don't calc CV, 1 .. GCV, 2 CV
|
---|
314 |
|
---|
315 | float smoothingParameter = -99.0f;
|
---|
316 | int smoothingParameterIndicator = 0;
|
---|
317 | float lowerSmoothingParameter = 0f;
|
---|
318 | float upperSmoothingParameter = 1.0f;
|
---|
319 | float tol = 0.02f;
|
---|
320 | int isetup = 0; // not setup?
|
---|
321 |
|
---|
322 | float min = -99.0f;
|
---|
323 | float range = -99.0f;
|
---|
324 |
|
---|
325 | if (Environment.Is64BitProcess) {
|
---|
326 | float[] xw = new float[n];
|
---|
327 | int nx = -99;
|
---|
328 | float[] xs = x.Select(xi => (float)xi).ToArray();
|
---|
329 | float[] ys = y.Select(yi => (float)yi).ToArray();
|
---|
330 | float[] ws = w.Select(wi => (float)wi).ToArray();
|
---|
331 |
|
---|
332 | // sort xs together with ys and ws
|
---|
333 | // combine rows with duplicate x values
|
---|
334 | // create a set of knots (using a heuristic for the number of knots)
|
---|
335 | // knots are located at data points. denser regions of x contain more knots.
|
---|
336 | SBART.setreg_x64(xs, ys, ws,
|
---|
337 | ref n, xw, ref nx, ref min, ref range, knots, ref nk);
|
---|
338 |
|
---|
339 | /* use all points as knot points
|
---|
340 | nk = nx + 2;
|
---|
341 | knots[0] = xs[0];
|
---|
342 | knots[1] = xs[0];
|
---|
343 | knots[2] = xs[0];
|
---|
344 | Array.Copy(xs, 0, knots, 3, nx);
|
---|
345 | knots[nx + 3] = xs[nx - 1];
|
---|
346 | knots[nx + 4] = xs[nx - 1];
|
---|
347 | knots[nx + 5] = xs[nx - 1];
|
---|
348 | */
|
---|
349 |
|
---|
350 | /*
|
---|
351 | // use uniform grid of knots
|
---|
352 | nk = 20;
|
---|
353 | knots = new float[nk + 4];
|
---|
354 | knots[0] = xs[0];
|
---|
355 | knots[1] = xs[0];
|
---|
356 | knots[2] = xs[0];
|
---|
357 | for(int i = 3; i<nk+1;i++) {
|
---|
358 | knots[i] = (i-3f) / (nk-1);
|
---|
359 | }
|
---|
360 | knots[nk] = xs[nx - 1];
|
---|
361 | knots[nk + 1] = xs[nx - 1];
|
---|
362 | knots[nk + 2] = xs[nx - 1];
|
---|
363 | knots[nk + 3] = xs[nx - 1];
|
---|
364 | */
|
---|
365 | if (nx < 4) {
|
---|
366 | report = new SBART_Report();
|
---|
367 | report.leverage = new double[0];
|
---|
368 | return new ConstantModel(ys.Take(nx).Average(), targetVariable);
|
---|
369 | }
|
---|
370 |
|
---|
371 | float[] coeff = new float[nk];
|
---|
372 | float[] leverage = new float[nx];
|
---|
373 | float[] y_smoothed = new float[nx];
|
---|
374 |
|
---|
375 |
|
---|
376 | // working arrays for sbart
|
---|
377 | float[] xwy = new float[nk];
|
---|
378 | float[] hs0 = new float[nk];
|
---|
379 | float[] hs1 = new float[nk];
|
---|
380 | float[] hs2 = new float[nk];
|
---|
381 | float[] hs3 = new float[nk];
|
---|
382 | float[] sg0 = new float[nk];
|
---|
383 | float[] sg1 = new float[nk];
|
---|
384 | float[] sg2 = new float[nk];
|
---|
385 | float[] sg3 = new float[nk];
|
---|
386 | int ld4 = 4;
|
---|
387 | float[,] adb = new float[ld4, nk];
|
---|
388 |
|
---|
389 | float[,] p1ip = new float[nk, ld4];
|
---|
390 | int ldnk = nk + 4;
|
---|
391 | float[,] p2ip = new float[nk, nx];
|
---|
392 |
|
---|
393 | SBART.sbart_x64(xs.Take(nx).ToArray(), ys.Take(nx).ToArray(), ws.Take(nx).ToArray(), ref nx,
|
---|
394 | knots, ref nk,
|
---|
395 | coeff, y_smoothed, leverage,
|
---|
396 | ref crit, ref icrit,
|
---|
397 | ref smoothingParameter, ref smoothingParameterIndicator, ref lowerSmoothingParameter, ref upperSmoothingParameter,
|
---|
398 | ref tol, ref isetup,
|
---|
399 | xwy, hs0, hs1, hs2, hs3, sg0, sg1, sg2, sg3, adb, p1ip, p2ip, ref ld4, ref ldnk, ref ier);
|
---|
400 |
|
---|
401 | if (ier > 0) throw new ArgumentException(ier.ToString());
|
---|
402 |
|
---|
403 | report = new SBART_Report();
|
---|
404 | report.gcv = crit;
|
---|
405 | report.smoothingParameter = smoothingParameter;
|
---|
406 | report.leverage = leverage.Select(li => (double)li).ToArray();
|
---|
407 |
|
---|
408 | return new BartRegressionModel(knots.Take(nk+4).ToArray(), coeff, targetVariable, inputVars, min, range);
|
---|
409 |
|
---|
410 | } else {
|
---|
411 | throw new NotSupportedException();
|
---|
412 | }
|
---|
413 |
|
---|
414 | }
|
---|
415 |
|
---|
416 | public class BartRegressionModel : NamedItem, IRegressionModel {
|
---|
417 | private float[] knots;
|
---|
418 | private float[] bcoeff;
|
---|
419 | private double min;
|
---|
420 | private double range;
|
---|
421 | public string TargetVariable { get; set; }
|
---|
422 |
|
---|
423 | private string[] variablesUsedForPrediction;
|
---|
424 | public IEnumerable<string> VariablesUsedForPrediction {
|
---|
425 | get {
|
---|
426 | return variablesUsedForPrediction;
|
---|
427 | }
|
---|
428 | }
|
---|
429 |
|
---|
430 | public BartRegressionModel(BartRegressionModel orig, Cloner cloner) {
|
---|
431 | this.knots = orig.knots;
|
---|
432 | this.bcoeff = orig.bcoeff;
|
---|
433 | this.min = orig.min;
|
---|
434 | this.range = orig.range;
|
---|
435 | this.TargetVariable = orig.TargetVariable;
|
---|
436 | this.variablesUsedForPrediction = orig.variablesUsedForPrediction;
|
---|
437 | }
|
---|
438 | public BartRegressionModel(float[] knots, float[] bcoeff, string targetVariable, string[] inputVars, double min, double range) {
|
---|
439 | this.variablesUsedForPrediction = inputVars;
|
---|
440 | this.TargetVariable = targetVariable;
|
---|
441 | this.knots = knots;
|
---|
442 | this.bcoeff = bcoeff;
|
---|
443 | this.range = range;
|
---|
444 | this.min = min;
|
---|
445 | }
|
---|
446 |
|
---|
447 | public event EventHandler TargetVariableChanged;
|
---|
448 |
|
---|
449 | public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
|
---|
450 | return new RegressionSolution(this, (IRegressionProblemData) problemData.Clone());
|
---|
451 | }
|
---|
452 |
|
---|
453 | public double GetEstimatedValue(double xx) {
|
---|
454 | float x = (float)((xx - min) / range);
|
---|
455 | int n = bcoeff.Length;
|
---|
456 | int k = 4;
|
---|
457 | int zero = 0;
|
---|
458 | int one = 1;
|
---|
459 | int two = 2;
|
---|
460 |
|
---|
461 |
|
---|
462 | // linear extrapolation
|
---|
463 | if (x < 0) {
|
---|
464 | float x0 = 0.0f;
|
---|
465 | var y0 = bvalue(knots, bcoeff, ref n, ref k, ref x0, ref zero);
|
---|
466 | var y0d = bvalue(knots, bcoeff, ref n, ref k, ref x0, ref one);
|
---|
467 | return y0 + x * y0d;
|
---|
468 | }
|
---|
469 | if (x > 1) {
|
---|
470 | float x1 = 1.0f;
|
---|
471 | var y1 = bvalue(knots, bcoeff, ref n, ref k, ref x1, ref zero);
|
---|
472 | var y1d = bvalue(knots, bcoeff, ref n, ref k, ref x1, ref one);
|
---|
473 | return y1 + (x-1) * y1d;
|
---|
474 | }
|
---|
475 |
|
---|
476 | lock (this) {
|
---|
477 | return bvalue(knots, bcoeff, ref n, ref k, ref x, ref zero);
|
---|
478 | }
|
---|
479 |
|
---|
480 | // piecewise constant approximation
|
---|
481 | // if (xx <= x[0]) return bcoeff[0];
|
---|
482 | // if (xx >= x[n - 1]) return bcoeff[n - 1];
|
---|
483 | // for(int i=1;i<n-2;i++) {
|
---|
484 | // var h1 = xx - x[i];
|
---|
485 | // var h2 = xx - x[i + 1];
|
---|
486 | // if(h1 > 0 && h2 <= 0) {
|
---|
487 | // if (h1 < h2) return bcoeff[i]; else return bcoeff[i + 1];
|
---|
488 | // }
|
---|
489 | // }
|
---|
490 | // return 0.0;
|
---|
491 |
|
---|
492 | // // piecewise linear approximation
|
---|
493 | // int n = x.Length;
|
---|
494 | // if (xx <= x[0]) {
|
---|
495 | // double h = xx - x[0];
|
---|
496 | // return h * (y[1] - y[0]) / (x[1] - x[0]) + y[0];
|
---|
497 | // } else if (xx >= x[n-1]) {
|
---|
498 | // double h = xx - x[n-1];
|
---|
499 | // return h * (y[n-1] - y[n-2]) / (x[n-1] - x[n-2]) + y[n-1];
|
---|
500 | // } else {
|
---|
501 | // // binary search
|
---|
502 | // int lower = 0;
|
---|
503 | // int upper = n-1;
|
---|
504 | // while (true) {
|
---|
505 | // if (upper < lower) throw new InvalidProgramException();
|
---|
506 | // int i = lower + (upper - lower) / 2;
|
---|
507 | // if (x[i] <= xx && xx < x[i + 1]) {
|
---|
508 | // double h = xx - x[i];
|
---|
509 | // double k = (y[i + 1] - y[i]) / (x[i + 1] - x[i]);
|
---|
510 | // return h * k + y[i];
|
---|
511 | // } else if (xx < x[i]) {
|
---|
512 | // upper = i - 1;
|
---|
513 | // } else {
|
---|
514 | // lower = i + 1;
|
---|
515 | // }
|
---|
516 | // }
|
---|
517 | // }
|
---|
518 | // return 0.0;
|
---|
519 | }
|
---|
520 |
|
---|
521 | public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
|
---|
522 | foreach(var x in dataset.GetDoubleValues(VariablesUsedForPrediction.First(), rows)) {
|
---|
523 | yield return GetEstimatedValue(x);
|
---|
524 | }
|
---|
525 | }
|
---|
526 |
|
---|
527 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
528 | return new BartRegressionModel(this, cloner);
|
---|
529 | }
|
---|
530 | }
|
---|
531 |
|
---|
532 | }
|
---|
533 | }
|
---|