source: branches/3106_AnalyticContinuedFractionsRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/ContinuedFractionRegression/Algorithm.cs @ 17985

Last change on this file since 17985 was 17985, checked in by gkronber, 5 months ago

#3106 call localsearch after mutation and recombination as well as in the main loop for the whole population

File size: 29.2 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Threading;
6using HEAL.Attic;
7using HeuristicLab.Analysis;
8using HeuristicLab.Common;
9using HeuristicLab.Core;
10using HeuristicLab.Data;
11using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
12using HeuristicLab.Parameters;
13using HeuristicLab.Problems.DataAnalysis;
14using HeuristicLab.Problems.DataAnalysis.Symbolic;
15using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
16using HeuristicLab.Random;
17
18namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression {
19  /// <summary>
20  /// Implementation of Continuous Fraction Regression (CFR) as described in
21  /// Pablo Moscato, Haoyuan Sun, Mohammad Nazmul Haque,
22  /// Analytic Continued Fractions for Regression: A Memetic Algorithm Approach,
23  /// Expert Systems with Applications, Volume 179, 2021, 115018, ISSN 0957-4174,
24  /// https://doi.org/10.1016/j.eswa.2021.115018.
25  /// </summary>
26  [Item("Continuous Fraction Regression (CFR)", "TODO")]
27  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 999)]
28  [StorableType("7A375270-EAAF-4AD1-82FF-132318D20E09")]
29  public class Algorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
30    private const string MutationRateParameterName = "MutationRate";
31    private const string DepthParameterName = "Depth";
32    private const string NumGenerationsParameterName = "NumGenerations";
33    private const string StagnationGenerationsParameterName = "StagnationGenerations";
34    private const string LocalSearchIterationsParameterName = "LocalSearchIterations";
35    private const string LocalSearchRestartsParameterName = "LocalSearchRestarts";
36    private const string LocalSearchToleranceParameterName = "LocalSearchTolerance";
37    private const string DeltaParameterName = "Delta";
38    private const string ScaleDataParameterName = "ScaleData";
39
40
41    #region parameters
42    public IFixedValueParameter<PercentValue> MutationRateParameter => (IFixedValueParameter<PercentValue>)Parameters[MutationRateParameterName];
43    public double MutationRate {
44      get { return MutationRateParameter.Value.Value; }
45      set { MutationRateParameter.Value.Value = value; }
46    }
47    public IFixedValueParameter<IntValue> DepthParameter => (IFixedValueParameter<IntValue>)Parameters[DepthParameterName];
48    public int Depth {
49      get { return DepthParameter.Value.Value; }
50      set { DepthParameter.Value.Value = value; }
51    }
52    public IFixedValueParameter<IntValue> NumGenerationsParameter => (IFixedValueParameter<IntValue>)Parameters[NumGenerationsParameterName];
53    public int NumGenerations {
54      get { return NumGenerationsParameter.Value.Value; }
55      set { NumGenerationsParameter.Value.Value = value; }
56    }
57    public IFixedValueParameter<IntValue> StagnationGenerationsParameter => (IFixedValueParameter<IntValue>)Parameters[StagnationGenerationsParameterName];
58    public int StagnationGenerations {
59      get { return StagnationGenerationsParameter.Value.Value; }
60      set { StagnationGenerationsParameter.Value.Value = value; }
61    }
62    public IFixedValueParameter<IntValue> LocalSearchIterationsParameter => (IFixedValueParameter<IntValue>)Parameters[LocalSearchIterationsParameterName];
63    public int LocalSearchIterations {
64      get { return LocalSearchIterationsParameter.Value.Value; }
65      set { LocalSearchIterationsParameter.Value.Value = value; }
66    }
67    public IFixedValueParameter<IntValue> LocalSearchRestartsParameter => (IFixedValueParameter<IntValue>)Parameters[LocalSearchRestartsParameterName];
68    public int LocalSearchRestarts {
69      get { return LocalSearchRestartsParameter.Value.Value; }
70      set { LocalSearchRestartsParameter.Value.Value = value; }
71    }
72    public IFixedValueParameter<DoubleValue> LocalSearchToleranceParameter => (IFixedValueParameter<DoubleValue>)Parameters[LocalSearchToleranceParameterName];
73    public double LocalSearchTolerance {
74      get { return LocalSearchToleranceParameter.Value.Value; }
75      set { LocalSearchToleranceParameter.Value.Value = value; }
76    }
77    public IFixedValueParameter<PercentValue> DeltaParameter => (IFixedValueParameter<PercentValue>)Parameters[DeltaParameterName];
78    public double Delta {
79      get { return DeltaParameter.Value.Value; }
80      set { DeltaParameter.Value.Value = value; }
81    }
82    public IFixedValueParameter<BoolValue> ScaleDataParameter => (IFixedValueParameter<BoolValue>)Parameters[ScaleDataParameterName];
83    public bool ScaleData {
84      get { return ScaleDataParameter.Value.Value; }
85      set { ScaleDataParameter.Value.Value = value; }
86    }
87    #endregion
88
89    // storable ctor
90    [StorableConstructor]
91    public Algorithm(StorableConstructorFlag _) : base(_) { }
92
93    // cloning ctor
94    public Algorithm(Algorithm original, Cloner cloner) : base(original, cloner) { }
95
96
97    // default ctor
98    public Algorithm() : base() {
99      Problem = new RegressionProblem();
100      Parameters.Add(new FixedValueParameter<PercentValue>(MutationRateParameterName, "Mutation rate (default 10%)", new PercentValue(0.1)));
101      Parameters.Add(new FixedValueParameter<IntValue>(DepthParameterName, "Depth of the continued fraction representation (default 6)", new IntValue(6)));
102      Parameters.Add(new FixedValueParameter<IntValue>(NumGenerationsParameterName, "The maximum number of generations (default 200)", new IntValue(200)));
103      Parameters.Add(new FixedValueParameter<IntValue>(StagnationGenerationsParameterName, "Number of generations after which the population is re-initialized (default value 5)", new IntValue(5)));
104      Parameters.Add(new FixedValueParameter<IntValue>(LocalSearchIterationsParameterName, "Number of iterations for local search (simplex) (default value 250)", new IntValue(250)));
105      Parameters.Add(new FixedValueParameter<IntValue>(LocalSearchRestartsParameterName, "Number of restarts for local search (default value 4)", new IntValue(4)));
106      Parameters.Add(new FixedValueParameter<DoubleValue>(LocalSearchToleranceParameterName, "The tolerance value for local search (simplex) (default value: 1e-3)", new DoubleValue(1e-3)));
107      Parameters.Add(new FixedValueParameter<PercentValue>(DeltaParameterName, "The relative weight for the number of variables term in the fitness function (default value: 10%)", new PercentValue(0.1)));
108      Parameters.Add(new FixedValueParameter<BoolValue>(ScaleDataParameterName, "Turns on/off scaling of input variable values to the range [0 .. 1] (default: false)", new BoolValue(false)));
109    }
110
111    public override IDeepCloneable Clone(Cloner cloner) {
112      return new Algorithm(this, cloner);
113    }
114
115    protected override void Run(CancellationToken cancellationToken) {
116      var problemData = Problem.ProblemData;
117      double[,] xy;
118      if (ScaleData) {
119        // Scale data to range 0 .. 1
120        //
121        // Scaling was not used for the experiments in the paper. Statement by the authors: "We did not pre-process the data."
122        var transformations = new List<Transformation<double>>();
123        foreach (var input in problemData.AllowedInputVariables) {
124          var values = problemData.Dataset.GetDoubleValues(input, problemData.TrainingIndices);
125          var linTransformation = new LinearTransformation(problemData.AllowedInputVariables);
126          var min = values.Min();
127          var max = values.Max();
128          var range = max - min;
129          linTransformation.Addend = -min / range;
130          linTransformation.Multiplier = 1.0 / range;
131          transformations.Add(linTransformation);
132        }
133        // do not scale the target
134        transformations.Add(new LinearTransformation(problemData.AllowedInputVariables) { Addend = 0.0, Multiplier = 1.0 });
135        xy = problemData.Dataset.ToArray(problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }),
136          transformations,
137          problemData.TrainingIndices);
138      } else {
139        // no transformation
140        xy = problemData.Dataset.ToArray(problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }),
141          problemData.TrainingIndices);
142      }
143      var nVars = xy.GetLength(1) - 1;
144      var seed = new System.Random().Next();
145      var rand = new MersenneTwister((uint)seed);
146      CFRAlgorithm(nVars, Depth, MutationRate, xy, out var bestObj, rand, NumGenerations, StagnationGenerations,
147        Delta,
148        LocalSearchIterations, LocalSearchRestarts, LocalSearchTolerance, cancellationToken);
149    }
150
151    private void CFRAlgorithm(int nVars, int depth, double mutationRate, double[,] trainingData,
152      out double bestObj,
153      IRandom rand, int numGen, int stagnatingGens, double evalDelta,
154      int localSearchIterations, int localSearchRestarts, double localSearchTolerance,
155      CancellationToken cancellationToken) {
156      /* Algorithm 1 */
157      /* Generate initial population by a randomized algorithm */
158      var pop = InitialPopulation(nVars, depth, rand, trainingData);
159      bestObj = pop.pocketObjValue;
160      // the best value since the last reset
161      var episodeBestObj = pop.pocketObjValue;
162      var episodeBestObjGen = 0;
163      for (int gen = 1; gen <= numGen && !cancellationToken.IsCancellationRequested; gen++) {
164        /* mutate each current solution in the population */
165        var pop_mu = Mutate(pop, mutationRate, rand, trainingData);
166        /* generate new population by recombination mechanism */
167        var pop_r = RecombinePopulation(pop_mu, rand, nVars, trainingData);
168
169        // Paper:
170        // A period of individual search operation is performed every generation on all current solutions.
171
172        // Statement by authors:
173        // "We ran the Local Search after Mutation and recombination operations. We executed the local-search only on the Current solutions."
174        // "We executed the MaintainInvariant() in the following steps:
175        // - After generating the initial population
176        // - after resetting the root
177        // - after executing the local-search on the whole population.
178        // We updated the pocket/ current automatically after mutation and recombination operation."
179
180        /* local search optimization of current solutions */
181        foreach (var agent in pop_r.IterateLevels()) {
182          LocalSearchSimplex(localSearchIterations, localSearchRestarts, localSearchTolerance, evalDelta, agent, trainingData, rand);
183          Debug.Assert(agent.pocketObjValue < agent.currentObjValue);
184        }
185        foreach (var agent in pop_r.IteratePostOrder()) agent.MaintainInvariant(); // post-order to make sure that the root contains the best model
186        foreach (var agent in pop_r.IteratePostOrder()) agent.AssertInvariant();
187
188        // for detecting stagnation we track the best objective value since the last reset
189        // and reset if this does not change for stagnatingGens
190        if (gen > episodeBestObjGen + stagnatingGens) {
191          Reset(pop_r, nVars, depth, rand, trainingData);
192          episodeBestObj = double.MaxValue;
193        }
194        if (episodeBestObj > pop_r.pocketObjValue) {
195          episodeBestObjGen = gen; // wait at least stagnatingGens until resetting again
196          episodeBestObj = pop_r.pocketObjValue;
197        }
198
199        /* replace old population with evolved population */
200        pop = pop_r;
201
202        /* keep track of the best solution */
203        if (bestObj > pop.pocketObjValue) {
204          bestObj = pop.pocketObjValue;
205          Results.AddOrUpdateResult("MSE (best)", new DoubleValue(bestObj));
206          Results.AddOrUpdateResult("Solution", CreateSymbolicRegressionSolution(pop.pocket, trainingData, Problem.ProblemData.AllowedInputVariables.ToArray(), Problem.ProblemData.TargetVariable));
207        }
208
209        #region visualization and debugging
210        DataTable qualities;
211        int i = 0;
212        if (Results.TryGetValue("Qualities", out var qualitiesResult)) {
213          qualities = (DataTable)qualitiesResult.Value;
214        } else {
215          qualities = new DataTable("Qualities", "Qualities");
216          i = 0;
217          foreach (var node in pop.IterateLevels()) {
218            qualities.Rows.Add(new DataRow($"Quality {i} pocket", "Quality of pocket"));
219            qualities.Rows.Add(new DataRow($"Quality {i} current", "Quality of current"));
220            i++;
221          }
222          Results.AddOrUpdateResult("Qualities", qualities);
223        }
224        i = 0;
225        foreach (var node in pop.IterateLevels()) {
226          qualities.Rows[$"Quality {i} pocket"].Values.Add(node.pocketObjValue);
227          qualities.Rows[$"Quality {i} current"].Values.Add(node.currentObjValue);
228          i++;
229        }
230        #endregion
231      }
232    }
233
234
235
236    private Agent InitialPopulation(int nVars, int depth, IRandom rand, double[,] trainingData) {
237      /* instantiate 13 agents in the population */
238      var pop = new Agent();
239      // see Figure 2
240      for (int i = 0; i < 3; i++) {
241        pop.children.Add(new Agent());
242        for (int j = 0; j < 3; j++) {
243          pop.children[i].children.Add(new Agent());
244        }
245      }
246
247      // Statement by the authors: "Yes, we use post-order traversal here"
248      foreach (var agent in pop.IteratePostOrder()) {
249        agent.current = new ContinuedFraction(nVars, depth, rand);
250        agent.pocket = new ContinuedFraction(nVars, depth, rand);
251
252        agent.currentObjValue = Evaluate(agent.current, trainingData, Delta);
253        agent.pocketObjValue = Evaluate(agent.pocket, trainingData, Delta);
254
255        /* within each agent, the pocket solution always holds the better value of guiding
256         * function than its current solution
257         */
258        agent.MaintainInvariant();
259      }
260
261      foreach (var agent in pop.IteratePostOrder()) agent.AssertInvariant();
262
263      return pop;
264    }
265
266    // Reset is not described in detail in the paper.
267    // Statement by the authors: "We only replaced the pocket solution of the root with
268    // a randomly generated solution. Then we execute the maintain-invariant process.
269    // It does not initialize the solutions in the entire population."
270    private void Reset(Agent root, int nVars, int depth, IRandom rand, double[,] trainingData) {
271      root.pocket = new ContinuedFraction(nVars, depth, rand);
272      root.current = new ContinuedFraction(nVars, depth, rand);
273
274      root.currentObjValue = Evaluate(root.current, trainingData, Delta);
275      root.pocketObjValue = Evaluate(root.pocket, trainingData, Delta);
276
277      foreach (var agent in root.IteratePreOrder()) { agent.MaintainInvariant(); } // Here we use pre-order traversal push the newly created model down the hierarchy.
278
279      foreach (var agent in root.IteratePostOrder()) agent.AssertInvariant();
280
281    }
282
283
284
285    private Agent RecombinePopulation(Agent pop, IRandom rand, int nVars, double[,] trainingData) {
286      var l = pop;
287
288      if (pop.children.Count > 0) {
289        var s1 = pop.children[0];
290        var s2 = pop.children[1];
291        var s3 = pop.children[2];
292
293        // Statement by the authors: "we are using recently generated solutions.
294        // For an example, in step 1 we got the new current(l), which is being used in
295        // Step 2 to generate current(s3). The current(s3) from Step 2 is being used at
296        // Step 4. These steps are executed sequentially from 1 to 4. Similarly, in the
297        // recombination of lower-level subpopulations, we will have the new current
298        // (the supporters generated at the previous level) as the leader of the subpopulation."
299        Recombine(l, s1, SelectRandomOp(rand), rand, nVars, trainingData);
300        Recombine(s3, l, SelectRandomOp(rand), rand, nVars, trainingData);
301        Recombine(s1, s2, SelectRandomOp(rand), rand, nVars, trainingData);
302        Recombine(s2, s3, SelectRandomOp(rand), rand, nVars, trainingData);
303
304        // recombination works from top to bottom
305        foreach (var child in pop.children) {
306          RecombinePopulation(child, rand, nVars, trainingData);
307        }
308
309      }
310      return pop;
311    }
312
313    private ContinuedFraction Recombine(Agent a, Agent b, Func<bool[], bool[], bool[]> op, IRandom rand, int nVars, double[,] trainingData) {
314      ContinuedFraction p1 = a.pocket;
315      ContinuedFraction p2 = b.pocket;
316      ContinuedFraction ch = new ContinuedFraction() { h = new Term[p1.h.Length] };
317      /* apply a recombination operator chosen uniformly at random on variables of two parents into offspring */
318      ch.vars = op(p1.vars, p2.vars);
319
320      /* recombine the coefficients for each term (h) of the continued fraction */
321      for (int i = 0; i < p1.h.Length; i++) {
322        var coefa = p1.h[i].coef; var varsa = p1.h[i].vars;
323        var coefb = p2.h[i].coef; var varsb = p2.h[i].vars;
324
325        /* recombine coefficient values for variables */
326        var coefx = new double[nVars];
327        var varsx = new bool[nVars]; // deviates from paper, probably forgotten in the pseudo-code
328        for (int vi = 0; vi < nVars; vi++) {
329          if (ch.vars[vi]) {  // CHECK: paper uses featAt()
330            if (varsa[vi] && varsb[vi]) {
331              coefx[vi] = coefa[vi] + (rand.NextDouble() * 5 - 1) * (coefb[vi] - coefa[vi]) / 3.0;
332              varsx[vi] = true;
333            } else if (varsa[vi]) {
334              coefx[vi] = coefa[vi];
335              varsx[vi] = true;
336            } else if (varsb[vi]) {
337              coefx[vi] = coefb[vi];
338              varsx[vi] = true;
339            }
340          }
341        }
342        /* update new coefficients of the term in offspring */
343        ch.h[i] = new Term() { coef = coefx, vars = varsx };
344        /* compute new value of constant (beta) for term hi in the offspring solution ch using
345         * beta of p1.hi and p2.hi */
346        ch.h[i].beta = p1.h[i].beta + (rand.NextDouble() * 5 - 1) * (p2.h[i].beta - p1.h[i].beta) / 3.0;
347      }
348
349      a.current = ch;
350      LocalSearchSimplex(LocalSearchIterations, LocalSearchRestarts, LocalSearchTolerance, Delta, a, trainingData, rand);
351      return ch;
352    }
353
354    private Agent Mutate(Agent pop, double mutationRate, IRandom rand, double[,] trainingData) {
355      foreach (var agent in pop.IterateLevels()) {
356        if (rand.NextDouble() < mutationRate) {
357          if (agent.currentObjValue < 1.2 * agent.pocketObjValue ||
358              agent.currentObjValue > 2 * agent.pocketObjValue)
359            ToggleVariables(agent.current, rand); // major mutation
360          else
361            ModifyVariable(agent.current, rand); // soft mutation
362
363          // Finally, the local search operation is executed on the mutated solution in order to optimize
364          // non-zero coefficients. We do not apply mutation on pocket solutions because we consider them as a "collective memory"
365          // of good models visited in the past. They influence the search process via recombination only.
366          LocalSearchSimplex(LocalSearchIterations, LocalSearchRestarts, LocalSearchTolerance, Delta, agent, trainingData, rand);
367        }
368      }
369      return pop;
370    }
371
372    private void ToggleVariables(ContinuedFraction cfrac, IRandom rand) {
373      double coinToss(double a, double b) {
374        return rand.NextDouble() < 0.5 ? a : b;
375      }
376
377      /* select a variable index uniformly at random */
378      int N = cfrac.vars.Length;
379      var vIdx = rand.Next(N);
380
381      /* for each depth of continued fraction, toggle the selection of variables of the term (h) */
382      foreach (var h in cfrac.h) {
383        /* Case 1: cfrac variable is turned ON: Turn OFF the variable, and either 'Remove' or
384         * 'Remember' the coefficient value at random */
385        if (cfrac.vars[vIdx]) {
386          h.vars[vIdx] = false;
387          h.coef[vIdx] = coinToss(0, h.coef[vIdx]);
388        } else {
389          /* Case 2: term variable is turned OFF: Turn ON the variable, and either 'Remove'
390           * or 'Replace' the coefficient with a random value between -3 and 3 at random */
391          if (!h.vars[vIdx]) {
392            h.vars[vIdx] = true;
393            h.coef[vIdx] = coinToss(0, rand.NextDouble() * 6 - 3);
394          }
395        }
396      }
397      /* toggle the randomly selected variable */
398      cfrac.vars[vIdx] = !cfrac.vars[vIdx];
399    }
400
401    private void ModifyVariable(ContinuedFraction cfrac, IRandom rand) {
402      /* randomly select a variable which is turned ON */
403      var candVars = new List<int>();
404      for (int i = 0; i < cfrac.vars.Length; i++) if (cfrac.vars[i]) candVars.Add(i);
405      if (candVars.Count == 0) return; // no variable active
406      var vIdx = candVars[rand.Next(candVars.Count)];
407
408      /* randomly select a term (h) of continued fraction */
409      var h = cfrac.h[rand.Next(cfrac.h.Length)];
410
411      /* modify the coefficient value */
412      if (h.vars[vIdx]) {
413        h.coef[vIdx] = 0.0;
414      } else {
415        h.coef[vIdx] = rand.NextDouble() * 6 - 3;
416      }
417      /* Toggle the randomly selected variable */
418      h.vars[vIdx] = !h.vars[vIdx];
419    }
420
421    private static double Evaluate(ContinuedFraction cfrac, double[,] trainingData, double delta) {
422      var dataPoint = new double[trainingData.GetLength(1) - 1];
423      var yIdx = trainingData.GetLength(1) - 1;
424      double sum = 0.0;
425      for (int r = 0; r < trainingData.GetLength(0); r++) {
426        for (int c = 0; c < dataPoint.Length; c++) {
427          dataPoint[c] = trainingData[r, c];
428        }
429        var y = trainingData[r, yIdx];
430        var pred = Evaluate(cfrac, dataPoint);
431        var res = y - pred;
432        sum += res * res;
433      }
434      return sum / trainingData.GetLength(0) * (1 + delta * cfrac.vars.Count(vi => vi));
435    }
436
437    private static double Evaluate(ContinuedFraction cfrac, double[] dataPoint) {
438      var res = 0.0;
439      for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
440        var hi = cfrac.h[i];
441        var hi1 = cfrac.h[i - 1];
442        var denom = hi.beta + dot(hi.vars, hi.coef, dataPoint) + res;
443        var numerator = hi1.beta + dot(hi1.vars, hi1.coef, dataPoint);
444        res = numerator / denom;
445      }
446      var h0 = cfrac.h[0];
447      res += h0.beta + dot(h0.vars, h0.coef, dataPoint);
448      return res;
449    }
450
451
452    private Func<bool[], bool[], bool[]> SelectRandomOp(IRandom rand) {
453      bool[] union(bool[] a, bool[] b) {
454        var res = new bool[a.Length];
455        for (int i = 0; i < a.Length; i++) res[i] = a[i] || b[i];
456        return res;
457      }
458      bool[] intersect(bool[] a, bool[] b) {
459        var res = new bool[a.Length];
460        for (int i = 0; i < a.Length; i++) res[i] = a[i] && b[i];
461        return res;
462      }
463      bool[] symmetricDifference(bool[] a, bool[] b) {
464        var res = new bool[a.Length];
465        for (int i = 0; i < a.Length; i++) res[i] = a[i] ^ b[i];
466        return res;
467      }
468      switch (rand.Next(3)) {
469        case 0: return union;
470        case 1: return intersect;
471        case 2: return symmetricDifference;
472        default: throw new ArgumentException();
473      }
474    }
475
476    private static double dot(bool[] filter, double[] x, double[] y) {
477      var s = 0.0;
478      for (int i = 0; i < x.Length; i++)
479        if (filter[i])
480          s += x[i] * y[i];
481      return s;
482    }
483
484
485    private static void LocalSearchSimplex(int iterations, int restarts, double tolerance, double delta, Agent a, double[,] trainingData, IRandom rand) {
486      double uniformPeturbation = 1.0;
487      int maxEvals = iterations;
488      int numSearches = restarts + 1;
489      var numRows = trainingData.GetLength(0);
490      int numSelectedRows = numRows / 5; // 20% of the training samples
491
492      var ch = a.current;
493      var quality = Evaluate(ch, trainingData, delta); // get quality with original coefficients
494
495      double[] origCoeff = ExtractCoeff(ch);
496      if (origCoeff.Length == 0) return; // no parameters to optimize
497
498      var bestQuality = quality;
499      var bestCoeff = origCoeff;
500
501      var fittingData = SelectRandomRows(trainingData, numSelectedRows, rand);
502
503      double objFunc(double[] curCoeff) {
504        SetCoeff(ch, curCoeff);
505        return Evaluate(ch, fittingData, delta);
506      }
507
508      for (int count = 0; count < numSearches; count++) {
509
510        SimplexConstant[] constants = new SimplexConstant[origCoeff.Length];
511        for (int i = 0; i < origCoeff.Length; i++) {
512          constants[i] = new SimplexConstant(origCoeff[i], uniformPeturbation);
513        }
514
515        RegressionResult result = NelderMeadSimplex.Regress(constants, tolerance, maxEvals, objFunc);
516
517        var optimizedCoeff = result.Constants;
518        SetCoeff(ch, optimizedCoeff);
519
520        var newQuality = Evaluate(ch, trainingData, delta);
521
522        if (newQuality < bestQuality) {
523          bestCoeff = optimizedCoeff;
524          bestQuality = newQuality;
525        }
526      } // reps
527
528      SetCoeff(a.current, bestCoeff);
529      a.currentObjValue = bestQuality;
530
531      // Unclear what the following means exactly.
532      //
533      // "We remind again that
534      // each solution corresponds to a single model, this means that if a current model becomes better than its corresponding
535      // pocket model (in terms of the guiding function of the solution), then an individual search optimization step is also
536      // performed on the pocket solution/ model before we swap it with the current solution/ model. Individual search can then
537      // make a current model better than the pocket model (again, according to the guiding function), and in that case they
538      // switch positions within the agent that contains both of them."
539
540      a.MaintainPocketCurrentInvariant();
541    }
542
543    private static double[,] SelectRandomRows(double[,] trainingData, int numSelectedRows, IRandom rand) {
544      var numRows = trainingData.GetLength(0);
545      var numCols = trainingData.GetLength(1);
546      var selectedRows = Enumerable.Range(0, numRows).Shuffle(rand).Take(numSelectedRows).ToArray();
547      var subset = new double[numSelectedRows, numCols];
548      var i = 0;
549      foreach (var r in selectedRows) {
550        for (int c = 0; c < numCols; c++) {
551          subset[i, c] = trainingData[r, c];
552        }
553        i++;
554      }
555      return subset;
556    }
557
558    private static double[] ExtractCoeff(ContinuedFraction ch) {
559      var coeff = new List<double>();
560      foreach (var hi in ch.h) {
561        coeff.Add(hi.beta);
562        for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
563          if (hi.vars[vIdx] && hi.coef[vIdx] != 0) coeff.Add(hi.coef[vIdx]); // paper states twice (for mutation and recombination) that non-zero coefficients are optimized
564        }
565      }
566      return coeff.ToArray();
567    }
568
569    private static void SetCoeff(ContinuedFraction ch, double[] curCoeff) {
570      int k = 0;
571      foreach (var hi in ch.h) {
572        hi.beta = curCoeff[k++];
573        for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
574          if (hi.vars[vIdx] && hi.coef[vIdx] != 0) hi.coef[vIdx] = curCoeff[k++]; // paper states twice (for mutation and recombination) that non-zero coefficients are optimized
575        }
576      }
577    }
578
579    #region build a symbolic expression tree
580    Symbol addSy = new Addition();
581    Symbol divSy = new Division();
582    Symbol startSy = new StartSymbol();
583    Symbol progSy = new ProgramRootSymbol();
584    Symbol constSy = new Constant();
585    Symbol varSy = new Problems.DataAnalysis.Symbolic.Variable();
586
587    private ISymbolicRegressionSolution CreateSymbolicRegressionSolution(ContinuedFraction cfrac, double[,] trainingData, string[] variables, string targetVariable) {
588      ISymbolicExpressionTreeNode res = null;
589      for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
590        var hi = cfrac.h[i];
591        var hi1 = cfrac.h[i - 1];
592        var denom = CreateLinearCombination(hi.vars, hi.coef, variables, hi.beta);
593        if (res != null) {
594          denom.AddSubtree(res);
595        }
596
597        var numerator = CreateLinearCombination(hi1.vars, hi1.coef, variables, hi1.beta);
598
599        res = divSy.CreateTreeNode();
600        res.AddSubtree(numerator);
601        res.AddSubtree(denom);
602      }
603
604      var h0 = cfrac.h[0];
605      var h0Term = CreateLinearCombination(h0.vars, h0.coef, variables, h0.beta);
606      h0Term.AddSubtree(res);
607
608      var progRoot = progSy.CreateTreeNode();
609      var start = startSy.CreateTreeNode();
610      progRoot.AddSubtree(start);
611      start.AddSubtree(h0Term);
612
613      var model = new SymbolicRegressionModel(targetVariable, new SymbolicExpressionTree(progRoot), new SymbolicDataAnalysisExpressionTreeBatchInterpreter());
614      var ds = new Dataset(variables.Concat(new[] { targetVariable }), trainingData);
615      var problemData = new RegressionProblemData(ds, variables, targetVariable);
616      var sol = new SymbolicRegressionSolution(model, problemData);
617      return sol;
618    }
619
620    private ISymbolicExpressionTreeNode CreateLinearCombination(bool[] vars, double[] coef, string[] variables, double beta) {
621      var sum = addSy.CreateTreeNode();
622      for (int i = 0; i < vars.Length; i++) {
623        if (vars[i]) {
624          var varNode = (VariableTreeNode)varSy.CreateTreeNode();
625          varNode.Weight = coef[i];
626          varNode.VariableName = variables[i];
627          sum.AddSubtree(varNode);
628        }
629      }
630      sum.AddSubtree(CreateConstant(beta));
631      return sum;
632    }
633
634    private ISymbolicExpressionTreeNode CreateConstant(double value) {
635      var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
636      constNode.Value = value;
637      return constNode;
638    }
639  }
640  #endregion
641}
Note: See TracBrowser for help on using the repository browser.