Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs @ 13650

Last change on this file since 13650 was 13650, checked in by gkronber, 8 years ago

#2581: fixed License header

File size: 18.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics.Contracts;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
35  public static class MctsSymbolicRegressionStatic {
36    // TODO: SGD with adagrad instead of lbfgs?
37    // TODO: check Taylor expansion capabilities (ln(x), sqrt(x), exp(x)) in combination with GBT
38    // TODO: optimize for 3 targets concurrently (y, 1/y, exp(y), and log(y))? Would simplify the number of possible expressions again
39    #region static API
40
41    public interface IState {
42      bool Done { get; }
43      ISymbolicRegressionModel BestModel { get; }
44      double BestSolutionTrainingQuality { get; }
45      double BestSolutionTestQuality { get; }
46    }
47
48    // created through factory method
49    private class State : IState {
50      private const int MaxParams = 100;
51
52      // state variables used by MCTS
53      internal readonly Automaton automaton;
54      internal IRandom random { get; private set; }
55      internal readonly double c;
56      internal readonly Tree tree;
57      internal readonly List<Tree> bestChildrenBuf;
58      internal readonly Func<byte[], int, double> evalFun;
59
60
61      // state variables used only internally (for eval function)
62      private readonly IRegressionProblemData problemData;
63      private readonly double[][] x;
64      private readonly double[] y;
65      private readonly double[][] testX;
66      private readonly double[] testY;
67      private readonly double[] scalingFactor;
68      private readonly double[] scalingOffset;
69      private readonly int constOptIterations;
70      private readonly double lowerEstimationLimit, upperEstimationLimit;
71
72      private readonly ExpressionEvaluator evaluator, testEvaluator;
73
74      // values for best solution
75      private double bestRSq;
76      private byte[] bestCode;
77      private int bestNParams;
78      private double[] bestConsts;
79
80      // buffers
81      private readonly double[] ones; // vector of ones (as default params)
82      private readonly double[] constsBuf;
83      private readonly double[] predBuf, testPredBuf;
84      private readonly double[][] gradBuf;
85
86      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, double c, bool scaleVariables, int constOptIterations,
87        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
88        bool allowProdOfVars = true,
89        bool allowExp = true,
90        bool allowLog = true,
91        bool allowInv = true,
92        bool allowMultipleTerms = false) {
93
94        this.problemData = problemData;
95        this.c = c;
96        this.constOptIterations = constOptIterations;
97        this.evalFun = this.Eval;
98        this.lowerEstimationLimit = lowerEstimationLimit;
99        this.upperEstimationLimit = upperEstimationLimit;
100
101        random = new MersenneTwister(randSeed);
102
103        // prepare data for evaluation
104        double[][] x;
105        double[] y;
106        double[][] testX;
107        double[] testY;
108        double[] scalingFactor;
109        double[] scalingOffset;
110        // get training and test datasets (scale linearly based on training set if required)
111        GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
112        GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
113        this.x = x;
114        this.y = y;
115        this.testX = testX;
116        this.testY = testY;
117        this.scalingFactor = scalingFactor;
118        this.scalingOffset = scalingOffset;
119        this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
120        // we need a separate evaluator because the vector length for the test dataset might differ
121        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
122
123        this.automaton = new Automaton(x, maxVariables, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
124        this.tree = new Tree() { state = automaton.CurrentState };
125
126        // reset best solution
127        this.bestRSq = 0;
128        // code for default solution (constant model)
129        this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
130        this.bestNParams = 0;
131        this.bestConsts = null;
132
133        // init buffers
134        this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
135        constsBuf = new double[MaxParams];
136        this.bestChildrenBuf = new List<Tree>(2 * x.Length); // the number of follow states in the automaton is O(number of variables) 2 * number of variables should be sufficient (capacity is increased if necessary anyway)
137        this.predBuf = new double[y.Length];
138        this.testPredBuf = new double[testY.Length];
139
140        this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
141      }
142
143      #region IState inferface
144      public bool Done { get { return tree != null && tree.done; } }
145
146      public double BestSolutionTrainingQuality {
147        get {
148          evaluator.Exec(bestCode, x, bestConsts, predBuf);
149          return RSq(y, predBuf);
150        }
151      }
152
153      public double BestSolutionTestQuality {
154        get {
155          testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
156          return RSq(testY, testPredBuf);
157        }
158      }
159
160      // takes the code of the best solution and creates and equivalent symbolic regression model
161      public ISymbolicRegressionModel BestModel {
162        get {
163          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
164          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
165          var simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
166
167          var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
168          var simpleT = simplifier.Simplify(t);
169          var model = new SymbolicRegressionModel(simpleT, interpreter, lowerEstimationLimit, upperEstimationLimit);
170
171          // model has already been scaled linearly in Eval
172          return model;
173        }
174      }
175      #endregion
176
177      private double Eval(byte[] code, int nParams) {
178        double[] optConsts;
179        double q;
180        Eval(code, nParams, out q, out optConsts);
181
182        if (q > bestRSq) {
183          bestRSq = q;
184          bestNParams = nParams;
185          this.bestCode = new byte[code.Length];
186          this.bestConsts = new double[bestNParams];
187
188          Array.Copy(code, bestCode, code.Length);
189          Array.Copy(optConsts, bestConsts, bestNParams);
190        }
191
192        return q;
193      }
194
195      private void Eval(byte[] code, int nParams, out double rsq, out double[] optConsts) {
196        // we make a first pass to determine a valid starting configuration for all constants
197        // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
198        // scale and offset are set to optimal starting configuration
199        // assumes scale is the first param and offset is the last param
200        double alpha;
201        double beta;
202
203        // reset constants
204        Array.Copy(ones, constsBuf, nParams);
205        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
206
207        // calc opt scaling (alpha*f(x) + beta)
208        OnlineCalculatorError error;
209        OnlineLinearScalingParameterCalculator.Calculate(predBuf, y, out alpha, out beta, out error);
210        if (error == OnlineCalculatorError.None) {
211          constsBuf[0] *= beta;
212          constsBuf[nParams - 1] = constsBuf[nParams - 1] * beta + alpha;
213        }
214        if (nParams <= 2 || constOptIterations <= 0) {
215          // if we don't need to optimize parameters then we are done
216          // changing scale and offset does not influence r²
217          rsq = RSq(y, predBuf);
218          optConsts = constsBuf;
219        } else {
220          // optimize constants using the starting point calculated above
221          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
222          evaluator.Exec(code, x, constsBuf, predBuf);
223          rsq = RSq(y, predBuf);
224          optConsts = constsBuf;
225        }
226      }
227
228
229
230      #region helpers
231      private static double RSq(IEnumerable<double> x, IEnumerable<double> y) {
232        OnlineCalculatorError error;
233        double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
234        return error == OnlineCalculatorError.None ? r * r : 0.0;
235      }
236
237
238      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
239        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt
240        Array.Copy(consts, optConsts, nParams);
241
242        alglib.minlmstate state;
243        alglib.minlmreport rep = null;
244        alglib.minlmcreatevj(y.Length, optConsts, out state);
245        alglib.minlmsetcond(state, 0.0, epsF, 0.0, nIters);
246        //alglib.minlmsetgradientcheck(state, 0.000001);
247        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
248        alglib.minlmresults(state, out optConsts, out rep);
249
250        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
251
252        // only use optimized constants if successful
253        if (rep.terminationtype >= 0) {
254          Array.Copy(optConsts, consts, optConsts.Length);
255        }
256      }
257
258      private void Func(double[] arg, double[] fi, object obj) {
259        // 0.5 * MSE and gradient
260        var code = (byte[])obj;
261        evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
262        for (int r = 0; r < predBuf.Length; r++) {
263          var res = predBuf[r] - y[r];
264          fi[r] = res;
265        }
266      }
267      private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
268        int nParams = arg.Length;
269        var code = (byte[])obj;
270        evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
271        for (int r = 0; r < predBuf.Length; r++) {
272          var res = predBuf[r] - y[r];
273          fi[r] = res;
274
275          for (int k = 0; k < nParams; k++) {
276            jac[r, k] = gradBuf[k][r];
277          }
278        }
279      }
280      #endregion
281    }
282
283    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3, double c = 1.0,
284      bool scaleVariables = true, int constOptIterations = 0, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
285      bool allowProdOfVars = true,
286      bool allowExp = true,
287      bool allowLog = true,
288      bool allowInv = true,
289      bool allowMultipleTerms = false
290      ) {
291      return new State(problemData, randSeed, maxVariables, c, scaleVariables, constOptIterations,
292        lowerEstimationLimit, upperEstimationLimit,
293        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
294    }
295
296    // returns the quality of the evaluated solution
297    public static double MakeStep(IState state) {
298      var mctsState = state as State;
299      if (mctsState == null) throw new ArgumentException("state");
300      if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
301
302      return TreeSearch(mctsState);
303    }
304    #endregion
305
306    private static double TreeSearch(State mctsState) {
307      var automaton = mctsState.automaton;
308      var tree = mctsState.tree;
309      var eval = mctsState.evalFun;
310      var bestChildrenBuf = mctsState.bestChildrenBuf;
311      var rand = mctsState.random;
312      double c = mctsState.c;
313
314      automaton.Reset();
315      return TreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf);
316    }
317
318    private static double TreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf) {
319      Tree selectedChild = null;
320      double q;
321      Contract.Assert(tree.state == automaton.CurrentState);
322      Contract.Assert(!tree.done);
323      if (tree.children == null) {
324        if (automaton.IsFinalState(tree.state)) {
325          // final state
326          tree.done = true;
327
328          // EVALUATE
329          byte[] code; int nParams;
330          automaton.GetCode(out code, out nParams);
331          q = eval(code, nParams);
332          tree.visits++;
333          tree.sumQuality += q;
334          return q;
335        } else {
336          // EXPAND
337          int[] possibleFollowStates;
338          int nFs;
339          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
340
341          tree.children = new Tree[nFs];
342          for (int i = 0; i < tree.children.Length; i++) tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0 };
343
344          selectedChild = SelectFinalOrRandom(automaton, tree, rand);
345        }
346      } else {
347        // tree.children != null
348        // UCT selection within tree
349        selectedChild = SelectUct(tree, rand, c, bestChildrenBuf);
350      }
351      // make selected step and recurse
352      automaton.Goto(selectedChild.state);
353      q = TreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf);
354
355      tree.sumQuality += q;
356      tree.visits++;
357
358      // tree.done = tree.children.All(ch => ch.done);
359      tree.done = true; for (int i = 0; i < tree.children.Length && tree.done; i++) tree.done = tree.children[i].done;
360      if (tree.done) {
361        tree.children = null; // cut of the sub-branch if it has been fully explored
362        // TODO: update all qualities and visits to remove the information gained from this whole branch
363      }
364      return q;
365    }
366
367    private static Tree SelectUct(Tree tree, IRandom rand, double c, List<Tree> bestChildrenBuf) {
368      // determine total tries of still active children
369      int totalTries = 0;
370      bestChildrenBuf.Clear();
371      for (int i = 0; i < tree.children.Length; i++) {
372        var ch = tree.children[i];
373        if (ch.done) continue;
374        if (ch.visits == 0) bestChildrenBuf.Add(ch);
375        else totalTries += tree.children[i].visits;
376      }
377      // if there are unvisited children select a random child
378      if (bestChildrenBuf.Any()) {
379        return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
380      }
381      Contract.Assert(totalTries > 0); // the tree is not done yet so there is at least on child that is not done
382      double logTotalTries = Math.Log(totalTries);
383      var bestQ = double.NegativeInfinity;
384      for (int i = 0; i < tree.children.Length; i++) {
385        var ch = tree.children[i];
386        if (ch.done) continue;
387        var childQ = ch.AverageQuality + c * Math.Sqrt(logTotalTries / ch.visits);
388        if (childQ > bestQ) {
389          bestChildrenBuf.Clear();
390          bestChildrenBuf.Add(ch);
391          bestQ = childQ;
392        } else if (childQ >= bestQ) {
393          bestChildrenBuf.Add(ch);
394        }
395      }
396      return bestChildrenBuf.Count > 0 ? bestChildrenBuf[rand.Next(bestChildrenBuf.Count)] : bestChildrenBuf[0];
397    }
398
399    private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
400      // if one of the new children leads to a final state then go there
401      // otherwise choose a random child
402      int selectedChildIdx = -1;
403      // find first final state if there is one
404      for (int i = 0; i < tree.children.Length; i++) {
405        if (automaton.IsFinalState(tree.children[i].state)) {
406          selectedChildIdx = i;
407          break;
408        }
409      }
410      // no final state -> select a random child
411      if (selectedChildIdx == -1) {
412        selectedChildIdx = rand.Next(tree.children.Length);
413      }
414      return tree.children[selectedChildIdx];
415    }
416
417    // scales data and extracts values from dataset into arrays
418    private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,
419      out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
420      xs = new double[problemData.AllowedInputVariables.Count()][];
421
422      var i = 0;
423      if (scaleVariables) {
424        scalingFactor = new double[xs.Length];
425        scalingOffset = new double[xs.Length];
426      } else {
427        scalingFactor = null;
428        scalingOffset = null;
429      }
430      foreach (var var in problemData.AllowedInputVariables) {
431        if (scaleVariables) {
432          var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
433          var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
434          var range = maxX - minX;
435
436          // scaledX = (x - min) / range
437          var sf = 1.0 / range;
438          var offset = -minX / range;
439          scalingFactor[i] = sf;
440          scalingOffset[i] = offset;
441          i++;
442        }
443      }
444
445      GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
446    }
447
448    // extract values from dataset into arrays
449    private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,
450     out double[][] xs, out double[] y) {
451      xs = new double[problemData.AllowedInputVariables.Count()][];
452
453      int i = 0;
454      foreach (var var in problemData.AllowedInputVariables) {
455        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
456        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
457        xs[i++] =
458          problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
459      }
460
461      y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
462    }
463  }
464}
Note: See TracBrowser for help on using the repository browser.