source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs @ 13645

Last change on this file since 13645 was 13645, checked in by gkronber, 3 years ago

#2581: added an MCTS for symbolic regression models

File size: 18.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Collections.Generic;
25using System.Diagnostics.Contracts;
26using System.Linq;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Problems.DataAnalysis;
31using HeuristicLab.Problems.DataAnalysis.Symbolic;
32using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
33using HeuristicLab.Random;
34
35namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
36  public static class MctsSymbolicRegressionStatic {
37    // TODO: SGD with adagrad instead of lbfgs?
38    // TODO: check Taylor expansion capabilities (ln(x), sqrt(x), exp(x)) in combination with GBT
39    // TODO: optimize for 3 targets concurrently (y, 1/y, exp(y), and log(y))? Would simplify the number of possible expressions again
40    #region static API
41
42    public interface IState {
43      bool Done { get; }
44      ISymbolicRegressionModel BestModel { get; }
45      double BestSolutionTrainingQuality { get; }
46      double BestSolutionTestQuality { get; }
47    }
48
49    // created through factory method
50    private class State : IState {
51      private const int MaxParams = 100;
52
53      // state variables used by MCTS
54      internal readonly Automaton automaton;
55      internal IRandom random { get; private set; }
56      internal readonly double c;
57      internal readonly Tree tree;
58      internal readonly List<Tree> bestChildrenBuf;
59      internal readonly Func<byte[], int, double> evalFun;
60
61
62      // state variables used only internally (for eval function)
63      private readonly IRegressionProblemData problemData;
64      private readonly double[][] x;
65      private readonly double[] y;
66      private readonly double[][] testX;
67      private readonly double[] testY;
68      private readonly double[] scalingFactor;
69      private readonly double[] scalingOffset;
70      private readonly int constOptIterations;
71      private readonly double lowerEstimationLimit, upperEstimationLimit;
72
73      private readonly ExpressionEvaluator evaluator, testEvaluator;
74
75      // values for best solution
76      private double bestRSq;
77      private byte[] bestCode;
78      private int bestNParams;
79      private double[] bestConsts;
80
81      // buffers
82      private readonly double[] ones; // vector of ones (as default params)
83      private readonly double[] constsBuf;
84      private readonly double[] predBuf, testPredBuf;
85      private readonly double[][] gradBuf;
86
87      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, double c, bool scaleVariables, int constOptIterations,
88        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
89        bool allowProdOfVars = true,
90        bool allowExp = true,
91        bool allowLog = true,
92        bool allowInv = true,
93        bool allowMultipleTerms = false) {
94
95        this.problemData = problemData;
96        this.c = c;
97        this.constOptIterations = constOptIterations;
98        this.evalFun = this.Eval;
99        this.lowerEstimationLimit = lowerEstimationLimit;
100        this.upperEstimationLimit = upperEstimationLimit;
101
102        random = new MersenneTwister(randSeed);
103
104        // prepare data for evaluation
105        double[][] x;
106        double[] y;
107        double[][] testX;
108        double[] testY;
109        double[] scalingFactor;
110        double[] scalingOffset;
111        // get training and test datasets (scale linearly based on training set if required)
112        GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
113        GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
114        this.x = x;
115        this.y = y;
116        this.testX = testX;
117        this.testY = testY;
118        this.scalingFactor = scalingFactor;
119        this.scalingOffset = scalingOffset;
120        this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
121        // we need a separate evaluator because the vector length for the test dataset might differ
122        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
123
124        this.automaton = new Automaton(x, maxVariables, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
125        this.tree = new Tree() { state = automaton.CurrentState };
126
127        // reset best solution
128        this.bestRSq = 0;
129        // code for default solution (constant model)
130        this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
131        this.bestNParams = 0;
132        this.bestConsts = null;
133
134        // init buffers
135        this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
136        constsBuf = new double[MaxParams];
137        this.bestChildrenBuf = new List<Tree>(2 * x.Length); // the number of follow states in the automaton is O(number of variables) 2 * number of variables should be sufficient (capacity is increased if necessary anyway)
138        this.predBuf = new double[y.Length];
139        this.testPredBuf = new double[testY.Length];
140
141        this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
142      }
143
144      #region IState inferface
145      public bool Done { get { return tree != null && tree.done; } }
146
147      public double BestSolutionTrainingQuality {
148        get {
149          evaluator.Exec(bestCode, x, bestConsts, predBuf);
150          return RSq(y, predBuf);
151        }
152      }
153
154      public double BestSolutionTestQuality {
155        get {
156          testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
157          return RSq(testY, testPredBuf);
158        }
159      }
160
161      // takes the code of the best solution and creates and equivalent symbolic regression model
162      public ISymbolicRegressionModel BestModel {
163        get {
164          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
165          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
166          var simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
167
168          var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
169          var simpleT = simplifier.Simplify(t);
170          var model = new SymbolicRegressionModel(simpleT, interpreter, lowerEstimationLimit, upperEstimationLimit);
171
172          // model has already been scaled linearly in Eval
173          return model;
174        }
175      }
176      #endregion
177
178      private double Eval(byte[] code, int nParams) {
179        double[] optConsts;
180        double q;
181        Eval(code, nParams, out q, out optConsts);
182
183        if (q > bestRSq) {
184          bestRSq = q;
185          bestNParams = nParams;
186          this.bestCode = new byte[code.Length];
187          this.bestConsts = new double[bestNParams];
188
189          Array.Copy(code, bestCode, code.Length);
190          Array.Copy(optConsts, bestConsts, bestNParams);
191        }
192
193        return q;
194      }
195
196      private void Eval(byte[] code, int nParams, out double rsq, out double[] optConsts) {
197        // we make a first pass to determine a valid starting configuration for all constants
198        // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
199        // scale and offset are set to optimal starting configuration
200        // assumes scale is the first param and offset is the last param
201        double alpha;
202        double beta;
203
204        // reset constants
205        Array.Copy(ones, constsBuf, nParams);
206        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
207
208        // calc opt scaling (alpha*f(x) + beta)
209        OnlineCalculatorError error;
210        OnlineLinearScalingParameterCalculator.Calculate(predBuf, y, out alpha, out beta, out error);
211        if (error == OnlineCalculatorError.None) {
212          constsBuf[0] *= beta;
213          constsBuf[nParams - 1] = constsBuf[nParams - 1] * beta + alpha;
214        }
215        if (nParams <= 2 || constOptIterations <= 0) {
216          // if we don't need to optimize parameters then we are done
217          // changing scale and offset does not influence r²
218          rsq = RSq(y, predBuf);
219          optConsts = constsBuf;
220        } else {
221          // optimize constants using the starting point calculated above
222          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
223          evaluator.Exec(code, x, constsBuf, predBuf);
224          rsq = RSq(y, predBuf);
225          optConsts = constsBuf;
226        }
227      }
228
229
230
231      #region helpers
232      private static double RSq(IEnumerable<double> x, IEnumerable<double> y) {
233        OnlineCalculatorError error;
234        double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
235        return error == OnlineCalculatorError.None ? r * r : 0.0;
236      }
237
238
239      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
240        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt
241        Array.Copy(consts, optConsts, nParams);
242
243        alglib.minlmstate state;
244        alglib.minlmreport rep = null;
245        alglib.minlmcreatevj(y.Length, optConsts, out state);
246        alglib.minlmsetcond(state, 0.0, epsF, 0.0, nIters);
247        //alglib.minlmsetgradientcheck(state, 0.000001);
248        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
249        alglib.minlmresults(state, out optConsts, out rep);
250
251        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
252
253        // only use optimized constants if successful
254        if (rep.terminationtype >= 0) {
255          Array.Copy(optConsts, consts, optConsts.Length);
256        }
257      }
258
259      private void Func(double[] arg, double[] fi, object obj) {
260        // 0.5 * MSE and gradient
261        var code = (byte[])obj;
262        evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
263        for (int r = 0; r < predBuf.Length; r++) {
264          var res = predBuf[r] - y[r];
265          fi[r] = res;
266        }
267      }
268      private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
269        int nParams = arg.Length;
270        var code = (byte[])obj;
271        evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
272        for (int r = 0; r < predBuf.Length; r++) {
273          var res = predBuf[r] - y[r];
274          fi[r] = res;
275
276          for (int k = 0; k < nParams; k++) {
277            jac[r, k] = gradBuf[k][r];
278          }
279        }
280      }
281      #endregion
282    }
283
284    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3, double c = 1.0,
285      bool scaleVariables = true, int constOptIterations = 0, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
286      bool allowProdOfVars = true,
287      bool allowExp = true,
288      bool allowLog = true,
289      bool allowInv = true,
290      bool allowMultipleTerms = false
291      ) {
292      return new State(problemData, randSeed, maxVariables, c, scaleVariables, constOptIterations,
293        lowerEstimationLimit, upperEstimationLimit,
294        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
295    }
296
297    // returns the quality of the evaluated solution
298    public static double MakeStep(IState state) {
299      var mctsState = state as State;
300      if (mctsState == null) throw new ArgumentException("state");
301      if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
302
303      return TreeSearch(mctsState);
304    }
305    #endregion
306
307    private static double TreeSearch(State mctsState) {
308      var automaton = mctsState.automaton;
309      var tree = mctsState.tree;
310      var eval = mctsState.evalFun;
311      var bestChildrenBuf = mctsState.bestChildrenBuf;
312      var rand = mctsState.random;
313      double c = mctsState.c;
314
315      automaton.Reset();
316      return TreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf);
317    }
318
319    private static double TreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf) {
320      Tree selectedChild = null;
321      double q;
322      Contract.Assert(tree.state == automaton.CurrentState);
323      Contract.Assert(!tree.done);
324      if (tree.children == null) {
325        if (automaton.IsFinalState(tree.state)) {
326          // final state
327          tree.done = true;
328
329          // EVALUATE
330          byte[] code; int nParams;
331          automaton.GetCode(out code, out nParams);
332          q = eval(code, nParams);
333          tree.visits++;
334          tree.sumQuality += q;
335          return q;
336        } else {
337          // EXPAND
338          int[] possibleFollowStates;
339          int nFs;
340          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
341
342          tree.children = new Tree[nFs];
343          for (int i = 0; i < tree.children.Length; i++) tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0 };
344
345          selectedChild = SelectFinalOrRandom(automaton, tree, rand);
346        }
347      } else {
348        // tree.children != null
349        // UCT selection within tree
350        selectedChild = SelectUct(tree, rand, c, bestChildrenBuf);
351      }
352      // make selected step and recurse
353      automaton.Goto(selectedChild.state);
354      q = TreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf);
355
356      tree.sumQuality += q;
357      tree.visits++;
358
359      // tree.done = tree.children.All(ch => ch.done);
360      tree.done = true; for (int i = 0; i < tree.children.Length && tree.done; i++) tree.done = tree.children[i].done;
361      if (tree.done) {
362        tree.children = null; // cut of the sub-branch if it has been fully explored
363        // TODO: update all qualities and visits to remove the information gained from this whole branch
364      }
365      return q;
366    }
367
368    private static Tree SelectUct(Tree tree, IRandom rand, double c, List<Tree> bestChildrenBuf) {
369      // determine total tries of still active children
370      int totalTries = 0;
371      bestChildrenBuf.Clear();
372      for (int i = 0; i < tree.children.Length; i++) {
373        var ch = tree.children[i];
374        if (ch.done) continue;
375        if (ch.visits == 0) bestChildrenBuf.Add(ch);
376        else totalTries += tree.children[i].visits;
377      }
378      // if there are unvisited children select a random child
379      if (bestChildrenBuf.Any()) {
380        return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
381      }
382      Contract.Assert(totalTries > 0); // the tree is not done yet so there is at least on child that is not done
383      double logTotalTries = Math.Log(totalTries);
384      var bestQ = double.NegativeInfinity;
385      for (int i = 0; i < tree.children.Length; i++) {
386        var ch = tree.children[i];
387        if (ch.done) continue;
388        var childQ = ch.AverageQuality + c * Math.Sqrt(logTotalTries / ch.visits);
389        if (childQ > bestQ) {
390          bestChildrenBuf.Clear();
391          bestChildrenBuf.Add(ch);
392          bestQ = childQ;
393        } else if (childQ >= bestQ) {
394          bestChildrenBuf.Add(ch);
395        }
396      }
397      return bestChildrenBuf.Count > 0 ? bestChildrenBuf[rand.Next(bestChildrenBuf.Count)] : bestChildrenBuf[0];
398    }
399
400    private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
401      // if one of the new children leads to a final state then go there
402      // otherwise choose a random child
403      int selectedChildIdx = -1;
404      // find first final state if there is one
405      for (int i = 0; i < tree.children.Length; i++) {
406        if (automaton.IsFinalState(tree.children[i].state)) {
407          selectedChildIdx = i;
408          break;
409        }
410      }
411      // no final state -> select a random child
412      if (selectedChildIdx == -1) {
413        selectedChildIdx = rand.Next(tree.children.Length);
414      }
415      return tree.children[selectedChildIdx];
416    }
417
418    // scales data and extracts values from dataset into arrays
419    private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,
420      out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
421      xs = new double[problemData.AllowedInputVariables.Count()][];
422
423      var i = 0;
424      if (scaleVariables) {
425        scalingFactor = new double[xs.Length];
426        scalingOffset = new double[xs.Length];
427      } else {
428        scalingFactor = null;
429        scalingOffset = null;
430      }
431      foreach (var var in problemData.AllowedInputVariables) {
432        if (scaleVariables) {
433          var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
434          var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
435          var range = maxX - minX;
436
437          // scaledX = (x - min) / range
438          var sf = 1.0 / range;
439          var offset = -minX / range;
440          scalingFactor[i] = sf;
441          scalingOffset[i] = offset;
442          i++;
443        }
444      }
445
446      GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
447    }
448
449    // extract values from dataset into arrays
450    private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,
451     out double[][] xs, out double[] y) {
452      xs = new double[problemData.AllowedInputVariables.Count()][];
453
454      int i = 0;
455      foreach (var var in problemData.AllowedInputVariables) {
456        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
457        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
458        xs[i++] =
459          problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
460      }
461
462      y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
463    }
464  }
465}
Note: See TracBrowser for help on using the repository browser.