source: branches/3106_AnalyticContinuedFractionsRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/ContinuedFractionRegression/Algorithm.cs @ 17971

Last change on this file since 17971 was 17971, checked in by gkronber, 5 months ago

#3106: first implementation of the algorithm as described in the paper

File size: 15.1 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Threading;
5using HEAL.Attic;
6using HeuristicLab.Common;
7using HeuristicLab.Core;
8using HeuristicLab.Data;
9using HeuristicLab.Problems.DataAnalysis;
10using HeuristicLab.Random;
11
12namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression {
13  [Item("Continuous Fraction Regression (CFR)", "TODO")]
14  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 999)]
15  [StorableType("7A375270-EAAF-4AD1-82FF-132318D20E09")]
16  public class Algorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
17    public override IDeepCloneable Clone(Cloner cloner) {
18      throw new NotImplementedException();
19    }
20
21    protected override void Run(CancellationToken cancellationToken) {
22      var problemData = Problem.ProblemData;
23
24      var x = problemData.Dataset.ToArray(problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }),
25        problemData.TrainingIndices);
26      var nVars = x.GetLength(1);
27      var rand = new MersenneTwister(31415);
28      CFRAlgorithm(nVars, depth: 4, 0.10, x, out var best, out var bestObj, rand, numGen: 200, cancellationToken);
29    }
30
31    private void CFRAlgorithm(int nVars, int depth, double mutationRate, double[,] trainingData,
32      out ContinuedFraction best, out double bestObj,
33      IRandom rand, int numGen,
34      CancellationToken cancellation) {
35      /* Algorithm 1 */
36      /* Generate initial population by a randomized algorithm */
37      var pop = InitialPopulation(nVars, depth, rand, trainingData);
38      best = pop.pocket;
39      bestObj = pop.pocketObjValue;
40
41      for (int gen = 1; gen <= numGen && !cancellation.IsCancellationRequested; gen++) {
42        /* mutate each current solution in the population */
43        var pop_mu = Mutate(pop, mutationRate, rand);
44        /* generate new population by recombination mechanism */
45        var pop_r = RecombinePopulation(pop_mu, rand, nVars);
46
47        /* local search optimization of current solutions */
48        foreach (var agent in pop_r.IterateLevels()) {
49          LocalSearchSimplex(agent.current, trainingData, rand);
50        }
51
52        foreach (var agent in pop_r.IteratePostOrder()) agent.MaintainInvariant(); // Deviates from Alg1 in paper
53
54        /* TODO
55        if (stagnating(curQuality, bestQuality, numStagnatingGensForReset)) {
56          Reset(pop_r.root);
57        }
58        */
59
60        /* replace old population with evolved population */
61        pop = pop_r;
62
63        /* keep track of the best solution */
64        if (bestObj > pop.pocketObjValue) {
65          best = pop.pocket;
66          bestObj = pop.pocketObjValue;
67          Results.AddOrUpdateResult("MSE (best)", new DoubleValue(bestObj));
68        }
69      }
70    }
71
72    private Agent InitialPopulation(int nVars, int depth, IRandom rand, double[,] trainingData) {
73      /* instantiate 13 agents in the population */
74      var pop = new Agent();
75      // see Figure 2
76      for (int i = 0; i < 3; i++) {
77        pop.children.Add(new Agent());
78        for (int j = 0; j < 3; j++) {
79          pop.children[i].children.Add(new Agent());
80        }
81      }
82
83      foreach (var agent in pop.IteratePostOrder()) {
84        agent.current = new ContinuedFraction(nVars, depth, rand);
85        agent.pocket = new ContinuedFraction(nVars, depth, rand);
86
87        agent.currentObjValue = Evaluate(agent.current, trainingData);
88        agent.pocketObjValue = Evaluate(agent.pocket, trainingData);
89
90        /* within each agent, the pocket solution always holds the better value of guiding
91         * function than its current solution
92         */
93        agent.MaintainInvariant();
94      }
95      return pop;
96    }
97
98    private Agent RecombinePopulation(Agent pop, IRandom rand, int nVars) {
99      var l = pop;
100      if (pop.children.Count > 0) {
101        var s1 = pop.children[0];
102        var s2 = pop.children[1];
103        var s3 = pop.children[2];
104        l.current = Recombine(l.pocket, s1.current, SelectRandomOp(rand), rand, nVars);
105        s3.current = Recombine(s3.pocket, l.current, SelectRandomOp(rand), rand, nVars);
106        s1.current = Recombine(s1.pocket, s2.current, SelectRandomOp(rand), rand, nVars);
107        s2.current = Recombine(s2.pocket, s3.current, SelectRandomOp(rand), rand, nVars);
108      }
109
110      foreach (var child in pop.children) {
111        RecombinePopulation(child, rand, nVars);
112      }
113      return pop;
114    }
115
116    private Func<bool[], bool[], bool[]> SelectRandomOp(IRandom rand) {
117      bool[] union(bool[] a, bool[] b) {
118        var res = new bool[a.Length];
119        for (int i = 0; i < a.Length; i++) res[i] = a[i] || b[i];
120        return res;
121      }
122      bool[] intersect(bool[] a, bool[] b) {
123        var res = new bool[a.Length];
124        for (int i = 0; i < a.Length; i++) res[i] = a[i] && b[i];
125        return res;
126      }
127      bool[] symmetricDifference(bool[] a, bool[] b) {
128        var res = new bool[a.Length];
129        for (int i = 0; i < a.Length; i++) res[i] = a[i] ^ b[i];
130        return res;
131      }
132      switch (rand.Next(3)) {
133        case 0: return union;
134        case 1: return intersect;
135        case 2: return symmetricDifference;
136        default: throw new ArgumentException();
137      }
138    }
139
140    private static ContinuedFraction Recombine(ContinuedFraction p1, ContinuedFraction p2, Func<bool[], bool[], bool[]> op, IRandom rand, int nVars) {
141      ContinuedFraction ch = new ContinuedFraction() { h = new Term[p1.h.Length] };
142      /* apply a recombination operator chosen uniformly at random on variable sof two parents into offspring */
143      ch.vars = op(p1.vars, p2.vars);
144
145      /* recombine the coefficients for each term (h) of the continued fraction */
146      for (int i = 0; i < p1.h.Length; i++) {
147        var coefa = p1.h[i].coef; var varsa = p1.h[i].vars;
148        var coefb = p2.h[i].coef; var varsb = p2.h[i].vars;
149
150        /* recombine coefficient values for variables */
151        var coefx = new double[nVars];
152        var varsx = new bool[nVars]; // TODO: deviates from paper -> check
153        for (int vi = 1; vi < nVars; vi++) {
154          if (ch.vars[vi]) {
155            if (varsa[vi] && varsb[vi]) {
156              coefx[vi] = coefa[vi] + (rand.NextDouble() * 5 - 1) * (coefb[vi] - coefa[vi]) / 3.0;
157              varsx[vi] = true;
158            } else if (varsa[vi]) {
159              coefx[vi] = coefa[vi];
160              varsx[vi] = true;
161            } else if (varsb[vi]) {
162              coefx[vi] = coefb[vi];
163              varsx[vi] = true;
164            }
165          }
166        }
167        /* update new coefficients of the term in offspring */
168        ch.h[i] = new Term() { coef = coefx, vars = varsx };
169        /* compute new value of constant (beta) for term hi in the offspring solution ch using
170         * beta of p1.hi and p2.hi */
171        ch.h[i].beta = p1.h[i].beta + (rand.NextDouble() * 5 - 1) * (p2.h[i].beta - p1.h[i].beta) / 3.0;
172      }
173      /* update current solution and apply local search */
174      // return LocalSearchSimplex(ch, trainingData); // Deviates from paper because Alg1 also has LocalSearch after Recombination
175      return ch;
176    }
177
178    private Agent Mutate(Agent pop, double mutationRate, IRandom rand) {
179      foreach (var agent in pop.IterateLevels()) {
180        if (rand.NextDouble() < mutationRate) {
181          if (agent.currentObjValue < 1.2 * agent.pocketObjValue ||
182             agent.currentObjValue > 2 * agent.pocketObjValue)
183            ToggleVariables(agent.current, rand); // major mutation
184          else ModifyVariable(agent.current, rand); // soft mutation
185        }
186      }
187      return pop;
188    }
189
190    private void ToggleVariables(ContinuedFraction cfrac, IRandom rand) {
191      double coinToss(double a, double b) {
192        return rand.NextDouble() < 0.5 ? a : b;
193      }
194
195      /* select a variable index uniformly at random */
196      int N = cfrac.vars.Length;
197      var vIdx = rand.Next(N);
198
199      /* for each depth of continued fraction, toggle the selection of variables of the term (h) */
200      foreach (var h in cfrac.h) {
201        /* Case 1: cfrac variable is turned ON: Turn OFF the variable, and either 'Remove' or
202         * 'Remember' the coefficient value at random */
203        if (cfrac.vars[vIdx]) {
204          h.vars[vIdx] = false;
205          h.coef[vIdx] = coinToss(0, h.coef[vIdx]);
206        } else {
207          /* Case 2: term variable is turned OFF: Turn ON the variable, and either 'Remove'
208           * or 'Replace' the coefficient with a random value between -3 and 3 at random */
209          if (!h.vars[vIdx]) {
210            h.vars[vIdx] = true;
211            h.coef[vIdx] = coinToss(0, rand.NextDouble() * 6 - 3);
212          }
213        }
214      }
215      /* toggle the randomly selected variable */
216      cfrac.vars[vIdx] = !cfrac.vars[vIdx];
217    }
218
219    private void ModifyVariable(ContinuedFraction cfrac, IRandom rand) {
220      /* randomly select a variable which is turned ON */
221      var candVars = cfrac.vars.Count(vi => vi);
222      if (candVars == 0) return; // no variable active
223      var vIdx = rand.Next(candVars);
224
225      /* randomly select a term (h) of continued fraction */
226      var h = cfrac.h[rand.Next(cfrac.h.Length)];
227
228      /* modify the coefficient value*/
229      if (h.vars[vIdx]) {
230        h.coef[vIdx] = 0.0;
231      } else {
232        h.coef[vIdx] = rand.NextDouble() * 6 - 3;
233      }
234      /* Toggle the randomly selected variable */
235      h.vars[vIdx] = !h.vars[vIdx];
236    }
237
238    private static double Evaluate(ContinuedFraction cfrac, double[,] trainingData) {
239      var dataPoint = new double[trainingData.GetLength(1) - 1];
240      var yIdx = trainingData.GetLength(1) - 1;
241      double sum = 0.0;
242      for (int r = 0; r < trainingData.GetLength(0); r++) {
243        for (int c = 0; c < dataPoint.Length; c++) {
244          dataPoint[c] = trainingData[r, c];
245        }
246        var y = trainingData[r, yIdx];
247        var pred = Evaluate(cfrac, dataPoint);
248        var res = y - pred;
249        sum += res * res;
250      }
251      var delta = 0.1; // TODO
252      return sum / trainingData.GetLength(0) * (1 + delta * cfrac.vars.Count(vi => vi));
253    }
254
255    private static double Evaluate(ContinuedFraction cfrac, double[] dataPoint) {
256      var res = 0.0;
257      for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
258        var hi = cfrac.h[i];
259        var hi1 = cfrac.h[i - 1];
260        var denom = hi.beta + dot(dataPoint, hi.coef) + res;
261        var numerator = hi1.beta + dot(dataPoint, hi1.coef);
262        res = numerator / denom;
263      }
264      return res;
265    }
266
267    private static double dot(double[] x, double[] y) {
268      var s = 0.0;
269      for (int i = 0; i < x.Length; i++)
270        s += x[i] * y[i];
271      return s;
272    }
273
274
275    private static ContinuedFraction LocalSearchSimplex(ContinuedFraction ch, double[,] trainingData, IRandom rand) {
276      double uniformPeturbation = 1.0;
277      double tolerance = 1e-3;
278      int maxEvals = 250;
279      int numSearches = 4;
280      var numRows = trainingData.GetLength(0);
281      int numSelectedRows = numRows / 5; // 20% of the training samples
282
283      double[] origCoeff = ExtractCoeff(ch);
284      if (origCoeff.Length == 0) return ch; // no parameters to optimize
285
286      var bestQuality = Evaluate(ch, trainingData); // get quality with origial coefficients
287      var bestCoeff = origCoeff;
288
289      var fittingData = SelectRandomRows(trainingData, numSelectedRows, rand);
290
291      double objFunc(double[] curCoeff) {
292        SetCoeff(ch, curCoeff);
293        return Evaluate(ch, fittingData);
294      }
295
296      for (int count = 0; count < numSearches; count++) {
297
298        SimplexConstant[] constants = new SimplexConstant[origCoeff.Length];
299        for (int i = 0; i < origCoeff.Length; i++) {
300          constants[i] = new SimplexConstant(origCoeff[i], uniformPeturbation);
301        }
302
303        RegressionResult result = NelderMeadSimplex.Regress(constants, tolerance, maxEvals, objFunc);
304
305        var optimizedCoeff = result.Constants;
306        SetCoeff(ch, optimizedCoeff);
307
308        var newQuality = Evaluate(ch, trainingData);
309
310        // TODO: optionally use regularization (ridge / LASSO)
311
312        if (newQuality < bestQuality) {
313          bestCoeff = optimizedCoeff;
314          bestQuality = newQuality;
315        }
316      } // reps
317
318      SetCoeff(ch, bestCoeff);
319      return ch;
320    }
321
322    private static double[,] SelectRandomRows(double[,] trainingData, int numSelectedRows, IRandom rand) {
323      var numRows = trainingData.GetLength(0);
324      var numCols = trainingData.GetLength(1);
325      var selectedRows = Enumerable.Range(0, numRows).Shuffle(rand).Take(numSelectedRows).ToArray();
326      var subset = new double[numSelectedRows, numCols];
327      var i = 0;
328      foreach (var r in selectedRows) {
329        for (int c = 0; c < numCols; c++) {
330          subset[i, c] = trainingData[r, c];
331        }
332        i++;
333      }
334      return subset;
335    }
336
337    private static double[] ExtractCoeff(ContinuedFraction ch) {
338      var coeff = new List<double>();
339      foreach (var hi in ch.h) {
340        for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
341          if (hi.vars[vIdx]) coeff.Add(hi.coef[vIdx]);
342        }
343      }
344      return coeff.ToArray();
345    }
346
347    private static void SetCoeff(ContinuedFraction ch, double[] curCoeff) {
348      int k = 0;
349      foreach (var hi in ch.h) {
350        for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
351          if (hi.vars[vIdx]) hi.coef[vIdx] = curCoeff[k++];
352        }
353      }
354    }
355  }
356
357  public class Agent {
358    public ContinuedFraction pocket;
359    public double pocketObjValue;
360    public ContinuedFraction current;
361    public double currentObjValue;
362
363    public IList<Agent> children = new List<Agent>();
364
365    public IEnumerable<Agent> IterateLevels() {
366      var agents = new List<Agent>() { this };
367      IterateLevelsRec(this, agents);
368      return agents;
369    }
370    public IEnumerable<Agent> IteratePostOrder() {
371      var agents = new List<Agent>();
372      IteratePostOrderRec(this, agents);
373      return agents;
374    }
375
376    internal void MaintainInvariant() {
377      foreach (var child in children) {
378        MaintainInvariant(parent: this, child);
379      }
380      if (currentObjValue < pocketObjValue) {
381        Swap(ref pocket, ref current);
382        Swap(ref pocketObjValue, ref currentObjValue);
383      }
384    }
385
386
387    private static void MaintainInvariant(Agent parent, Agent child) {
388      if (child.pocketObjValue < parent.pocketObjValue) {
389        Swap(ref child.pocket, ref parent.pocket);
390        Swap(ref child.pocketObjValue, ref parent.pocketObjValue);
391      }
392    }
393
394    private void IterateLevelsRec(Agent agent, List<Agent> agents) {
395      foreach (var child in agent.children) {
396        agents.Add(child);
397      }
398      foreach (var child in agent.children) {
399        IterateLevelsRec(child, agents);
400      }
401    }
402
403    private void IteratePostOrderRec(Agent agent, List<Agent> agents) {
404      foreach (var child in agent.children) {
405        IteratePostOrderRec(child, agents);
406      }
407      agents.Add(agent);
408    }
409
410
411    private static void Swap<T>(ref T a, ref T b) {
412      var temp = a;
413      a = b;
414      b = temp;
415    }
416  }
417}
Note: See TracBrowser for help on using the repository browser.