source: branches/3106_AnalyticContinuedFractionsRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/ContinuedFractionRegression/Algorithm.cs @ 17986

Last change on this file since 17986 was 17986, checked in by gkronber, 4 months ago

#3106 Fixed a bug: training data were not resampled in restarts for local search. Made a few minor changes to match the description in the paper (introduced two parameters for local search and changed likelihood of variables to be included.)

File size: 33.0 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Threading;
6using HEAL.Attic;
7using HeuristicLab.Analysis;
8using HeuristicLab.Common;
9using HeuristicLab.Core;
10using HeuristicLab.Data;
11using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
12using HeuristicLab.Parameters;
13using HeuristicLab.Problems.DataAnalysis;
14using HeuristicLab.Problems.DataAnalysis.Symbolic;
15using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
16using HeuristicLab.Random;
17
18namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression {
19  /// <summary>
20  /// Implementation of Continuous Fraction Regression (CFR) as described in
21  /// Pablo Moscato, Haoyuan Sun, Mohammad Nazmul Haque,
22  /// Analytic Continued Fractions for Regression: A Memetic Algorithm Approach,
23  /// Expert Systems with Applications, Volume 179, 2021, 115018, ISSN 0957-4174,
24  /// https://doi.org/10.1016/j.eswa.2021.115018.
25  /// </summary>
26  [Item("Continuous Fraction Regression (CFR)", "TODO")]
27  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 999)]
28  [StorableType("7A375270-EAAF-4AD1-82FF-132318D20E09")]
29  public class Algorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
30    private const string MutationRateParameterName = "MutationRate";
31    private const string DepthParameterName = "Depth";
32    private const string NumGenerationsParameterName = "NumGenerations";
33    private const string StagnationGenerationsParameterName = "StagnationGenerations";
34    private const string LocalSearchIterationsParameterName = "LocalSearchIterations";
35    private const string LocalSearchRestartsParameterName = "LocalSearchRestarts";
36    private const string LocalSearchToleranceParameterName = "LocalSearchTolerance";
37    private const string DeltaParameterName = "Delta";
38    private const string ScaleDataParameterName = "ScaleData";
39    private const string LocalSearchMinNumSamplesParameterName = "LocalSearchMinNumSamples";
40    private const string LocalSearchSamplesFractionParameterName = "LocalSearchSamplesFraction";
41
42
43    #region parameters
44    public IFixedValueParameter<PercentValue> MutationRateParameter => (IFixedValueParameter<PercentValue>)Parameters[MutationRateParameterName];
45    public double MutationRate {
46      get { return MutationRateParameter.Value.Value; }
47      set { MutationRateParameter.Value.Value = value; }
48    }
49    public IFixedValueParameter<IntValue> DepthParameter => (IFixedValueParameter<IntValue>)Parameters[DepthParameterName];
50    public int Depth {
51      get { return DepthParameter.Value.Value; }
52      set { DepthParameter.Value.Value = value; }
53    }
54    public IFixedValueParameter<IntValue> NumGenerationsParameter => (IFixedValueParameter<IntValue>)Parameters[NumGenerationsParameterName];
55    public int NumGenerations {
56      get { return NumGenerationsParameter.Value.Value; }
57      set { NumGenerationsParameter.Value.Value = value; }
58    }
59    public IFixedValueParameter<IntValue> StagnationGenerationsParameter => (IFixedValueParameter<IntValue>)Parameters[StagnationGenerationsParameterName];
60    public int StagnationGenerations {
61      get { return StagnationGenerationsParameter.Value.Value; }
62      set { StagnationGenerationsParameter.Value.Value = value; }
63    }
64    public IFixedValueParameter<IntValue> LocalSearchIterationsParameter => (IFixedValueParameter<IntValue>)Parameters[LocalSearchIterationsParameterName];
65    public int LocalSearchIterations {
66      get { return LocalSearchIterationsParameter.Value.Value; }
67      set { LocalSearchIterationsParameter.Value.Value = value; }
68    }
69    public IFixedValueParameter<IntValue> LocalSearchRestartsParameter => (IFixedValueParameter<IntValue>)Parameters[LocalSearchRestartsParameterName];
70    public int LocalSearchRestarts {
71      get { return LocalSearchRestartsParameter.Value.Value; }
72      set { LocalSearchRestartsParameter.Value.Value = value; }
73    }
74    public IFixedValueParameter<DoubleValue> LocalSearchToleranceParameter => (IFixedValueParameter<DoubleValue>)Parameters[LocalSearchToleranceParameterName];
75    public double LocalSearchTolerance {
76      get { return LocalSearchToleranceParameter.Value.Value; }
77      set { LocalSearchToleranceParameter.Value.Value = value; }
78    }
79    public IFixedValueParameter<PercentValue> DeltaParameter => (IFixedValueParameter<PercentValue>)Parameters[DeltaParameterName];
80    public double Delta {
81      get { return DeltaParameter.Value.Value; }
82      set { DeltaParameter.Value.Value = value; }
83    }
84    public IFixedValueParameter<BoolValue> ScaleDataParameter => (IFixedValueParameter<BoolValue>)Parameters[ScaleDataParameterName];
85    public bool ScaleData {
86      get { return ScaleDataParameter.Value.Value; }
87      set { ScaleDataParameter.Value.Value = value; }
88    }
89    public IFixedValueParameter<IntValue> LocalSearchMinNumSamplesParameter => (IFixedValueParameter<IntValue>)Parameters[LocalSearchMinNumSamplesParameterName];
90    public int LocalSearchMinNumSamples {
91      get { return LocalSearchMinNumSamplesParameter.Value.Value; }
92      set { LocalSearchMinNumSamplesParameter.Value.Value = value; }
93    }
94    public IFixedValueParameter<PercentValue> LocalSearchSamplesFractionParameter => (IFixedValueParameter<PercentValue>)Parameters[LocalSearchSamplesFractionParameterName];
95    public double LocalSearchSamplesFraction {
96      get { return LocalSearchSamplesFractionParameter.Value.Value; }
97      set { LocalSearchSamplesFractionParameter.Value.Value = value; }
98    }
99    #endregion
100
101    // storable ctor
102    [StorableConstructor]
103    public Algorithm(StorableConstructorFlag _) : base(_) { }
104
105    // cloning ctor
106    public Algorithm(Algorithm original, Cloner cloner) : base(original, cloner) { }
107
108
109    // default ctor
110    public Algorithm() : base() {
111      Problem = new RegressionProblem();
112      Parameters.Add(new FixedValueParameter<PercentValue>(MutationRateParameterName, "Mutation rate (default 10%)", new PercentValue(0.1)));
113      Parameters.Add(new FixedValueParameter<IntValue>(DepthParameterName, "Depth of the continued fraction representation (default 6)", new IntValue(6)));
114      Parameters.Add(new FixedValueParameter<IntValue>(NumGenerationsParameterName, "The maximum number of generations (default 200)", new IntValue(200)));
115      Parameters.Add(new FixedValueParameter<IntValue>(StagnationGenerationsParameterName, "Number of generations after which the population is re-initialized (default value 5)", new IntValue(5)));
116      Parameters.Add(new FixedValueParameter<IntValue>(LocalSearchIterationsParameterName, "Number of iterations for local search (simplex) (default value 250)", new IntValue(250)));
117      Parameters.Add(new FixedValueParameter<IntValue>(LocalSearchRestartsParameterName, "Number of restarts for local search (default value 4)", new IntValue(4)));
118      Parameters.Add(new FixedValueParameter<DoubleValue>(LocalSearchToleranceParameterName, "The tolerance value for local search (simplex) (default value: 1e-3)", new DoubleValue(1e-3))); // page 12 of the preprint
119      Parameters.Add(new FixedValueParameter<PercentValue>(DeltaParameterName, "The relative weight for the number of variables term in the fitness function (default value: 10%)", new PercentValue(0.1)));
120      Parameters.Add(new FixedValueParameter<BoolValue>(ScaleDataParameterName, "Turns on/off scaling of input variable values to the range [0 .. 1] (default: false)", new BoolValue(false)));
121      Parameters.Add(new FixedValueParameter<IntValue>(LocalSearchMinNumSamplesParameterName, "The minimum number of samples for the local search (default 200)", new IntValue(200)));
122      Parameters.Add(new FixedValueParameter<PercentValue>(LocalSearchSamplesFractionParameterName, "The fraction of samples used for local search. Only used when the number of samples is more than " + LocalSearchMinNumSamplesParameterName + " (default 20%)", new PercentValue(0.2)));
123    }
124
125    [StorableHook(HookType.AfterDeserialization)]
126    private void AfterDeserialization() {
127      // for backwards compatibility
128      if (!Parameters.ContainsKey(LocalSearchMinNumSamplesParameterName))
129        Parameters.Add(new FixedValueParameter<IntValue>(LocalSearchMinNumSamplesParameterName, "The minimum number of samples for the local search (default 200)", new IntValue(200)));
130      if (!Parameters.ContainsKey(LocalSearchSamplesFractionParameterName))
131        Parameters.Add(new FixedValueParameter<PercentValue>(LocalSearchSamplesFractionParameterName, "The fraction of samples used for local search. Only used when the number of samples is more than " + LocalSearchMinNumSamplesParameterName + " (default 20%)", new PercentValue(0.2)));
132    }
133
134
135    public override IDeepCloneable Clone(Cloner cloner) {
136      return new Algorithm(this, cloner);
137    }
138
139    protected override void Run(CancellationToken cancellationToken) {
140      var problemData = Problem.ProblemData;
141      double[,] xy;
142      var transformations = new List<ITransformation<double>>();
143      if (ScaleData) {
144        // TODO: scaling via transformations is really ugly.
145        // Scale data to range 0 .. 1
146        //
147        // Scaling was not used for the experiments in the paper. Statement by the authors: "We did not pre-process the data."
148        foreach (var input in problemData.AllowedInputVariables) {
149          var values = problemData.Dataset.GetDoubleValues(input, problemData.TrainingIndices);
150          var linTransformation = new LinearTransformation(problemData.AllowedInputVariables);
151          var min = values.Min();
152          var max = values.Max();
153          var range = max - min;
154          linTransformation.Addend = -min / range;
155          linTransformation.Multiplier = 1.0 / range;
156          linTransformation.ColumnParameter.ActualValue = linTransformation.ColumnParameter.ValidValues.First(sv => sv.Value == input);
157          transformations.Add(linTransformation);
158        }
159        // do not scale the target
160        var targetTransformation = new LinearTransformation(new string[] { problemData.TargetVariable }) { Addend = 0.0, Multiplier = 1.0 };
161        targetTransformation.ColumnParameter.ActualValue = targetTransformation.ColumnParameter.ValidValues.First(sv => sv.Value == problemData.TargetVariable);
162        transformations.Add(targetTransformation);
163        xy = problemData.Dataset.ToArray(problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }),
164          transformations,
165          problemData.TrainingIndices);
166      } else {
167        // no transformation
168        xy = problemData.Dataset.ToArray(problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }),
169          problemData.TrainingIndices);
170      }
171      var nVars = xy.GetLength(1) - 1;
172      var seed = new System.Random().Next();
173      var rand = new MersenneTwister((uint)seed);
174
175      void iterationCallback(Agent pop) {
176        #region visualization and debugging
177        DataTable qualities;
178        int i = 0;
179        if (Results.TryGetValue("Qualities", out var qualitiesResult)) {
180          qualities = (DataTable)qualitiesResult.Value;
181        } else {
182          qualities = new DataTable("Qualities", "Qualities");
183          i = 0;
184          foreach (var node in pop.IterateLevels()) {
185            qualities.Rows.Add(new DataRow($"Quality {i} pocket", "Quality of pocket"));
186            qualities.Rows.Add(new DataRow($"Quality {i} current", "Quality of current"));
187            i++;
188          }
189          Results.AddOrUpdateResult("Qualities", qualities);
190        }
191        i = 0;
192        foreach (var node in pop.IterateLevels()) {
193          qualities.Rows[$"Quality {i} pocket"].Values.Add(node.pocketObjValue);
194          qualities.Rows[$"Quality {i} current"].Values.Add(node.currentObjValue);
195          i++;
196        }
197        #endregion
198      }
199      void newBestSolutionCallback(ContinuedFraction best, double objVal) {
200        Results.AddOrUpdateResult("MSE (best)", new DoubleValue(objVal));
201        Results.AddOrUpdateResult("Solution", CreateSymbolicRegressionSolution(best, problemData, transformations));
202      }
203
204      CFRAlgorithm(nVars, Depth, MutationRate, xy, out var bestObj, rand, NumGenerations, StagnationGenerations,
205        Delta,
206        LocalSearchIterations, LocalSearchRestarts, LocalSearchTolerance, newBestSolutionCallback, iterationCallback, cancellationToken);
207    }
208
209    private void CFRAlgorithm(int nVars, int depth, double mutationRate, double[,] trainingData,
210      out double bestObj,
211      IRandom rand, int numGen, int stagnatingGens, double evalDelta,
212      int localSearchIterations, int localSearchRestarts, double localSearchTolerance,
213      Action<ContinuedFraction, double> newBestSolutionCallback,
214      Action<Agent> iterationCallback,
215      CancellationToken cancellationToken) {
216      /* Algorithm 1 */
217      /* Generate initial population by a randomized algorithm */
218      var pop = InitialPopulation(nVars, depth, rand, trainingData);
219      bestObj = pop.pocketObjValue;
220      // the best value since the last reset
221      var episodeBestObj = pop.pocketObjValue;
222      var episodeBestObjGen = 0;
223      for (int gen = 1; gen <= numGen && !cancellationToken.IsCancellationRequested; gen++) {
224        /* mutate each current solution in the population */
225        var pop_mu = Mutate(pop, mutationRate, rand, trainingData);
226        /* generate new population by recombination mechanism */
227        var pop_r = RecombinePopulation(pop_mu, rand, nVars, trainingData);
228
229        // Paper:
230        // A period of individual search operation is performed every generation on all current solutions.
231
232        // Statement by authors:
233        // "We ran the Local Search after Mutation and recombination operations. We executed the local-search only on the Current solutions."
234        // "We executed the MaintainInvariant() in the following steps:
235        // - After generating the initial population
236        // - after resetting the root
237        // - after executing the local-search on the whole population.
238        // We updated the pocket/ current automatically after mutation and recombination operation."
239
240        /* local search optimization of current solutions */
241        foreach (var agent in pop_r.IterateLevels()) {
242          LocalSearchSimplex(localSearchIterations, localSearchRestarts, localSearchTolerance, LocalSearchMinNumSamples, LocalSearchSamplesFraction, evalDelta, agent, trainingData, rand);
243          Debug.Assert(agent.pocketObjValue < agent.currentObjValue);
244        }
245        foreach (var agent in pop_r.IteratePostOrder()) agent.MaintainInvariant(); // post-order to make sure that the root contains the best model
246        foreach (var agent in pop_r.IteratePostOrder()) agent.AssertInvariant();
247
248        // for detecting stagnation we track the best objective value since the last reset
249        // and reset if this does not change for stagnatingGens
250        if (gen > episodeBestObjGen + stagnatingGens) {
251          Reset(pop_r, nVars, depth, rand, trainingData);
252          episodeBestObj = double.MaxValue;
253        }
254        if (episodeBestObj > pop_r.pocketObjValue) {
255          episodeBestObjGen = gen; // wait at least stagnatingGens until resetting again
256          episodeBestObj = pop_r.pocketObjValue;
257        }
258
259        /* replace old population with evolved population */
260        pop = pop_r;
261
262        /* keep track of the best solution */
263        if (bestObj > pop.pocketObjValue) {
264          bestObj = pop.pocketObjValue;
265          newBestSolutionCallback(pop.pocket, bestObj);
266        }
267
268
269        iterationCallback(pop);
270      }
271    }
272
273
274
275    private Agent InitialPopulation(int nVars, int depth, IRandom rand, double[,] trainingData) {
276      /* instantiate 13 agents in the population */
277      var pop = new Agent();
278      // see Figure 2
279      for (int i = 0; i < 3; i++) {
280        pop.children.Add(new Agent());
281        for (int j = 0; j < 3; j++) {
282          pop.children[i].children.Add(new Agent());
283        }
284      }
285
286      // Statement by the authors: "Yes, we use post-order traversal here"
287      foreach (var agent in pop.IteratePostOrder()) {
288        agent.current = new ContinuedFraction(nVars, depth, rand);
289        agent.pocket = new ContinuedFraction(nVars, depth, rand);
290
291        agent.currentObjValue = Evaluate(agent.current, trainingData, Delta);
292        agent.pocketObjValue = Evaluate(agent.pocket, trainingData, Delta);
293
294        /* within each agent, the pocket solution always holds the better value of guiding
295         * function than its current solution
296         */
297        agent.MaintainInvariant();
298      }
299
300      foreach (var agent in pop.IteratePostOrder()) agent.AssertInvariant();
301
302      return pop;
303    }
304
305    // Our criterion for relevance has been fairly strict: if no
306    // better model has been produced for five(5) straight generations,
307    // then the pocket of the root agent is removed and a new solution is created at random.
308
309    // Statement by the authors: "We only replaced the pocket solution of the root with
310    // a randomly generated solution. Then we execute the maintain-invariant process.
311    // It does not initialize the solutions in the entire population."
312    private void Reset(Agent root, int nVars, int depth, IRandom rand, double[,] trainingData) {
313      root.pocket = new ContinuedFraction(nVars, depth, rand);
314      root.pocketObjValue = Evaluate(root.pocket, trainingData, Delta);
315
316      foreach (var agent in root.IteratePreOrder()) { agent.MaintainInvariant(); } // Here we use pre-order traversal push the newly created model down the hierarchy.
317
318      foreach (var agent in root.IteratePostOrder()) agent.AssertInvariant();
319
320    }
321
322
323
324    private Agent RecombinePopulation(Agent pop, IRandom rand, int nVars, double[,] trainingData) {
325      var l = pop;
326
327      if (pop.children.Count > 0) {
328        var s1 = pop.children[0];
329        var s2 = pop.children[1];
330        var s3 = pop.children[2];
331
332        // Statement by the authors: "we are using recently generated solutions.
333        // For an example, in step 1 we got the new current(l), which is being used in
334        // Step 2 to generate current(s3). The current(s3) from Step 2 is being used at
335        // Step 4. These steps are executed sequentially from 1 to 4. Similarly, in the
336        // recombination of lower-level subpopulations, we will have the new current
337        // (the supporters generated at the previous level) as the leader of the subpopulation."
338        Recombine(l, s1, SelectRandomOp(rand), rand, nVars, trainingData);
339        Recombine(s3, l, SelectRandomOp(rand), rand, nVars, trainingData);
340        Recombine(s1, s2, SelectRandomOp(rand), rand, nVars, trainingData);
341        Recombine(s2, s3, SelectRandomOp(rand), rand, nVars, trainingData);
342
343        // recombination works from top to bottom
344        foreach (var child in pop.children) {
345          RecombinePopulation(child, rand, nVars, trainingData);
346        }
347
348      }
349      return pop;
350    }
351
352    private ContinuedFraction Recombine(Agent a, Agent b, Func<bool[], bool[], bool[]> op, IRandom rand, int nVars, double[,] trainingData) {
353      ContinuedFraction p1 = a.pocket;
354      ContinuedFraction p2 = b.pocket;
355      ContinuedFraction ch = new ContinuedFraction() { h = new Term[p1.h.Length] };
356      /* apply a recombination operator chosen uniformly at random on variables of two parents into offspring */
357      ch.vars = op(p1.vars, p2.vars);
358
359      /* recombine the coefficients for each term (h) of the continued fraction */
360      for (int i = 0; i < p1.h.Length; i++) {
361        var coefa = p1.h[i].coef; var varsa = p1.h[i].vars;
362        var coefb = p2.h[i].coef; var varsb = p2.h[i].vars;
363
364        /* recombine coefficient values for variables */
365        var coefx = new double[nVars];
366        var varsx = new bool[nVars]; // deviates from paper, probably forgotten in the pseudo-code
367        for (int vi = 0; vi < nVars; vi++) {
368          if (ch.vars[vi]) {  // CHECK: paper uses featAt()
369            if (varsa[vi] && varsb[vi]) {
370              coefx[vi] = coefa[vi] + (rand.NextDouble() * 5 - 1) * (coefb[vi] - coefa[vi]) / 3.0;
371              varsx[vi] = true;
372            } else if (varsa[vi]) {
373              coefx[vi] = coefa[vi];
374              varsx[vi] = true;
375            } else if (varsb[vi]) {
376              coefx[vi] = coefb[vi];
377              varsx[vi] = true;
378            }
379          }
380        }
381        /* update new coefficients of the term in offspring */
382        ch.h[i] = new Term() { coef = coefx, vars = varsx };
383        /* compute new value of constant (beta) for term hi in the offspring solution ch using
384         * beta of p1.hi and p2.hi */
385        ch.h[i].beta = p1.h[i].beta + (rand.NextDouble() * 5 - 1) * (p2.h[i].beta - p1.h[i].beta) / 3.0;
386      }
387
388      a.current = ch;
389      LocalSearchSimplex(LocalSearchIterations, LocalSearchRestarts, LocalSearchTolerance, LocalSearchMinNumSamples, LocalSearchSamplesFraction, Delta, a, trainingData, rand);
390      return ch;
391    }
392
393    private Agent Mutate(Agent pop, double mutationRate, IRandom rand, double[,] trainingData) {
394      foreach (var agent in pop.IterateLevels()) {
395        if (rand.NextDouble() < mutationRate) {
396          if (agent.currentObjValue < 1.2 * agent.pocketObjValue ||
397              agent.currentObjValue > 2 * agent.pocketObjValue)
398            ToggleVariables(agent.current, rand); // major mutation
399          else
400            ModifyVariable(agent.current, rand); // soft mutation
401
402          // Finally, the local search operation is executed on the mutated solution in order to optimize
403          // non-zero coefficients. We do not apply mutation on pocket solutions because we consider them as a "collective memory"
404          // of good models visited in the past. They influence the search process via recombination only.
405          LocalSearchSimplex(LocalSearchIterations, LocalSearchRestarts, LocalSearchTolerance, LocalSearchMinNumSamples, LocalSearchSamplesFraction, Delta, agent, trainingData, rand);
406        }
407      }
408      return pop;
409    }
410
411    private void ToggleVariables(ContinuedFraction cfrac, IRandom rand) {
412      double coinToss(double a, double b) {
413        return rand.NextDouble() < 0.5 ? a : b;
414      }
415
416      /* select a variable index uniformly at random */
417      int N = cfrac.vars.Length;
418      var vIdx = rand.Next(N);
419
420      /* for each depth of continued fraction, toggle the selection of variables of the term (h) */
421      foreach (var h in cfrac.h) {
422        /* Case 1: cfrac variable is turned ON: Turn OFF the variable, and either 'Remove' or
423         * 'Remember' the coefficient value at random */
424        if (cfrac.vars[vIdx]) {
425          h.vars[vIdx] = false;
426          h.coef[vIdx] = coinToss(0, h.coef[vIdx]);
427        } else {
428          /* Case 2: term variable is turned OFF: Turn ON the variable, and either 'Remove'
429           * or 'Replace' the coefficient with a random value between -3 and 3 at random */
430          if (!h.vars[vIdx]) {
431            h.vars[vIdx] = true;
432            h.coef[vIdx] = coinToss(0, rand.NextDouble() * 6 - 3);
433          }
434        }
435      }
436      /* toggle the randomly selected variable */
437      cfrac.vars[vIdx] = !cfrac.vars[vIdx];
438    }
439
440    private void ModifyVariable(ContinuedFraction cfrac, IRandom rand) {
441      /* randomly select a variable which is turned ON */
442      var candVars = new List<int>();
443      for (int i = 0; i < cfrac.vars.Length; i++) if (cfrac.vars[i]) candVars.Add(i);
444      if (candVars.Count == 0) return; // no variable active
445      var vIdx = candVars[rand.Next(candVars.Count)];
446
447      /* randomly select a term (h) of continued fraction */
448      var h = cfrac.h[rand.Next(cfrac.h.Length)];
449
450      /* modify the coefficient value */
451      if (h.vars[vIdx]) {
452        h.coef[vIdx] = 0.0;
453      } else {
454        h.coef[vIdx] = rand.NextDouble() * 6 - 3;
455      }
456      /* Toggle the randomly selected variable */
457      h.vars[vIdx] = !h.vars[vIdx];
458    }
459
460    private static double Evaluate(ContinuedFraction cfrac, double[,] trainingData, double delta) {
461      var dataPoint = new double[trainingData.GetLength(1) - 1];
462      var yIdx = trainingData.GetLength(1) - 1;
463      double sum = 0.0;
464      for (int r = 0; r < trainingData.GetLength(0); r++) {
465        for (int c = 0; c < dataPoint.Length; c++) {
466          dataPoint[c] = trainingData[r, c];
467        }
468        var y = trainingData[r, yIdx];
469        var pred = Evaluate(cfrac, dataPoint);
470        var res = y - pred;
471        sum += res * res;
472      }
473      return sum / trainingData.GetLength(0) * (1 + delta * cfrac.vars.Count(vi => vi));
474    }
475
476    private static double Evaluate(ContinuedFraction cfrac, double[] dataPoint) {
477      var res = 0.0;
478      for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
479        var hi = cfrac.h[i];
480        var hi1 = cfrac.h[i - 1];
481        var denom = hi.beta + dot(hi.vars, hi.coef, dataPoint) + res;
482        var numerator = hi1.beta + dot(hi1.vars, hi1.coef, dataPoint);
483        res = numerator / denom;
484      }
485      var h0 = cfrac.h[0];
486      res += h0.beta + dot(h0.vars, h0.coef, dataPoint);
487      return res;
488    }
489
490
491    private Func<bool[], bool[], bool[]> SelectRandomOp(IRandom rand) {
492      bool[] union(bool[] a, bool[] b) {
493        var res = new bool[a.Length];
494        for (int i = 0; i < a.Length; i++) res[i] = a[i] || b[i];
495        return res;
496      }
497      bool[] intersect(bool[] a, bool[] b) {
498        var res = new bool[a.Length];
499        for (int i = 0; i < a.Length; i++) res[i] = a[i] && b[i];
500        return res;
501      }
502      bool[] symmetricDifference(bool[] a, bool[] b) {
503        var res = new bool[a.Length];
504        for (int i = 0; i < a.Length; i++) res[i] = a[i] ^ b[i];
505        return res;
506      }
507      switch (rand.Next(3)) {
508        case 0: return union;
509        case 1: return intersect;
510        case 2: return symmetricDifference;
511        default: throw new ArgumentException();
512      }
513    }
514
515    private static double dot(bool[] filter, double[] x, double[] y) {
516      var s = 0.0;
517      for (int i = 0; i < x.Length; i++)
518        if (filter[i])
519          s += x[i] * y[i];
520      return s;
521    }
522
523
524    // The authors used the Nelder Mead solver from https://direct.mit.edu/evco/article/25/3/351/1046/Evolving-a-Nelder-Mead-Algorithm-for-Optimization
525    // Using different solvers (e.g. LevMar) is mentioned but not analysed
526
527
528    private static void LocalSearchSimplex(int iterations, int restarts, double tolerance, int minNumRows, double samplesFrac, double delta, Agent a, double[,] trainingData, IRandom rand) {
529      int maxEvals = iterations;
530      int numSearches = restarts + 1;
531      var numRows = trainingData.GetLength(0);
532      int numSelectedRows = numRows;
533      if (numRows > minNumRows)
534        numSelectedRows = (int)(numRows * samplesFrac);
535
536      var ch = a.current;
537      var quality = Evaluate(ch, trainingData, delta); // get quality with original coefficients
538
539      double[] origCoeff = ExtractCoeff(ch);
540      if (origCoeff.Length == 0) return; // no parameters to optimize
541
542      var bestQuality = quality;
543      var bestCoeff = origCoeff;
544
545      double[,] fittingData = null;
546
547      double objFunc(double[] curCoeff) {
548        SetCoeff(ch, curCoeff);
549        return Evaluate(ch, fittingData, delta);
550      }
551
552      for (int count = 0; count < numSearches; count++) {
553        fittingData = SelectRandomRows(trainingData, numSelectedRows, rand);
554
555        SimplexConstant[] constants = new SimplexConstant[origCoeff.Length];
556        for (int i = 0; i < origCoeff.Length; i++) {
557          constants[i] = new SimplexConstant(origCoeff[i], initialPerturbation: 1.0);
558        }
559
560        RegressionResult result = NelderMeadSimplex.Regress(constants, tolerance, maxEvals, objFunc);
561
562        var optimizedCoeff = result.Constants;
563        SetCoeff(ch, optimizedCoeff);
564
565        // the result with the best guiding function value (on the entire dataset) is chosen.
566        var newQuality = Evaluate(ch, trainingData, delta);
567
568        if (newQuality < bestQuality) {
569          bestCoeff = optimizedCoeff;
570          bestQuality = newQuality;
571        }
572      }
573
574      SetCoeff(a.current, bestCoeff);
575      a.currentObjValue = bestQuality;
576
577      // Unclear what the following means exactly.
578      //
579      // "We remind again that
580      // each solution corresponds to a single model, this means that if a current model becomes better than its corresponding
581      // pocket model (in terms of the guiding function of the solution), then an individual search optimization step is also
582      // performed on the pocket solution/ model before we swap it with the current solution/ model. Individual search can then
583      // make a current model better than the pocket model (again, according to the guiding function), and in that case they
584      // switch positions within the agent that contains both of them."
585
586      a.MaintainPocketCurrentInvariant();
587    }
588
589    private static double[,] SelectRandomRows(double[,] trainingData, int numSelectedRows, IRandom rand) {
590      var numRows = trainingData.GetLength(0);
591      var numCols = trainingData.GetLength(1);
592      var selectedRows = Enumerable.Range(0, numRows).Shuffle(rand).Take(numSelectedRows).ToArray();
593      var subset = new double[numSelectedRows, numCols];
594      var i = 0;
595      foreach (var r in selectedRows) {
596        for (int c = 0; c < numCols; c++) {
597          subset[i, c] = trainingData[r, c];
598        }
599        i++;
600      }
601      return subset;
602    }
603
604    private static double[] ExtractCoeff(ContinuedFraction ch) {
605      var coeff = new List<double>();
606      foreach (var hi in ch.h) {
607        coeff.Add(hi.beta);
608        for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
609          if (hi.vars[vIdx] && hi.coef[vIdx] != 0) coeff.Add(hi.coef[vIdx]); // paper states twice (for mutation and recombination) that non-zero coefficients are optimized
610        }
611      }
612      return coeff.ToArray();
613    }
614
615    private static void SetCoeff(ContinuedFraction ch, double[] curCoeff) {
616      int k = 0;
617      foreach (var hi in ch.h) {
618        hi.beta = curCoeff[k++];
619        for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
620          if (hi.vars[vIdx] && hi.coef[vIdx] != 0) hi.coef[vIdx] = curCoeff[k++]; // paper states twice (for mutation and recombination) that non-zero coefficients are optimized
621        }
622      }
623    }
624
625    #region build a symbolic expression tree
626    Symbol addSy = new Addition();
627    Symbol divSy = new Division();
628    Symbol startSy = new StartSymbol();
629    Symbol progSy = new ProgramRootSymbol();
630    Symbol constSy = new Constant();
631    Symbol varSy = new Problems.DataAnalysis.Symbolic.Variable();
632
633    private ISymbolicRegressionSolution CreateSymbolicRegressionSolution(ContinuedFraction cfrac, IRegressionProblemData problemData, List<ITransformation<double>> transformations) {
634      var variables = problemData.AllowedInputVariables.ToArray();
635      ISymbolicExpressionTreeNode res = null;
636      for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
637        var hi = cfrac.h[i];
638        var hi1 = cfrac.h[i - 1];
639        var denom = CreateLinearCombination(hi.vars, hi.coef, variables, hi.beta);
640        if (res != null) {
641          denom.AddSubtree(res);
642        }
643
644        var numerator = CreateLinearCombination(hi1.vars, hi1.coef, variables, hi1.beta);
645
646        res = divSy.CreateTreeNode();
647        res.AddSubtree(numerator);
648        res.AddSubtree(denom);
649      }
650
651      var h0 = cfrac.h[0];
652      var h0Term = CreateLinearCombination(h0.vars, h0.coef, variables, h0.beta);
653      h0Term.AddSubtree(res);
654
655      var progRoot = progSy.CreateTreeNode();
656      var start = startSy.CreateTreeNode();
657      progRoot.AddSubtree(start);
658      start.AddSubtree(h0Term);
659
660      ISymbolicRegressionModel model = new SymbolicRegressionModel(problemData.TargetVariable, new SymbolicExpressionTree(progRoot), new SymbolicDataAnalysisExpressionTreeBatchInterpreter());
661      if (transformations != null && transformations.Any()) {
662        var backscaling = new SymbolicExpressionTreeBacktransformator(new TransformationToSymbolicTreeMapper());
663        model = (ISymbolicRegressionModel)backscaling.Backtransform(model, transformations, problemData.TargetVariable);
664      }
665      var sol = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
666      return sol;
667    }
668
669    private ISymbolicExpressionTreeNode CreateLinearCombination(bool[] vars, double[] coef, string[] variables, double beta) {
670      var sum = addSy.CreateTreeNode();
671      for (int i = 0; i < vars.Length; i++) {
672        if (vars[i]) {
673          var varNode = (VariableTreeNode)varSy.CreateTreeNode();
674          varNode.Weight = coef[i];
675          varNode.VariableName = variables[i];
676          sum.AddSubtree(varNode);
677        }
678      }
679      sum.AddSubtree(CreateConstant(beta));
680      return sum;
681    }
682
683    private ISymbolicExpressionTreeNode CreateConstant(double value) {
684      var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
685      constNode.Value = value;
686      return constNode;
687    }
688  }
689  #endregion
690}
Note: See TracBrowser for help on using the repository browser.