source: branches/3106_AnalyticContinuedFractionsRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/ContinuedFractionRegression/Algorithm.cs @ 17984

Last change on this file since 17984 was 17984, checked in by gkronber, 5 months ago

#3106 updated implementation based on the reply by Moscato

File size: 27.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Threading;
5using HEAL.Attic;
6using HeuristicLab.Analysis;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
11using HeuristicLab.Parameters;
12using HeuristicLab.Problems.DataAnalysis;
13using HeuristicLab.Problems.DataAnalysis.Symbolic;
14using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
15using HeuristicLab.Random;
16
17namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression {
18  /// <summary>
19  /// Implementation of Continuous Fraction Regression (CFR) as described in
20  /// Pablo Moscato, Haoyuan Sun, Mohammad Nazmul Haque,
21  /// Analytic Continued Fractions for Regression: A Memetic Algorithm Approach,
22  /// Expert Systems with Applications, Volume 179, 2021, 115018, ISSN 0957-4174,
23  /// https://doi.org/10.1016/j.eswa.2021.115018.
24  /// </summary>
25  [Item("Continuous Fraction Regression (CFR)", "TODO")]
26  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 999)]
27  [StorableType("7A375270-EAAF-4AD1-82FF-132318D20E09")]
28  public class Algorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
29    private const string MutationRateParameterName = "MutationRate";
30    private const string DepthParameterName = "Depth";
31    private const string NumGenerationsParameterName = "NumGenerations";
32    private const string StagnationGenerationsParameterName = "StagnationGenerations";
33    private const string LocalSearchIterationsParameterName = "LocalSearchIterations";
34    private const string LocalSearchRestartsParameterName = "LocalSearchRestarts";
35    private const string LocalSearchToleranceParameterName = "LocalSearchTolerance";
36    private const string DeltaParameterName = "Delta";
37    private const string ScaleDataParameterName = "ScaleData";
38
39
40    #region parameters
41    public IFixedValueParameter<PercentValue> MutationRateParameter => (IFixedValueParameter<PercentValue>)Parameters[MutationRateParameterName];
42    public double MutationRate {
43      get { return MutationRateParameter.Value.Value; }
44      set { MutationRateParameter.Value.Value = value; }
45    }
46    public IFixedValueParameter<IntValue> DepthParameter => (IFixedValueParameter<IntValue>)Parameters[DepthParameterName];
47    public int Depth {
48      get { return DepthParameter.Value.Value; }
49      set { DepthParameter.Value.Value = value; }
50    }
51    public IFixedValueParameter<IntValue> NumGenerationsParameter => (IFixedValueParameter<IntValue>)Parameters[NumGenerationsParameterName];
52    public int NumGenerations {
53      get { return NumGenerationsParameter.Value.Value; }
54      set { NumGenerationsParameter.Value.Value = value; }
55    }
56    public IFixedValueParameter<IntValue> StagnationGenerationsParameter => (IFixedValueParameter<IntValue>)Parameters[StagnationGenerationsParameterName];
57    public int StagnationGenerations {
58      get { return StagnationGenerationsParameter.Value.Value; }
59      set { StagnationGenerationsParameter.Value.Value = value; }
60    }
61    public IFixedValueParameter<IntValue> LocalSearchIterationsParameter => (IFixedValueParameter<IntValue>)Parameters[LocalSearchIterationsParameterName];
62    public int LocalSearchIterations {
63      get { return LocalSearchIterationsParameter.Value.Value; }
64      set { LocalSearchIterationsParameter.Value.Value = value; }
65    }
66    public IFixedValueParameter<IntValue> LocalSearchRestartsParameter => (IFixedValueParameter<IntValue>)Parameters[LocalSearchRestartsParameterName];
67    public int LocalSearchRestarts {
68      get { return LocalSearchRestartsParameter.Value.Value; }
69      set { LocalSearchRestartsParameter.Value.Value = value; }
70    }
71    public IFixedValueParameter<DoubleValue> LocalSearchToleranceParameter => (IFixedValueParameter<DoubleValue>)Parameters[LocalSearchToleranceParameterName];
72    public double LocalSearchTolerance {
73      get { return LocalSearchToleranceParameter.Value.Value; }
74      set { LocalSearchToleranceParameter.Value.Value = value; }
75    }
76    public IFixedValueParameter<PercentValue> DeltaParameter => (IFixedValueParameter<PercentValue>)Parameters[DeltaParameterName];
77    public double Delta {
78      get { return DeltaParameter.Value.Value; }
79      set { DeltaParameter.Value.Value = value; }
80    }
81    public IFixedValueParameter<BoolValue> ScaleDataParameter => (IFixedValueParameter<BoolValue>)Parameters[ScaleDataParameterName];
82    public bool ScaleData {
83      get { return ScaleDataParameter.Value.Value; }
84      set { ScaleDataParameter.Value.Value = value; }
85    }
86    #endregion
87
88    // storable ctor
89    [StorableConstructor]
90    public Algorithm(StorableConstructorFlag _) : base(_) { }
91
92    // cloning ctor
93    public Algorithm(Algorithm original, Cloner cloner) : base(original, cloner) { }
94
95
96    // default ctor
97    public Algorithm() : base() {
98      Problem = new RegressionProblem();
99      Parameters.Add(new FixedValueParameter<PercentValue>(MutationRateParameterName, "Mutation rate (default 10%)", new PercentValue(0.1)));
100      Parameters.Add(new FixedValueParameter<IntValue>(DepthParameterName, "Depth of the continued fraction representation (default 6)", new IntValue(6)));
101      Parameters.Add(new FixedValueParameter<IntValue>(NumGenerationsParameterName, "The maximum number of generations (default 200)", new IntValue(200)));
102      Parameters.Add(new FixedValueParameter<IntValue>(StagnationGenerationsParameterName, "Number of generations after which the population is re-initialized (default value 5)", new IntValue(5)));
103      Parameters.Add(new FixedValueParameter<IntValue>(LocalSearchIterationsParameterName, "Number of iterations for local search (simplex) (default value 250)", new IntValue(250)));
104      Parameters.Add(new FixedValueParameter<IntValue>(LocalSearchRestartsParameterName, "Number of restarts for local search (default value 4)", new IntValue(4)));
105      Parameters.Add(new FixedValueParameter<DoubleValue>(LocalSearchToleranceParameterName, "The tolerance value for local search (simplex) (default value: 1e-3)", new DoubleValue(1e-3)));
106      Parameters.Add(new FixedValueParameter<PercentValue>(DeltaParameterName, "The relative weight for the number of variables term in the fitness function (default value: 10%)", new PercentValue(0.1)));
107      Parameters.Add(new FixedValueParameter<BoolValue>(ScaleDataParameterName, "Turns on/off scaling of input variable values to the range [0 .. 1] (default: false)", new BoolValue(false)));
108    }
109
110    public override IDeepCloneable Clone(Cloner cloner) {
111      return new Algorithm(this, cloner);
112    }
113
114    protected override void Run(CancellationToken cancellationToken) {
115      var problemData = Problem.ProblemData;
116      double[,] xy;
117      if (ScaleData) {
118        // Scale data to range 0 .. 1
119        //
120        // Scaling was not used for the experiments in the paper. Statement by the authors: "We did not pre-process the data."
121        var transformations = new List<Transformation<double>>();
122        foreach (var input in problemData.AllowedInputVariables) {
123          var values = problemData.Dataset.GetDoubleValues(input, problemData.TrainingIndices);
124          var linTransformation = new LinearTransformation(problemData.AllowedInputVariables);
125          var min = values.Min();
126          var max = values.Max();
127          var range = max - min;
128          linTransformation.Addend = -min / range;
129          linTransformation.Multiplier = 1.0 / range;
130          transformations.Add(linTransformation);
131        }
132        // do not scale the target
133        transformations.Add(new LinearTransformation(problemData.AllowedInputVariables) { Addend = 0.0, Multiplier = 1.0 });
134        xy = problemData.Dataset.ToArray(problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }),
135          transformations,
136          problemData.TrainingIndices);
137      } else {
138        // no transformation
139        xy = problemData.Dataset.ToArray(problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }),
140          problemData.TrainingIndices);
141      }
142      var nVars = xy.GetLength(1) - 1;
143      var seed = new System.Random().Next();
144      var rand = new MersenneTwister((uint)seed);
145      CFRAlgorithm(nVars, Depth, MutationRate, xy, out var bestObj, rand, NumGenerations, StagnationGenerations,
146        Delta,
147        LocalSearchIterations, LocalSearchRestarts, LocalSearchTolerance, cancellationToken);
148    }
149
150    private void CFRAlgorithm(int nVars, int depth, double mutationRate, double[,] trainingData,
151      out double bestObj,
152      IRandom rand, int numGen, int stagnatingGens, double evalDelta,
153      int localSearchIterations, int localSearchRestarts, double localSearchTolerance,
154      CancellationToken cancellationToken) {
155      /* Algorithm 1 */
156      /* Generate initial population by a randomized algorithm */
157      var pop = InitialPopulation(nVars, depth, rand, trainingData);
158      bestObj = pop.pocketObjValue;
159      // the best value since the last reset
160      var episodeBestObj = pop.pocketObjValue;
161      var episodeBestObjGen = 0;
162      for (int gen = 1; gen <= numGen && !cancellationToken.IsCancellationRequested; gen++) {
163        /* mutate each current solution in the population */
164        var pop_mu = Mutate(pop, mutationRate, rand);
165        /* generate new population by recombination mechanism */
166        var pop_r = RecombinePopulation(pop_mu, rand, nVars);
167
168        // "We ran the Local Search after Mutation and recombination operations. We executed the local-search only on the Current solutions."
169        // "We executed the MaintainInvariant() in the following steps:
170        // - After generating the initial population
171        // - after resetting the root
172        // - after executing the local-search on the whole population.
173        // We updated the pocket/ current automatically after mutation and recombination operation."
174
175        /* local search optimization of current solutions */
176        foreach (var agent in pop_r.IterateLevels()) {
177          LocalSearchSimplex(localSearchIterations, localSearchRestarts, localSearchTolerance, evalDelta, agent.current, ref agent.currentObjValue, trainingData, rand);
178        }
179        foreach (var agent in pop_r.IteratePostOrder()) agent.MaintainInvariant(); // post-order to make sure that the root contains the best model
180
181
182        // for detecting stagnation we track the best objective value since the last reset
183        // and reset if this does not change for stagnatingGens
184        if (gen > episodeBestObjGen + stagnatingGens) {
185          Reset(pop_r, nVars, depth, rand, trainingData);
186          episodeBestObj = double.MaxValue;
187        }
188        if (episodeBestObj > pop_r.pocketObjValue) {
189          episodeBestObjGen = gen; // wait at least stagnatingGens until resetting again
190          episodeBestObj = pop_r.pocketObjValue;
191        }
192
193        /* replace old population with evolved population */
194        pop = pop_r;
195
196        /* keep track of the best solution */
197        if (bestObj > pop.pocketObjValue) {
198          bestObj = pop.pocketObjValue;
199          Results.AddOrUpdateResult("MSE (best)", new DoubleValue(bestObj));
200          Results.AddOrUpdateResult("Solution", CreateSymbolicRegressionSolution(pop.pocket, trainingData, Problem.ProblemData.AllowedInputVariables.ToArray(), Problem.ProblemData.TargetVariable));
201        }
202
203        #region visualization and debugging
204        DataTable qualities;
205        int i = 0;
206        if (Results.TryGetValue("Qualities", out var qualitiesResult)) {
207          qualities = (DataTable)qualitiesResult.Value;
208        } else {
209          qualities = new DataTable("Qualities", "Qualities");
210          i = 0;
211          foreach (var node in pop.IterateLevels()) {
212            qualities.Rows.Add(new DataRow($"Quality {i} pocket", "Quality of pocket"));
213            qualities.Rows.Add(new DataRow($"Quality {i} current", "Quality of current"));
214            i++;
215          }
216          Results.AddOrUpdateResult("Qualities", qualities);
217        }
218        i = 0;
219        foreach (var node in pop.IterateLevels()) {
220          qualities.Rows[$"Quality {i} pocket"].Values.Add(node.pocketObjValue);
221          qualities.Rows[$"Quality {i} current"].Values.Add(node.currentObjValue);
222          i++;
223        }
224        #endregion
225      }
226    }
227
228
229
230    private Agent InitialPopulation(int nVars, int depth, IRandom rand, double[,] trainingData) {
231      /* instantiate 13 agents in the population */
232      var pop = new Agent();
233      // see Figure 2
234      for (int i = 0; i < 3; i++) {
235        pop.children.Add(new Agent());
236        for (int j = 0; j < 3; j++) {
237          pop.children[i].children.Add(new Agent());
238        }
239      }
240
241      // Statement by the authors: "Yes, we use post-order traversal here"
242      foreach (var agent in pop.IteratePostOrder()) {
243        agent.current = new ContinuedFraction(nVars, depth, rand);
244        agent.pocket = new ContinuedFraction(nVars, depth, rand);
245
246        agent.currentObjValue = Evaluate(agent.current, trainingData, Delta);
247        agent.pocketObjValue = Evaluate(agent.pocket, trainingData, Delta);
248
249        /* within each agent, the pocket solution always holds the better value of guiding
250         * function than its current solution
251         */
252        agent.MaintainInvariant();
253      }
254      return pop;
255    }
256
257    // Reset is not described in detail in the paper.
258    // Statement by the authors: "We only replaced the pocket solution of the root with
259    // a randomly generated solution. Then we execute the maintain-invariant process.
260    // It does not initialize the solutions in the entire population."
261    private void Reset(Agent root, int nVars, int depth, IRandom rand, double[,] trainingData) {
262      root.pocket = new ContinuedFraction(nVars, depth, rand);
263      root.current = new ContinuedFraction(nVars, depth, rand);
264
265      root.currentObjValue = Evaluate(root.current, trainingData, Delta);
266      root.pocketObjValue = Evaluate(root.pocket, trainingData, Delta);
267
268      foreach (var agent in root.IteratePreOrder()) { agent.MaintainInvariant(); } // Here we push the newly created model down the hierarchy.
269    }
270
271
272
273    private Agent RecombinePopulation(Agent pop, IRandom rand, int nVars) {
274      var l = pop;
275
276      if (pop.children.Count > 0) {
277        var s1 = pop.children[0];
278        var s2 = pop.children[1];
279        var s3 = pop.children[2];
280
281        // Statement by the authors: "we are using recently generated solutions.
282        // For an example, in step 1 we got the new current(l), which is being used in
283        // Step 2 to generate current(s3). The current(s3) from Step 2 is being used at
284        // Step 4. These steps are executed sequentially from 1 to 4. Similarly, in the
285        // recombination of lower-level subpopulations, we will have the new current
286        // (the supporters generated at the previous level) as the leader of the subpopulation.
287        l.current = Recombine(l.pocket, s1.current, SelectRandomOp(rand), rand, nVars);
288        s3.current = Recombine(s3.pocket, l.current, SelectRandomOp(rand), rand, nVars);
289        s1.current = Recombine(s1.pocket, s2.current, SelectRandomOp(rand), rand, nVars);
290        s2.current = Recombine(s2.pocket, s3.current, SelectRandomOp(rand), rand, nVars);
291
292        // recombination works from top to bottom
293        foreach (var child in pop.children) {
294          RecombinePopulation(child, rand, nVars);
295        }
296
297      }
298      return pop;
299    }
300
301    private static ContinuedFraction Recombine(ContinuedFraction p1, ContinuedFraction p2, Func<bool[], bool[], bool[]> op, IRandom rand, int nVars) {
302      ContinuedFraction ch = new ContinuedFraction() { h = new Term[p1.h.Length] };
303      /* apply a recombination operator chosen uniformly at random on variables of two parents into offspring */
304      ch.vars = op(p1.vars, p2.vars);
305
306      /* recombine the coefficients for each term (h) of the continued fraction */
307      for (int i = 0; i < p1.h.Length; i++) {
308        var coefa = p1.h[i].coef; var varsa = p1.h[i].vars;
309        var coefb = p2.h[i].coef; var varsb = p2.h[i].vars;
310
311        /* recombine coefficient values for variables */
312        var coefx = new double[nVars];
313        var varsx = new bool[nVars]; // deviates from paper, probably forgotten in the pseudo-code
314        for (int vi = 0; vi < nVars; vi++) {
315          if (ch.vars[vi]) {  // CHECK: paper uses featAt()
316            if (varsa[vi] && varsb[vi]) {
317              coefx[vi] = coefa[vi] + (rand.NextDouble() * 5 - 1) * (coefb[vi] - coefa[vi]) / 3.0;
318              varsx[vi] = true;
319            } else if (varsa[vi]) {
320              coefx[vi] = coefa[vi];
321              varsx[vi] = true;
322            } else if (varsb[vi]) {
323              coefx[vi] = coefb[vi];
324              varsx[vi] = true;
325            }
326          }
327        }
328        /* update new coefficients of the term in offspring */
329        ch.h[i] = new Term() { coef = coefx, vars = varsx };
330        /* compute new value of constant (beta) for term hi in the offspring solution ch using
331         * beta of p1.hi and p2.hi */
332        ch.h[i].beta = p1.h[i].beta + (rand.NextDouble() * 5 - 1) * (p2.h[i].beta - p1.h[i].beta) / 3.0;
333      }
334      // return LocalSearchSimplex(ch, trainingData); // The paper has a local search step here.
335      // The authors have stated that local search is executed after mutation and recombination
336      // for the current solutions.
337      // Local search and MaintainInvariant is called in the main loop (Alg 1)
338      return ch;
339    }
340
341    private Agent Mutate(Agent pop, double mutationRate, IRandom rand) {
342      foreach (var agent in pop.IterateLevels()) {
343        if (rand.NextDouble() < mutationRate) {
344          if (agent.currentObjValue < 1.2 * agent.pocketObjValue ||
345              agent.currentObjValue > 2 * agent.pocketObjValue)
346            ToggleVariables(agent.current, rand); // major mutation
347          else
348            ModifyVariable(agent.current, rand); // soft mutation
349        }
350      }
351      return pop;
352    }
353
354    private void ToggleVariables(ContinuedFraction cfrac, IRandom rand) {
355      double coinToss(double a, double b) {
356        return rand.NextDouble() < 0.5 ? a : b;
357      }
358
359      /* select a variable index uniformly at random */
360      int N = cfrac.vars.Length;
361      var vIdx = rand.Next(N);
362
363      /* for each depth of continued fraction, toggle the selection of variables of the term (h) */
364      foreach (var h in cfrac.h) {
365        /* Case 1: cfrac variable is turned ON: Turn OFF the variable, and either 'Remove' or
366         * 'Remember' the coefficient value at random */
367        if (cfrac.vars[vIdx]) {  // CHECK: paper uses varAt()
368          h.vars[vIdx] = false;  // CHECK: paper uses varAt()
369          h.coef[vIdx] = coinToss(0, h.coef[vIdx]);
370        } else {
371          /* Case 2: term variable is turned OFF: Turn ON the variable, and either 'Remove'
372           * or 'Replace' the coefficient with a random value between -3 and 3 at random */
373          if (!h.vars[vIdx]) {
374            h.vars[vIdx] = true;  // CHECK: paper uses varAt()
375            h.coef[vIdx] = coinToss(0, rand.NextDouble() * 6 - 3);
376          }
377        }
378      }
379      /* toggle the randomly selected variable */
380      cfrac.vars[vIdx] = !cfrac.vars[vIdx];  // CHECK: paper uses varAt()
381    }
382
383    private void ModifyVariable(ContinuedFraction cfrac, IRandom rand) {
384      /* randomly select a variable which is turned ON */
385      var candVars = new List<int>();
386      for (int i = 0; i < cfrac.vars.Length; i++) if (cfrac.vars[i]) candVars.Add(i);  // CHECK: paper uses varAt()
387      if (candVars.Count == 0) return; // no variable active
388      var vIdx = candVars[rand.Next(candVars.Count)];
389
390      /* randomly select a term (h) of continued fraction */
391      var h = cfrac.h[rand.Next(cfrac.h.Length)];
392
393      /* modify the coefficient value */
394      if (h.vars[vIdx]) {  // CHECK: paper uses varAt()
395        h.coef[vIdx] = 0.0;
396      } else {
397        h.coef[vIdx] = rand.NextDouble() * 6 - 3;
398      }
399      /* Toggle the randomly selected variable */
400      h.vars[vIdx] = !h.vars[vIdx]; // CHECK: paper uses varAt()
401    }
402
403    private static double Evaluate(ContinuedFraction cfrac, double[,] trainingData, double delta) {
404      var dataPoint = new double[trainingData.GetLength(1) - 1];
405      var yIdx = trainingData.GetLength(1) - 1;
406      double sum = 0.0;
407      for (int r = 0; r < trainingData.GetLength(0); r++) {
408        for (int c = 0; c < dataPoint.Length; c++) {
409          dataPoint[c] = trainingData[r, c];
410        }
411        var y = trainingData[r, yIdx];
412        var pred = Evaluate(cfrac, dataPoint);
413        var res = y - pred;
414        sum += res * res;
415      }
416      return sum / trainingData.GetLength(0) * (1 + delta * cfrac.vars.Count(vi => vi));
417    }
418
419    private static double Evaluate(ContinuedFraction cfrac, double[] dataPoint) {
420      var res = 0.0;
421      for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
422        var hi = cfrac.h[i];
423        var hi1 = cfrac.h[i - 1];
424        var denom = hi.beta + dot(hi.vars, hi.coef, dataPoint) + res;
425        var numerator = hi1.beta + dot(hi1.vars, hi1.coef, dataPoint);
426        res = numerator / denom;
427      }
428      var h0 = cfrac.h[0];
429      res += h0.beta + dot(h0.vars, h0.coef, dataPoint);
430      return res;
431    }
432
433
434    private Func<bool[], bool[], bool[]> SelectRandomOp(IRandom rand) {
435      bool[] union(bool[] a, bool[] b) {
436        var res = new bool[a.Length];
437        for (int i = 0; i < a.Length; i++) res[i] = a[i] || b[i];
438        return res;
439      }
440      bool[] intersect(bool[] a, bool[] b) {
441        var res = new bool[a.Length];
442        for (int i = 0; i < a.Length; i++) res[i] = a[i] && b[i];
443        return res;
444      }
445      bool[] symmetricDifference(bool[] a, bool[] b) {
446        var res = new bool[a.Length];
447        for (int i = 0; i < a.Length; i++) res[i] = a[i] ^ b[i];
448        return res;
449      }
450      switch (rand.Next(3)) {
451        case 0: return union;
452        case 1: return intersect;
453        case 2: return symmetricDifference;
454        default: throw new ArgumentException();
455      }
456    }
457
458    private static double dot(bool[] filter, double[] x, double[] y) {
459      var s = 0.0;
460      for (int i = 0; i < x.Length; i++)
461        if (filter[i])
462          s += x[i] * y[i];
463      return s;
464    }
465
466
467    private static void LocalSearchSimplex(int iterations, int restarts, double tolerance, double delta, ContinuedFraction ch, ref double quality, double[,] trainingData, IRandom rand) {
468      double uniformPeturbation = 1.0;
469      int maxEvals = iterations;
470      int numSearches = restarts + 1;
471      var numRows = trainingData.GetLength(0);
472      int numSelectedRows = numRows / 5; // 20% of the training samples
473
474      quality = Evaluate(ch, trainingData, delta); // get quality with origial coefficients
475
476      double[] origCoeff = ExtractCoeff(ch);
477      if (origCoeff.Length == 0) return; // no parameters to optimize
478
479      var bestQuality = quality;
480      var bestCoeff = origCoeff;
481
482      var fittingData = SelectRandomRows(trainingData, numSelectedRows, rand);
483
484      double objFunc(double[] curCoeff) {
485        SetCoeff(ch, curCoeff);
486        return Evaluate(ch, fittingData, delta);
487      }
488
489      for (int count = 0; count < numSearches; count++) {
490
491        SimplexConstant[] constants = new SimplexConstant[origCoeff.Length];
492        for (int i = 0; i < origCoeff.Length; i++) {
493          constants[i] = new SimplexConstant(origCoeff[i], uniformPeturbation);
494        }
495
496        RegressionResult result = NelderMeadSimplex.Regress(constants, tolerance, maxEvals, objFunc);
497
498        var optimizedCoeff = result.Constants;
499        SetCoeff(ch, optimizedCoeff);
500
501        var newQuality = Evaluate(ch, trainingData, delta);
502
503        if (newQuality < bestQuality) {
504          bestCoeff = optimizedCoeff;
505          bestQuality = newQuality;
506        }
507      } // reps
508
509      SetCoeff(ch, bestCoeff);
510      quality = bestQuality;
511    }
512
513    private static double[,] SelectRandomRows(double[,] trainingData, int numSelectedRows, IRandom rand) {
514      var numRows = trainingData.GetLength(0);
515      var numCols = trainingData.GetLength(1);
516      var selectedRows = Enumerable.Range(0, numRows).Shuffle(rand).Take(numSelectedRows).ToArray();
517      var subset = new double[numSelectedRows, numCols];
518      var i = 0;
519      foreach (var r in selectedRows) {
520        for (int c = 0; c < numCols; c++) {
521          subset[i, c] = trainingData[r, c];
522        }
523        i++;
524      }
525      return subset;
526    }
527
528    private static double[] ExtractCoeff(ContinuedFraction ch) {
529      var coeff = new List<double>();
530      foreach (var hi in ch.h) {
531        coeff.Add(hi.beta);
532        for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
533          if (hi.vars[vIdx]) coeff.Add(hi.coef[vIdx]);
534        }
535      }
536      return coeff.ToArray();
537    }
538
539    private static void SetCoeff(ContinuedFraction ch, double[] curCoeff) {
540      int k = 0;
541      foreach (var hi in ch.h) {
542        hi.beta = curCoeff[k++];
543        for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
544          if (hi.vars[vIdx]) hi.coef[vIdx] = curCoeff[k++];
545        }
546      }
547    }
548
549    #region build a symbolic expression tree
550    Symbol addSy = new Addition();
551    Symbol divSy = new Division();
552    Symbol startSy = new StartSymbol();
553    Symbol progSy = new ProgramRootSymbol();
554    Symbol constSy = new Constant();
555    Symbol varSy = new Problems.DataAnalysis.Symbolic.Variable();
556
557    private ISymbolicRegressionSolution CreateSymbolicRegressionSolution(ContinuedFraction cfrac, double[,] trainingData, string[] variables, string targetVariable) {
558      ISymbolicExpressionTreeNode res = null;
559      for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
560        var hi = cfrac.h[i];
561        var hi1 = cfrac.h[i - 1];
562        var denom = CreateLinearCombination(hi.vars, hi.coef, variables, hi.beta);
563        if (res != null) {
564          denom.AddSubtree(res);
565        }
566
567        var numerator = CreateLinearCombination(hi1.vars, hi1.coef, variables, hi1.beta);
568
569        res = divSy.CreateTreeNode();
570        res.AddSubtree(numerator);
571        res.AddSubtree(denom);
572      }
573
574      var h0 = cfrac.h[0];
575      var h0Term = CreateLinearCombination(h0.vars, h0.coef, variables, h0.beta);
576      h0Term.AddSubtree(res);
577
578      var progRoot = progSy.CreateTreeNode();
579      var start = startSy.CreateTreeNode();
580      progRoot.AddSubtree(start);
581      start.AddSubtree(h0Term);
582
583      var model = new SymbolicRegressionModel(targetVariable, new SymbolicExpressionTree(progRoot), new SymbolicDataAnalysisExpressionTreeBatchInterpreter());
584      var ds = new Dataset(variables.Concat(new[] { targetVariable }), trainingData);
585      var problemData = new RegressionProblemData(ds, variables, targetVariable);
586      var sol = new SymbolicRegressionSolution(model, problemData);
587      return sol;
588    }
589
590    private ISymbolicExpressionTreeNode CreateLinearCombination(bool[] vars, double[] coef, string[] variables, double beta) {
591      var sum = addSy.CreateTreeNode();
592      for (int i = 0; i < vars.Length; i++) {
593        if (vars[i]) {
594          var varNode = (VariableTreeNode)varSy.CreateTreeNode();
595          varNode.Weight = coef[i];
596          varNode.VariableName = variables[i];
597          sum.AddSubtree(varNode);
598        }
599      }
600      sum.AddSubtree(CreateConstant(beta));
601      return sum;
602    }
603
604    private ISymbolicExpressionTreeNode CreateConstant(double value) {
605      var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
606      constNode.Value = value;
607      return constNode;
608    }
609  }
610  #endregion
611}
Note: See TracBrowser for help on using the repository browser.