source: branches/3106_AnalyticContinuedFractionsRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/ContinuedFractionRegression/Algorithm.cs @ 17983

Last change on this file since 17983 was 17983, checked in by gkronber, 5 months ago

#3106 add parameters

File size: 21.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Threading;
5using HEAL.Attic;
6using HeuristicLab.Common;
7using HeuristicLab.Core;
8using HeuristicLab.Data;
9using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
10using HeuristicLab.Parameters;
11using HeuristicLab.Problems.DataAnalysis;
12using HeuristicLab.Problems.DataAnalysis.Symbolic;
13using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
14using HeuristicLab.Random;
15
16namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression {
17  [Item("Continuous Fraction Regression (CFR)", "TODO")]
18  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 999)]
19  [StorableType("7A375270-EAAF-4AD1-82FF-132318D20E09")]
20  public class Algorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
21    private const string MutationRateParameterName = "MutationRate";
22    private const string DepthParameterName = "Depth";
23    private const string NumGenerationsParameterName = "Depth";
24
25    #region parameters
26    public IFixedValueParameter<PercentValue> MutationRateParameter => (IFixedValueParameter<PercentValue>)Parameters[MutationRateParameterName];
27    public double MutationRate {
28      get { return MutationRateParameter.Value.Value; }
29      set { MutationRateParameter.Value.Value = value; }
30    }
31    public IFixedValueParameter<IntValue> DepthParameter => (IFixedValueParameter<IntValue>)Parameters[DepthParameterName];
32    public int Depth {
33      get { return DepthParameter.Value.Value; }
34      set { DepthParameter.Value.Value = value; }
35    }
36    public IFixedValueParameter<IntValue> NumGenerationsParameter => (IFixedValueParameter<IntValue>)Parameters[NumGenerationsParameterName];
37    public int NumGenerations {
38      get { return NumGenerationsParameter.Value.Value; }
39      set { NumGenerationsParameter.Value.Value = value; }
40    }
41    #endregion
42
43    // storable ctor
44    [StorableConstructor]
45    public Algorithm(StorableConstructorFlag _) : base(_) { }
46
47    // cloning ctor
48    public Algorithm(Algorithm original, Cloner cloner) : base(original, cloner) { }
49
50
51    // default ctor
52    public Algorithm() : base() {
53      Parameters.Add(new FixedValueParameter<PercentValue>(MutationRateParameterName, "Mutation rate (default 10%)", new PercentValue(0.1)));
54      Parameters.Add(new FixedValueParameter<IntValue>(DepthParameterName, "Depth of the continued fraction representation (default 6)", new IntValue(6)));
55      Parameters.Add(new FixedValueParameter<IntValue>(NumGenerationsParameterName, "The maximum number of generations (default 200)", new IntValue(200)));
56    }
57
58    public override IDeepCloneable Clone(Cloner cloner) {
59      throw new NotImplementedException();
60    }
61
62    protected override void Run(CancellationToken cancellationToken) {
63      var problemData = Problem.ProblemData;
64
65      var x = problemData.Dataset.ToArray(problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }),
66        problemData.TrainingIndices);
67      var nVars = x.GetLength(1) - 1;
68      var seed = new System.Random().Next();
69      var rand = new MersenneTwister((uint)seed);
70      CFRAlgorithm(nVars, Depth, MutationRate, x, out var best, out var bestObj, rand, NumGenerations, stagnatingGens: 5, cancellationToken);
71    }
72
73    private void CFRAlgorithm(int nVars, int depth, double mutationRate, double[,] trainingData,
74      out ContinuedFraction best, out double bestObj,
75      IRandom rand, int numGen, int stagnatingGens,
76      CancellationToken cancellationToken) {
77      /* Algorithm 1 */
78      /* Generate initial population by a randomized algorithm */
79      var pop = InitialPopulation(nVars, depth, rand, trainingData);
80      best = pop.pocket;
81      bestObj = pop.pocketObjValue;
82      var bestObjGen = 0;
83      for (int gen = 1; gen <= numGen && !cancellationToken.IsCancellationRequested; gen++) {
84        /* mutate each current solution in the population */
85        var pop_mu = Mutate(pop, mutationRate, rand);
86        /* generate new population by recombination mechanism */
87        var pop_r = RecombinePopulation(pop_mu, rand, nVars);
88
89        /* local search optimization of current solutions */
90        foreach (var agent in pop_r.IterateLevels()) {
91          LocalSearchSimplex(agent.current, ref agent.currentObjValue, trainingData, rand); // CHECK paper states that pocket might also be optimized. Unclear how / when invariants are maintained.
92        }
93
94        foreach (var agent in pop_r.IteratePostOrder()) agent.MaintainInvariant(); // CHECK deviates from Alg1 in paper
95
96        /* replace old population with evolved population */
97        pop = pop_r;
98
99        /* keep track of the best solution */
100        if (bestObj > pop.pocketObjValue) { // CHECK: comparison obviously wrong in the paper
101          best = pop.pocket;
102          bestObj = pop.pocketObjValue;
103          bestObjGen = gen;
104          // Results.AddOrUpdateResult("MSE (best)", new DoubleValue(bestObj));
105          // Results.AddOrUpdateResult("Solution", CreateSymbolicRegressionSolution(best, Problem.ProblemData));
106        }
107
108
109        if (gen > bestObjGen + stagnatingGens) {
110          bestObjGen = gen; // CHECK: unspecified in the paper: wait at least stagnatingGens until resetting again
111          Reset(pop, nVars, depth, rand, trainingData);
112          // InitialPopulation(nVars, depth, rand, trainingData); CHECK reset is not specified in the paper
113        }
114      }
115    }
116
117
118
119    private Agent InitialPopulation(int nVars, int depth, IRandom rand, double[,] trainingData) {
120      /* instantiate 13 agents in the population */
121      var pop = new Agent();
122      // see Figure 2
123      for (int i = 0; i < 3; i++) {
124        pop.children.Add(new Agent());
125        for (int j = 0; j < 3; j++) {
126          pop.children[i].children.Add(new Agent());
127        }
128      }
129
130      foreach (var agent in pop.IteratePostOrder()) {
131        agent.current = new ContinuedFraction(nVars, depth, rand);
132        agent.pocket = new ContinuedFraction(nVars, depth, rand);
133
134        agent.currentObjValue = Evaluate(agent.current, trainingData);
135        agent.pocketObjValue = Evaluate(agent.pocket, trainingData);
136
137        /* within each agent, the pocket solution always holds the better value of guiding
138         * function than its current solution
139         */
140        agent.MaintainInvariant();
141      }
142      return pop;
143    }
144
145    // TODO: reset is not described in the paper
146    private void Reset(Agent root, int nVars, int depth, IRandom rand, double[,] trainingData) {
147      root.pocket = new ContinuedFraction(nVars, depth, rand);
148      root.current = new ContinuedFraction(nVars, depth, rand);
149
150      root.currentObjValue = Evaluate(root.current, trainingData);
151      root.pocketObjValue = Evaluate(root.pocket, trainingData);
152
153      /* within each agent, the pocket solution always holds the better value of guiding
154       * function than its current solution
155       */
156      root.MaintainInvariant();
157    }
158
159
160
161    private Agent RecombinePopulation(Agent pop, IRandom rand, int nVars) {
162      var l = pop;
163
164      if (pop.children.Count > 0) {
165        var s1 = pop.children[0];
166        var s2 = pop.children[1];
167        var s3 = pop.children[2];
168
169        // CHECK Deviates from paper (recombine all models in the current pop before updating the population)
170        var l_current = Recombine(l.pocket, s1.current, SelectRandomOp(rand), rand, nVars);
171        var s3_current = Recombine(s3.pocket, l.current, SelectRandomOp(rand), rand, nVars);
172        var s1_current = Recombine(s1.pocket, s2.current, SelectRandomOp(rand), rand, nVars);
173        var s2_current = Recombine(s2.pocket, s3.current, SelectRandomOp(rand), rand, nVars);
174
175        // recombination works from top to bottom
176        // CHECK do we use the new current solutions (s1_current .. s3_current) already in the next levels?
177        foreach (var child in pop.children) {
178          RecombinePopulation(child, rand, nVars);
179        }
180
181        l.current = l_current;
182        s3.current = s3_current;
183        s1.current = s1_current;
184        s2.current = s2_current;
185      }
186      return pop;
187    }
188
189    private Func<bool[], bool[], bool[]> SelectRandomOp(IRandom rand) {
190      bool[] union(bool[] a, bool[] b) {
191        var res = new bool[a.Length];
192        for (int i = 0; i < a.Length; i++) res[i] = a[i] || b[i];
193        return res;
194      }
195      bool[] intersect(bool[] a, bool[] b) {
196        var res = new bool[a.Length];
197        for (int i = 0; i < a.Length; i++) res[i] = a[i] && b[i];
198        return res;
199      }
200      bool[] symmetricDifference(bool[] a, bool[] b) {
201        var res = new bool[a.Length];
202        for (int i = 0; i < a.Length; i++) res[i] = a[i] ^ b[i];
203        return res;
204      }
205      switch (rand.Next(3)) {
206        case 0: return union;
207        case 1: return intersect;
208        case 2: return symmetricDifference;
209        default: throw new ArgumentException();
210      }
211    }
212
213    private static ContinuedFraction Recombine(ContinuedFraction p1, ContinuedFraction p2, Func<bool[], bool[], bool[]> op, IRandom rand, int nVars) {
214      ContinuedFraction ch = new ContinuedFraction() { h = new Term[p1.h.Length] };
215      /* apply a recombination operator chosen uniformly at random on variables of two parents into offspring */
216      ch.vars = op(p1.vars, p2.vars);
217
218      /* recombine the coefficients for each term (h) of the continued fraction */
219      for (int i = 0; i < p1.h.Length; i++) {
220        var coefa = p1.h[i].coef; var varsa = p1.h[i].vars;
221        var coefb = p2.h[i].coef; var varsb = p2.h[i].vars;
222
223        /* recombine coefficient values for variables */
224        var coefx = new double[nVars];
225        var varsx = new bool[nVars]; // CHECK: deviates from paper, probably forgotten in the pseudo-code
226        for (int vi = 1; vi < nVars; vi++) {
227          if (ch.vars[vi]) {  // CHECK: paper uses featAt()
228            if (varsa[vi] && varsb[vi]) {
229              coefx[vi] = coefa[vi] + (rand.NextDouble() * 5 - 1) * (coefb[vi] - coefa[vi]) / 3.0;
230              varsx[vi] = true;
231            } else if (varsa[vi]) {
232              coefx[vi] = coefa[vi];
233              varsx[vi] = true;
234            } else if (varsb[vi]) {
235              coefx[vi] = coefb[vi];
236              varsx[vi] = true;
237            }
238          }
239        }
240        /* update new coefficients of the term in offspring */
241        ch.h[i] = new Term() { coef = coefx, vars = varsx };
242        /* compute new value of constant (beta) for term hi in the offspring solution ch using
243         * beta of p1.hi and p2.hi */
244        ch.h[i].beta = p1.h[i].beta + (rand.NextDouble() * 5 - 1) * (p2.h[i].beta - p1.h[i].beta) / 3.0;
245      }
246      /* update current solution and apply local search */
247      // return LocalSearchSimplex(ch, trainingData); // CHECK: Deviates from paper because Alg1 also has LocalSearch after Recombination
248      return ch;
249    }
250
251    private Agent Mutate(Agent pop, double mutationRate, IRandom rand) {
252      foreach (var agent in pop.IterateLevels()) {
253        if (rand.NextDouble() < mutationRate) {
254          if (agent.currentObjValue < 1.2 * agent.pocketObjValue ||
255             agent.currentObjValue > 2 * agent.pocketObjValue)
256            ToggleVariables(agent.current, rand); // major mutation
257          else
258            ModifyVariable(agent.current, rand); // soft mutation
259        }
260      }
261      return pop;
262    }
263
264    private void ToggleVariables(ContinuedFraction cfrac, IRandom rand) {
265      double coinToss(double a, double b) {
266        return rand.NextDouble() < 0.5 ? a : b;
267      }
268
269      /* select a variable index uniformly at random */
270      int N = cfrac.vars.Length;
271      var vIdx = rand.Next(N);
272
273      /* for each depth of continued fraction, toggle the selection of variables of the term (h) */
274      foreach (var h in cfrac.h) {
275        /* Case 1: cfrac variable is turned ON: Turn OFF the variable, and either 'Remove' or
276         * 'Remember' the coefficient value at random */
277        if (cfrac.vars[vIdx]) {  // CHECK: paper uses varAt()
278          h.vars[vIdx] = false;  // CHECK: paper uses varAt()
279          h.coef[vIdx] = coinToss(0, h.coef[vIdx]);
280        } else {
281          /* Case 2: term variable is turned OFF: Turn ON the variable, and either 'Remove'
282           * or 'Replace' the coefficient with a random value between -3 and 3 at random */
283          if (!h.vars[vIdx]) {
284            h.vars[vIdx] = true;  // CHECK: paper uses varAt()
285            h.coef[vIdx] = coinToss(0, rand.NextDouble() * 6 - 3);
286          }
287        }
288      }
289      /* toggle the randomly selected variable */
290      cfrac.vars[vIdx] = !cfrac.vars[vIdx];  // CHECK: paper uses varAt()
291    }
292
293    private void ModifyVariable(ContinuedFraction cfrac, IRandom rand) {
294      /* randomly select a variable which is turned ON */
295      var candVars = new List<int>();
296      for (int i = 0; i < cfrac.vars.Length; i++) if (cfrac.vars[i]) candVars.Add(i);  // CHECK: paper uses varAt()
297      if (candVars.Count == 0) return; // no variable active
298      var vIdx = candVars[rand.Next(candVars.Count)];
299
300      /* randomly select a term (h) of continued fraction */
301      var h = cfrac.h[rand.Next(cfrac.h.Length)];
302
303      /* modify the coefficient value*/
304      if (h.vars[vIdx]) {  // CHECK: paper uses varAt()
305        h.coef[vIdx] = 0.0;
306      } else {
307        h.coef[vIdx] = rand.NextDouble() * 6 - 3;
308      }
309      /* Toggle the randomly selected variable */
310      h.vars[vIdx] = !h.vars[vIdx]; // CHECK: paper uses varAt()
311    }
312
313    private static double Evaluate(ContinuedFraction cfrac, double[,] trainingData) {
314      var dataPoint = new double[trainingData.GetLength(1) - 1];
315      var yIdx = trainingData.GetLength(1) - 1;
316      double sum = 0.0;
317      for (int r = 0; r < trainingData.GetLength(0); r++) {
318        for (int c = 0; c < dataPoint.Length; c++) {
319          dataPoint[c] = trainingData[r, c];
320        }
321        var y = trainingData[r, yIdx];
322        var pred = Evaluate(cfrac, dataPoint);
323        var res = y - pred;
324        sum += res * res;
325      }
326      var delta = 0.1;
327      return sum / trainingData.GetLength(0) * (1 + delta * cfrac.vars.Count(vi => vi));
328    }
329
330    private static double Evaluate(ContinuedFraction cfrac, double[] dataPoint) {
331      var res = 0.0;
332      for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
333        var hi = cfrac.h[i];
334        var hi1 = cfrac.h[i - 1];
335        var denom = hi.beta + dot(hi.vars, hi.coef, dataPoint) + res;
336        var numerator = hi1.beta + dot(hi1.vars, hi1.coef, dataPoint);
337        res = numerator / denom;
338      }
339      var h0 = cfrac.h[0];
340      res += h0.beta + dot(h0.vars, h0.coef, dataPoint);
341      return res;
342    }
343
344    private static double dot(bool[] filter, double[] x, double[] y) {
345      var s = 0.0;
346      for (int i = 0; i < x.Length; i++)
347        if (filter[i])
348          s += x[i] * y[i];
349      return s;
350    }
351
352
353    private static void LocalSearchSimplex(ContinuedFraction ch, ref double quality, double[,] trainingData, IRandom rand) {
354      double uniformPeturbation = 1.0;
355      double tolerance = 1e-3;
356      int maxEvals = 250;
357      int numSearches = 4;
358      var numRows = trainingData.GetLength(0);
359      int numSelectedRows = numRows / 5; // 20% of the training samples
360
361      quality = Evaluate(ch, trainingData); // get quality with origial coefficients
362
363      double[] origCoeff = ExtractCoeff(ch);
364      if (origCoeff.Length == 0) return; // no parameters to optimize
365
366      var bestQuality = quality;
367      var bestCoeff = origCoeff;
368
369      var fittingData = SelectRandomRows(trainingData, numSelectedRows, rand);
370
371      double objFunc(double[] curCoeff) {
372        SetCoeff(ch, curCoeff);
373        return Evaluate(ch, fittingData);
374      }
375
376      for (int count = 0; count < numSearches; count++) {
377
378        SimplexConstant[] constants = new SimplexConstant[origCoeff.Length];
379        for (int i = 0; i < origCoeff.Length; i++) {
380          constants[i] = new SimplexConstant(origCoeff[i], uniformPeturbation);
381        }
382
383        RegressionResult result = NelderMeadSimplex.Regress(constants, tolerance, maxEvals, objFunc);
384
385        var optimizedCoeff = result.Constants;
386        SetCoeff(ch, optimizedCoeff);
387
388        var newQuality = Evaluate(ch, trainingData);
389
390        if (newQuality < bestQuality) {
391          bestCoeff = optimizedCoeff;
392          bestQuality = newQuality;
393        }
394      } // reps
395
396      SetCoeff(ch, bestCoeff);
397      quality = bestQuality;
398    }
399
400    private static double[,] SelectRandomRows(double[,] trainingData, int numSelectedRows, IRandom rand) {
401      var numRows = trainingData.GetLength(0);
402      var numCols = trainingData.GetLength(1);
403      var selectedRows = Enumerable.Range(0, numRows).Shuffle(rand).Take(numSelectedRows).ToArray();
404      var subset = new double[numSelectedRows, numCols];
405      var i = 0;
406      foreach (var r in selectedRows) {
407        for (int c = 0; c < numCols; c++) {
408          subset[i, c] = trainingData[r, c];
409        }
410        i++;
411      }
412      return subset;
413    }
414
415    private static double[] ExtractCoeff(ContinuedFraction ch) {
416      var coeff = new List<double>();
417      foreach (var hi in ch.h) {
418        coeff.Add(hi.beta);
419        for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
420          if (hi.vars[vIdx]) coeff.Add(hi.coef[vIdx]);
421        }
422      }
423      return coeff.ToArray();
424    }
425
426    private static void SetCoeff(ContinuedFraction ch, double[] curCoeff) {
427      int k = 0;
428      foreach (var hi in ch.h) {
429        hi.beta = curCoeff[k++];
430        for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
431          if (hi.vars[vIdx]) hi.coef[vIdx] = curCoeff[k++];
432        }
433      }
434    }
435
436    Symbol addSy = new Addition();
437    Symbol mulSy = new Multiplication();
438    Symbol divSy = new Division();
439    Symbol startSy = new StartSymbol();
440    Symbol progSy = new ProgramRootSymbol();
441    Symbol constSy = new Constant();
442    Symbol varSy = new Problems.DataAnalysis.Symbolic.Variable();
443
444    private ISymbolicRegressionSolution CreateSymbolicRegressionSolution(ContinuedFraction cfrac, IRegressionProblemData problemData) {
445      var variables = problemData.AllowedInputVariables.ToArray();
446      ISymbolicExpressionTreeNode res = null;
447      for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
448        var hi = cfrac.h[i];
449        var hi1 = cfrac.h[i - 1];
450        var denom = CreateLinearCombination(hi.vars, hi.coef, variables, hi.beta);
451        if (res != null) {
452          denom.AddSubtree(res);
453        }
454
455        var numerator = CreateLinearCombination(hi1.vars, hi1.coef, variables, hi1.beta);
456
457        res = divSy.CreateTreeNode();
458        res.AddSubtree(numerator);
459        res.AddSubtree(denom);
460      }
461
462      var h0 = cfrac.h[0];
463      var h0Term = CreateLinearCombination(h0.vars, h0.coef, variables, h0.beta);
464      h0Term.AddSubtree(res);
465
466      var progRoot = progSy.CreateTreeNode();
467      var start = startSy.CreateTreeNode();
468      progRoot.AddSubtree(start);
469      start.AddSubtree(h0Term);
470
471      var model = new SymbolicRegressionModel(problemData.TargetVariable, new SymbolicExpressionTree(progRoot), new SymbolicDataAnalysisExpressionTreeBatchInterpreter());
472      var sol = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
473      return sol;
474    }
475
476    private ISymbolicExpressionTreeNode CreateLinearCombination(bool[] vars, double[] coef, string[] variables, double beta) {
477      var sum = addSy.CreateTreeNode();
478      for (int i = 0; i < vars.Length; i++) {
479        if (vars[i]) {
480          var varNode = (VariableTreeNode)varSy.CreateTreeNode();
481          varNode.Weight = coef[i];
482          varNode.VariableName = variables[i];
483          sum.AddSubtree(varNode);
484        }
485      }
486      sum.AddSubtree(CreateConstant(beta));
487      return sum;
488    }
489
490    private ISymbolicExpressionTreeNode CreateConstant(double value) {
491      var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
492      constNode.Value = value;
493      return constNode;
494    }
495  }
496
497  public class Agent {
498    public ContinuedFraction pocket;
499    public double pocketObjValue;
500    public ContinuedFraction current;
501    public double currentObjValue;
502
503    public IList<Agent> children = new List<Agent>();
504
505    public IEnumerable<Agent> IterateLevels() {
506      var agents = new List<Agent>() { this };
507      IterateLevelsRec(this, agents);
508      return agents;
509    }
510    public IEnumerable<Agent> IteratePostOrder() {
511      var agents = new List<Agent>();
512      IteratePostOrderRec(this, agents);
513      return agents;
514    }
515
516    internal void MaintainInvariant() {
517      foreach (var child in children) {
518        MaintainInvariant(parent: this, child);
519      }
520      if (currentObjValue < pocketObjValue) {
521        Swap(ref pocket, ref current);
522        Swap(ref pocketObjValue, ref currentObjValue);
523      }
524    }
525
526
527    private static void MaintainInvariant(Agent parent, Agent child) {
528      if (child.pocketObjValue < parent.pocketObjValue) {
529        Swap(ref child.pocket, ref parent.pocket);
530        Swap(ref child.pocketObjValue, ref parent.pocketObjValue);
531      }
532    }
533
534    private void IterateLevelsRec(Agent agent, List<Agent> agents) {
535      foreach (var child in agent.children) {
536        agents.Add(child);
537      }
538      foreach (var child in agent.children) {
539        IterateLevelsRec(child, agents);
540      }
541    }
542
543    private void IteratePostOrderRec(Agent agent, List<Agent> agents) {
544      foreach (var child in agent.children) {
545        IteratePostOrderRec(child, agents);
546      }
547      agents.Add(agent);
548    }
549
550
551    private static void Swap<T>(ref T a, ref T b) {
552      var temp = a;
553      a = b;
554      b = temp;
555    }
556  }
557}
Note: See TracBrowser for help on using the repository browser.