Ignore:
Timestamp:
11/09/17 18:03:06 (4 years ago)
Author:
gkronber
Message:

#2789 worked on sbart spline

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/MathNetNumerics-Exploration-2789/HeuristicLab.Algorithms.DataAnalysis.Experimental/SBART.cs

    r15459 r15468  
    112112    [DllImport("sbart_x64.dll", CallingConvention = CallingConvention.Cdecl, EntryPoint = "setreg")]
    113113    public static extern void setreg_x64(float[] x, float[] y, float[] w, ref int n, float[] xw, ref int nx, ref float min, ref float range, float[] knot, ref int nk);
     114    /*
     115     * calculates value at  x  of  jderiv-th derivative of spline from b-repr.
     116    c  the spline is taken to be continuous from the right, EXCEPT at the
     117    c  rightmost knot, where it is taken to be continuous from the left.
     118    c
     119    c******  i n p u t ******
     120    c  t, bcoef, n, k......forms the b-representation of the spline  f  to
     121    c        be evaluated. specifically,
     122    c  t.....knot sequence, of length  n+k, assumed nondecreasing.
     123    c  bcoef.....b-coefficient sequence, of length  n .
     124    c  n.....length of  bcoef  and dimension of spline(k,t),
     125    c        a s s u m e d  positive .
     126    c  k.....order of the spline .
     127    c
     128    c  w a r n i n g . . .   the restriction  k .le. kmax (=20)  is imposed
     129    c        arbitrarily by the dimension statement for  aj, dl, dr  below,
     130    c        but is  n o w h e r e  c h e c k e d  for.
     131    c
     132    c  x.....the point at which to evaluate .
     133    c  jderiv.....integer giving the order of the derivative to be evaluated
     134    c        a s s u m e d  to be zero or positive.
     135    c
     136    c******  o u t p u t  ******
     137    c  bvalue.....the value of the (jderiv)-th derivative of  f  at  x .
     138    c
     139    c******  m e t h o d  ******
     140    c     The nontrivial knot interval  (t(i),t(i+1))  containing  x  is lo-
     141    c  cated with the aid of  interv . The  k  b-coeffs of  f  relevant for
     142    c  this interval are then obtained from  bcoef (or taken to be zero if
     143    c  not explicitly available) and are then differenced  jderiv  times to
     144    c  obtain the b-coeffs of  (d**jderiv)f  relevant for that interval.
     145    c  Precisely, with  j = jderiv, we have from x.(12) of the text that
     146    c
     147    c     (d**j)f  =  sum ( bcoef(.,j)*b(.,k-j,t) )
     148    c
     149    c  where
     150    c                   / bcoef(.),                     ,  j .eq. 0
     151    c                   /
     152    c    bcoef(.,j)  =  / bcoef(.,j-1) - bcoef(.-1,j-1)
     153    c                   / ----------------------------- ,  j .gt. 0
     154    c                   /    (t(.+k-j) - t(.))/(k-j)
     155    c
     156    c     Then, we use repeatedly the fact that
     157    c
     158    c    sum ( a(.)*b(.,m,t)(x) )  =  sum ( a(.,x)*b(.,m-1,t)(x) )
     159    c  with
     160    c                 (x - t(.))*a(.) + (t(.+m-1) - x)*a(.-1)
     161    c    a(.,x)  =    ---------------------------------------
     162    c                 (x - t(.))      + (t(.+m-1) - x)
     163    c
     164    c  to write  (d**j)f(x)  eventually as a linear combination of b-splines
     165    c  of order  1 , and the coefficient for  b(i,1,t)(x)  must then be the
     166    c  desired number  (d**j)f(x). (see x.(17)-(19) of text).
     167     */
     168    [DllImport("sbart_x64.dll", CallingConvention = CallingConvention.Cdecl, EntryPoint = "bvalue")]
     169    public static extern float bvalue(float[] t, float[] bcoeff, ref int n, ref int k, ref float x, ref int jderiv);
    114170
    115171    public class SBART_Report {
     
    119175    }
    120176
     177
     178    public static IRegressionModel CalculateSBART(double[] x, double[] y, double[] w, int nKnots, string targetVariable, string[] inputVars, out SBART_Report rep) {
     179      // use kMeans to find knot points
     180      double[,] xy = new double[x.Length, 1];
     181      for (int i = 0; i < x.Length; i++) xy[i, 0] = x[i];
     182      double[,] c;
     183      int[] xyc;
     184      int info;
     185      alglib.kmeansgenerate(xy, x.Length, 1, nKnots, 10, out info, out c, out xyc);
     186      double[] knots = new double[nKnots];
     187      for (int i = 0; i < knots.Length; i++) knots[i] = c[0, i];
     188      return CalculateSBART(x, y, w, knots, targetVariable, inputVars, out rep);
     189    }
     190
     191    public static IRegressionModel CalculateSBART(double[] x, double[] y, double[] w, double[] knots, string targetVariable, string[] inputVars, out SBART_Report rep) {
     192      float[] xs = x.Select(xi=>(float)xi).ToArray();
     193      float[] ys = y.Select(xi => (float)xi).ToArray();
     194      float[] ws = w.Select(xi => (float)xi).ToArray();
     195      float[] k = knots.Select(xi => (float)xi).ToArray();
     196
     197      int n = xs.Length;
     198      if (n < 4) throw new ArgumentException("n < 4");
     199      if (knots.Length > n + 2) throw new ArgumentException("more than n+2 knots");
     200      float[] xw = new float[n];
     201      int nx = -99;
     202      float min = 0.0f;
     203      float range = 0.0f;
     204      int nk = -99;
     205      float[] regKnots = new float[n + 6];
     206
     207      // sort xs together with ys and ws
     208      // combine rows with duplicate x values
     209      // transform x to range [0 .. 1]
     210      // create a set of knots (using a heuristic for the number of knots)
     211      // knots are located at data points. denser regions of x contain more knots.
     212      SBART.setreg_x64(xs, ys, ws,
     213        ref n, xw, ref nx, ref min, ref range, regKnots, ref nk);
     214
     215      // in this case we want to use the knots supplied by the caller.
     216      // the knot values produced by setreg are overwritten with scaled knots supplied by caller.
     217      // knots must be ordered as well.
     218      int i = 0;
     219      // left boundary
     220      regKnots[i++] = 0.0f;
     221      regKnots[i++] = 0.0f;
     222      regKnots[i++] = 0.0f;
     223      regKnots[i++] = 0.0f;
     224      foreach (var knot in knots.OrderBy(ki=>ki)) {
     225        regKnots[i++] = ((float)knot - min) / range;
     226      }
     227      // right boundary
     228      regKnots[i++] = 1.0f;
     229      regKnots[i++] = 1.0f;
     230      regKnots[i++] = 1.0f;
     231      regKnots[i++] = 1.0f;
     232      nk = knots.Length + 4;
     233
     234      float criterion = -99.0f; // GCV
     235      int icrit = 1; // calculate GCV
     236      float smoothingParameter = -99.0f;
     237      int smoothingParameterIndicator = 0;
     238      float lowerSmoothingParameter = 0.5f;
     239      float upperSmoothingParameter = 1.0f;
     240      float tol = 0.01f;
     241      int isetup = 0; // not setup?
     242
     243      // results
     244      float[] coeff = new float[nk];
     245      float[] leverage = new float[nx];
     246      float[] y_smoothed = new float[nx];
     247      int ier = -99;
     248
     249
     250      // working arrays for sbart
     251      float[] xwy = new float[nk];
     252      float[] hs0 = new float[nk];
     253      float[] hs1 = new float[nk];
     254      float[] hs2 = new float[nk];
     255      float[] hs3 = new float[nk];
     256      float[] sg0 = new float[nk];
     257      float[] sg1 = new float[nk];
     258      float[] sg2 = new float[nk];
     259      float[] sg3 = new float[nk];
     260      int ld4 = 4;
     261      float[,] adb = new float[ld4, nk];
     262
     263      float[,] p1ip = new float[nk, ld4];
     264      int ldnk = nk + 4;
     265      float[,] p2ip = new float[nk, nx];
     266
     267      SBART.sbart_x64(xs.Take(nx).ToArray(), ys.Take(nx).ToArray(), ws.Take(nx).ToArray(), ref nx,
     268        regKnots, ref nk,
     269        coeff, y_smoothed, leverage,
     270        ref criterion, ref icrit,
     271        ref smoothingParameter, ref smoothingParameterIndicator, ref lowerSmoothingParameter, ref upperSmoothingParameter,
     272        ref tol, ref isetup,
     273        xwy, hs0, hs1, hs2, hs3, sg0, sg1, sg2, sg3, adb, p1ip, p2ip, ref ld4, ref ldnk, ref ier);
     274
     275      if (ier > 0) throw new ArgumentException(ier.ToString());
     276
     277      rep = new SBART_Report();
     278      rep.gcv = criterion;
     279      rep.smoothingParameter = smoothingParameter;
     280      rep.leverage = leverage.Select(li => (double)li).ToArray();
     281
     282      return new BartRegressionModel(regKnots.Take(nk+4).ToArray(), coeff, targetVariable, inputVars, min, range);
     283    }
     284
    121285    public static IRegressionModel CalculateSBART(double[] x, double[] y,
    122       string targetVariable, string[] inputVars, float smoothingParameter,
     286      string targetVariable, string[] inputVars,
    123287      out SBART_Report report) {
    124288      var w = Enumerable.Repeat(1.0, x.Length).ToArray();
     
    133297      int icrit = 1; // 0..don't calc CV,  1 .. GCV, 2 CV
    134298
    135       // float smoothingParameter = -99.0f;
     299      float smoothingParameter = -99.0f;
    136300      int smoothingParameterIndicator = 0;
    137301      float lowerSmoothingParameter = 0f;
    138302      float upperSmoothingParameter = 1.0f;
    139       float tol = 0.01f;
     303      float tol = 0.02f;
    140304      int isetup = 0; // not setup?
    141305
     
    149313        float[] ys = y.Select(yi => (float)yi).ToArray();
    150314        float[] ws = w.Select(wi => (float)wi).ToArray();
     315       
     316        // sort xs together with ys and ws
     317        // combine rows with duplicate x values
     318        // create a set of knots (using a heuristic for the number of knots)
     319        // knots are located at data points. denser regions of x contain more knots.
    151320        SBART.setreg_x64(xs, ys, ws,
    152321          ref n, xw, ref nx, ref min, ref range, knots, ref nk);
     
    163332        */
    164333
     334        /*
    165335        // use uniform grid of knots
    166336        nk = 20;
     
    176346        knots[nk + 2] = xs[nx - 1];
    177347        knots[nk + 3] = xs[nx - 1];
    178 
     348        */
    179349        if (nx < 4) {
    180350          report = new SBART_Report();
     
    220390        report.leverage = leverage.Select(li => (double)li).ToArray();
    221391
    222         return new BartRegressionModel(xs.Take(nx).ToArray(), y_smoothed.Take(nx).ToArray(), targetVariable, inputVars, min, range);
     392        return new BartRegressionModel(knots.Take(nk+4).ToArray(), coeff, targetVariable, inputVars, min, range);
    223393
    224394      } else {
     
    229399
    230400    public class BartRegressionModel : NamedItem, IRegressionModel {
    231       private float[] x;
    232       private float[] y;
     401      private float[] knots;
     402      private float[] bcoeff;
    233403      private double min;
    234404      private double range;
     
    243413
    244414      public BartRegressionModel(BartRegressionModel orig, Cloner cloner) {
    245         this.x = orig.x;
    246         this.y = orig.y;
     415        this.knots = orig.knots;
     416        this.bcoeff = orig.bcoeff;
    247417        this.min = orig.min;
    248418        this.range = orig.range;
     
    250420        this.variablesUsedForPrediction = orig.variablesUsedForPrediction;
    251421      }
    252       public BartRegressionModel(float[] x, float[] y_smoothed, string targetVariable, string[] inputVars, double min, double range) {
     422      public BartRegressionModel(float[] knots, float[] bcoeff, string targetVariable, string[] inputVars, double min, double range) {
    253423        this.variablesUsedForPrediction = inputVars;
    254424        this.TargetVariable = targetVariable;
    255         this.x = x;
    256         this.y = y_smoothed;
     425        this.knots = knots;
     426        this.bcoeff = bcoeff;
    257427        this.range = range;
    258428        this.min = min;
     
    266436
    267437      public double GetEstimatedValue(double xx) {
    268         xx = (xx - min) / range;
    269         // piecewise linear approximation
    270         int n = x.Length;
    271         if (xx <= x[0]) {
    272           double h = xx - x[0];
    273           return h * (y[1] - y[0]) / (x[1] - x[0]) + y[0];
    274         } else if (xx >= x[n-1]) {
    275           double h = xx - x[n-1];
    276           return h * (y[n-1] - y[n-2]) / (x[n-1] - x[n-2]) + y[n-1];
    277         } else {
    278           // binary search
    279           int lower = 0;
    280           int upper = n-1;
    281           while (true) {
    282             if (upper < lower) throw new InvalidProgramException();
    283             int i = lower + (upper - lower) / 2;
    284             if (x[i] <= xx && xx < x[i + 1]) {
    285               double h = xx - x[i];
    286               double k = (y[i + 1] - y[i]) / (x[i + 1] - x[i]);
    287               return h * k + y[i];
    288             } else if (xx < x[i]) {
    289               upper = i - 1;
    290             } else {
    291               lower = i + 1;
    292             }
    293           }
    294         }
    295         return 0.0;
     438        float x = (float)((xx - min) / range);
     439        int n = bcoeff.Length;
     440        int k = 4;
     441        int zero = 0;
     442        int one = 1;
     443        int two = 2;
     444
     445
     446        // linear extrapolation
     447        if (x < 0) {
     448          float x0 = 0.0f;
     449          var y0 = bvalue(knots, bcoeff, ref n, ref k, ref x0, ref zero);
     450          var y0d = bvalue(knots, bcoeff, ref n, ref k, ref x0, ref one);
     451          return y0 + x * y0d;
     452        }
     453        if (x > 1) {
     454          float x1 = 1.0f;
     455          var y1 = bvalue(knots, bcoeff, ref n, ref k, ref x1, ref zero);
     456          var y1d = bvalue(knots, bcoeff, ref n, ref k, ref x1, ref one);
     457          return y1 + (x-1) * y1d;
     458        }
     459
     460        lock (this) {
     461          return bvalue(knots, bcoeff, ref n, ref k, ref x, ref zero);
     462        }
     463
     464        // piecewise constant approximation
     465        // if (xx <= x[0]) return bcoeff[0];
     466        // if (xx >= x[n - 1]) return bcoeff[n - 1];
     467        // for(int i=1;i<n-2;i++) {
     468        //   var h1 = xx - x[i];
     469        //   var h2 = xx - x[i + 1];
     470        //   if(h1 > 0 && h2 <= 0) {
     471        //     if (h1 < h2) return bcoeff[i]; else return bcoeff[i + 1];
     472        //   }
     473        // }
     474        // return 0.0;
     475
     476        // // piecewise linear approximation
     477        // int n = x.Length;
     478        // if (xx <= x[0]) {
     479        //   double h = xx - x[0];
     480        //   return h * (y[1] - y[0]) / (x[1] - x[0]) + y[0];
     481        // } else if (xx >= x[n-1]) {
     482        //   double h = xx - x[n-1];
     483        //   return h * (y[n-1] - y[n-2]) / (x[n-1] - x[n-2]) + y[n-1];
     484        // } else {
     485        //   // binary search
     486        //   int lower = 0;
     487        //   int upper = n-1;
     488        //   while (true) {
     489        //     if (upper < lower) throw new InvalidProgramException();
     490        //     int i = lower + (upper - lower) / 2;
     491        //     if (x[i] <= xx && xx < x[i + 1]) {
     492        //       double h = xx - x[i];
     493        //       double k = (y[i + 1] - y[i]) / (x[i + 1] - x[i]);
     494        //       return h * k + y[i];
     495        //     } else if (xx < x[i]) {
     496        //       upper = i - 1;
     497        //     } else {
     498        //       lower = i + 1;
     499        //     }
     500        //   }
     501        // }
     502        // return 0.0;
    296503      }
    297504
Note: See TracChangeset for help on using the changeset viewer.