Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs @ 13657

Last change on this file since 13657 was 13657, checked in by gkronber, 8 years ago

#2581: update quality estimate in parent nodes when a branch is completely explored. added ucbtuned selection

File size: 22.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics.Contracts;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
35  public static class MctsSymbolicRegressionStatic {
36    // TODO: SGD with adagrad instead of lbfgs?
37    // TODO: check Taylor expansion capabilities (ln(x), sqrt(x), exp(x)) in combination with GBT
38    // TODO: optimize for 3 targets concurrently (y, 1/y, exp(y), and log(y))? Would simplify the number of possible expressions again
39    #region static API
40
41    public interface IState {
42      bool Done { get; }
43      ISymbolicRegressionModel BestModel { get; }
44      double BestSolutionTrainingQuality { get; }
45      double BestSolutionTestQuality { get; }
46      int TotalRollouts { get; }
47      int EffectiveRollouts { get; }
48      int FuncEvaluations { get; }
49      int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
50      // TODO other stats on LM optimizer might be interesting here
51    }
52
53    // created through factory method
54    private class State : IState {
55      private const int MaxParams = 100;
56
57      // state variables used by MCTS
58      internal readonly Automaton automaton;
59      internal IRandom random { get; private set; }
60      internal readonly double c;
61      internal readonly Tree tree;
62      internal readonly List<Tree> bestChildrenBuf;
63      internal readonly Func<byte[], int, double> evalFun;
64      // MCTS might get stuck. Track statistics on the number of effective rollouts
65      internal int totalRollouts;
66      internal int effectiveRollouts;
67
68
69      // state variables used only internally (for eval function)
70      private readonly IRegressionProblemData problemData;
71      private readonly double[][] x;
72      private readonly double[] y;
73      private readonly double[][] testX;
74      private readonly double[] testY;
75      private readonly double[] scalingFactor;
76      private readonly double[] scalingOffset;
77      private readonly int constOptIterations;
78      private readonly double lowerEstimationLimit, upperEstimationLimit;
79
80      private readonly ExpressionEvaluator evaluator, testEvaluator;
81
82      // values for best solution
83      private double bestRSq;
84      private byte[] bestCode;
85      private int bestNParams;
86      private double[] bestConsts;
87
88      // stats
89      private int funcEvaluations;
90      private int gradEvaluations;
91
92      // buffers
93      private readonly double[] ones; // vector of ones (as default params)
94      private readonly double[] constsBuf;
95      private readonly double[] predBuf, testPredBuf;
96      private readonly double[][] gradBuf;
97
98      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, double c, bool scaleVariables, int constOptIterations,
99        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
100        bool allowProdOfVars = true,
101        bool allowExp = true,
102        bool allowLog = true,
103        bool allowInv = true,
104        bool allowMultipleTerms = false) {
105
106        this.problemData = problemData;
107        this.c = c;
108        this.constOptIterations = constOptIterations;
109        this.evalFun = this.Eval;
110        this.lowerEstimationLimit = lowerEstimationLimit;
111        this.upperEstimationLimit = upperEstimationLimit;
112
113        random = new MersenneTwister(randSeed);
114
115        // prepare data for evaluation
116        double[][] x;
117        double[] y;
118        double[][] testX;
119        double[] testY;
120        double[] scalingFactor;
121        double[] scalingOffset;
122        // get training and test datasets (scale linearly based on training set if required)
123        GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
124        GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
125        this.x = x;
126        this.y = y;
127        this.testX = testX;
128        this.testY = testY;
129        this.scalingFactor = scalingFactor;
130        this.scalingOffset = scalingOffset;
131        this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
132        // we need a separate evaluator because the vector length for the test dataset might differ
133        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
134
135        this.automaton = new Automaton(x, maxVariables, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
136        this.tree = new Tree() { state = automaton.CurrentState };
137
138        // reset best solution
139        this.bestRSq = 0;
140        // code for default solution (constant model)
141        this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
142        this.bestNParams = 0;
143        this.bestConsts = null;
144
145        // init buffers
146        this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
147        constsBuf = new double[MaxParams];
148        this.bestChildrenBuf = new List<Tree>(2 * x.Length); // the number of follow states in the automaton is O(number of variables) 2 * number of variables should be sufficient (capacity is increased if necessary anyway)
149        this.predBuf = new double[y.Length];
150        this.testPredBuf = new double[testY.Length];
151
152        this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
153      }
154
155      #region IState inferface
156      public bool Done { get { return tree != null && tree.done; } }
157
158      public double BestSolutionTrainingQuality {
159        get {
160          evaluator.Exec(bestCode, x, bestConsts, predBuf);
161          return RSq(y, predBuf);
162        }
163      }
164
165      public double BestSolutionTestQuality {
166        get {
167          testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
168          return RSq(testY, testPredBuf);
169        }
170      }
171
172      // takes the code of the best solution and creates and equivalent symbolic regression model
173      public ISymbolicRegressionModel BestModel {
174        get {
175          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
176          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
177
178          var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
179          var model = new SymbolicRegressionModel(t, interpreter, lowerEstimationLimit, upperEstimationLimit);
180
181          // model has already been scaled linearly in Eval
182          return model;
183        }
184      }
185
186      public int TotalRollouts { get { return totalRollouts; } }
187      public int EffectiveRollouts { get { return effectiveRollouts; } }
188      public int FuncEvaluations { get { return funcEvaluations; } }
189      public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
190
191      #endregion
192
193      private double Eval(byte[] code, int nParams) {
194        double[] optConsts;
195        double q;
196        Eval(code, nParams, out q, out optConsts);
197
198        if (q > bestRSq) {
199          bestRSq = q;
200          bestNParams = nParams;
201          this.bestCode = new byte[code.Length];
202          this.bestConsts = new double[bestNParams];
203
204          Array.Copy(code, bestCode, code.Length);
205          Array.Copy(optConsts, bestConsts, bestNParams);
206        }
207
208        return q;
209      }
210
211      private void Eval(byte[] code, int nParams, out double rsq, out double[] optConsts) {
212        // we make a first pass to determine a valid starting configuration for all constants
213        // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
214        // scale and offset are set to optimal starting configuration
215        // assumes scale is the first param and offset is the last param
216        double alpha;
217        double beta;
218
219        // reset constants
220        Array.Copy(ones, constsBuf, nParams);
221        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
222        funcEvaluations++;
223
224        // calc opt scaling (alpha*f(x) + beta)
225        OnlineCalculatorError error;
226        OnlineLinearScalingParameterCalculator.Calculate(predBuf, y, out alpha, out beta, out error);
227        if (error == OnlineCalculatorError.None) {
228          constsBuf[0] *= beta;
229          constsBuf[nParams - 1] = constsBuf[nParams - 1] * beta + alpha;
230        }
231        if (nParams <= 2 || constOptIterations <= 0) {
232          // if we don't need to optimize parameters then we are done
233          // changing scale and offset does not influence r²
234          rsq = RSq(y, predBuf);
235          optConsts = constsBuf;
236        } else {
237          // optimize constants using the starting point calculated above
238          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
239
240          evaluator.Exec(code, x, constsBuf, predBuf);
241          funcEvaluations++;
242
243          rsq = RSq(y, predBuf);
244          optConsts = constsBuf;
245        }
246      }
247
248
249
250      #region helpers
251      private static double RSq(IEnumerable<double> x, IEnumerable<double> y) {
252        OnlineCalculatorError error;
253        double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
254        return error == OnlineCalculatorError.None ? r * r : 0.0;
255      }
256
257
258      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
259        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
260        Array.Copy(consts, optConsts, nParams);
261
262        alglib.minlmstate state;
263        alglib.minlmreport rep = null;
264        alglib.minlmcreatevj(y.Length, optConsts, out state);
265        alglib.minlmsetcond(state, 0.0, epsF, 0.0, nIters);
266        //alglib.minlmsetgradientcheck(state, 0.000001);
267        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
268        alglib.minlmresults(state, out optConsts, out rep);
269        funcEvaluations += rep.nfunc;
270        gradEvaluations += rep.njac * nParams;
271
272        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
273
274        // only use optimized constants if successful
275        if (rep.terminationtype >= 0) {
276          Array.Copy(optConsts, consts, optConsts.Length);
277        }
278      }
279
280      private void Func(double[] arg, double[] fi, object obj) {
281        var code = (byte[])obj;
282        evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
283        for (int r = 0; r < predBuf.Length; r++) {
284          var res = predBuf[r] - y[r];
285          fi[r] = res;
286        }
287      }
288      private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
289        int nParams = arg.Length;
290        var code = (byte[])obj;
291        evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
292        for (int r = 0; r < predBuf.Length; r++) {
293          var res = predBuf[r] - y[r];
294          fi[r] = res;
295
296          for (int k = 0; k < nParams; k++) {
297            jac[r, k] = gradBuf[k][r];
298          }
299        }
300      }
301      #endregion
302    }
303
304    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3, double c = 1.0,
305      bool scaleVariables = true, int constOptIterations = 0, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
306      bool allowProdOfVars = true,
307      bool allowExp = true,
308      bool allowLog = true,
309      bool allowInv = true,
310      bool allowMultipleTerms = false
311      ) {
312      return new State(problemData, randSeed, maxVariables, c, scaleVariables, constOptIterations,
313        lowerEstimationLimit, upperEstimationLimit,
314        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
315    }
316
317    // returns the quality of the evaluated solution
318    public static double MakeStep(IState state) {
319      var mctsState = state as State;
320      if (mctsState == null) throw new ArgumentException("state");
321      if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
322
323      return TreeSearch(mctsState);
324    }
325    #endregion
326
327    private static double TreeSearch(State mctsState) {
328      var automaton = mctsState.automaton;
329      var tree = mctsState.tree;
330      var eval = mctsState.evalFun;
331      var bestChildrenBuf = mctsState.bestChildrenBuf;
332      var rand = mctsState.random;
333      double c = mctsState.c;
334      double q = 0;
335      double deltaQ = 0;
336      double deltaSqrQ = 0;
337      int deltaVisits = 0;
338      bool success = false;
339      do {
340        automaton.Reset();
341        success = TryTreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf, out q, out deltaQ, out deltaSqrQ, out deltaVisits);
342        mctsState.totalRollouts++;
343      } while (!success && !tree.done);
344      mctsState.effectiveRollouts++;
345      return q;
346    }
347
348    // tree search might fail because of constraints for expressions
349    // in this case we get stuck we just restart
350    // see ConstraintHandler.cs for more info
351    private static bool TryTreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf,
352      out double q, // quality of the expression
353      out double deltaQ, out double deltaSqrQ, out int deltaVisits // the updates for total quality and number of visits (can be negative if branches have been fully explored)
354      ) {
355      Tree selectedChild = null;
356      Contract.Assert(tree.state == automaton.CurrentState);
357      Contract.Assert(!tree.done);
358      if (tree.children == null) {
359        if (automaton.IsFinalState(tree.state)) {
360          // final state
361          tree.done = true;
362
363          // EVALUATE
364          byte[] code; int nParams;
365          automaton.GetCode(out code, out nParams);
366          q = eval(code, nParams);
367          tree.visits += 1;
368          tree.sumQuality += q;
369          tree.sumSqrQuality += q * q;
370          deltaQ = q;
371          deltaVisits = 1;
372          deltaSqrQ = q * q;
373          return true; // we reached a final state
374        } else {
375          // EXPAND
376          int[] possibleFollowStates;
377          int nFs;
378          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
379          if (nFs == 0) {
380            // stuck in a dead end (no final state and no allowed follow states)
381            q = 0;
382            deltaQ = 0;
383            deltaSqrQ = 0.0;
384            deltaVisits = 0;
385            tree.done = true;
386            tree.children = null;
387            tree.visits = 1;
388            return false;
389          }
390          tree.children = new Tree[nFs];
391          for (int i = 0; i < tree.children.Length; i++)
392            tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0 };
393
394          selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
395        }
396      } else {
397        // tree.children != null
398        // UCT selection within tree
399        selectedChild = tree.children.Length > 1 ? SelectUctTuned(tree, rand, c, bestChildrenBuf) : tree.children[0];
400      }
401      // make selected step and recurse
402      automaton.Goto(selectedChild.state);
403      var success = TryTreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf,
404        out q, out deltaQ, out deltaSqrQ, out deltaVisits);
405      if (success) {
406        // only update if successful
407        tree.sumQuality += deltaQ;
408        tree.sumSqrQuality += deltaSqrQ;
409        tree.visits += deltaVisits;
410      }
411
412      if (tree.children.All(ch => ch.done)) {
413        tree.done = true;
414        // update parent nodes to remove information from this branch
415        if (tree.children.Length > 1) {
416          deltaQ = -(tree.sumQuality - deltaQ);
417          deltaSqrQ = -(tree.sumSqrQuality - deltaSqrQ);
418          deltaVisits = -(tree.visits - deltaVisits);
419        }
420        tree.children = null; // cut off the sub-branch if it has been fully explored
421      }
422      return success;
423    }
424
425    private static Tree SelectUct(Tree tree, IRandom rand, double c, List<Tree> bestChildrenBuf) {
426      // determine total tries of still active children
427      int totalTries = 0;
428      bestChildrenBuf.Clear();
429      for (int i = 0; i < tree.children.Length; i++) {
430        var ch = tree.children[i];
431        if (ch.done) continue;
432        if (ch.visits == 0) bestChildrenBuf.Add(ch);
433        else totalTries += tree.children[i].visits;
434      }
435      // if there are unvisited children select a random child
436      if (bestChildrenBuf.Any()) {
437        return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
438      }
439      Contract.Assert(totalTries > 0); // the tree is not done yet so there is at least on child that is not done
440      double logTotalTries = Math.Log(totalTries);
441      var bestQ = double.NegativeInfinity;
442      for (int i = 0; i < tree.children.Length; i++) {
443        var ch = tree.children[i];
444        if (ch.done) continue;
445        var childQ = ch.AverageQuality + c * Math.Sqrt(logTotalTries / ch.visits);
446        if (childQ > bestQ) {
447          bestChildrenBuf.Clear();
448          bestChildrenBuf.Add(ch);
449          bestQ = childQ;
450        } else if (childQ >= bestQ) {
451          bestChildrenBuf.Add(ch);
452        }
453      }
454      return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
455    }
456
457    private static Tree SelectUctTuned(Tree tree, IRandom rand, double c, List<Tree> bestChildrenBuf) {
458      // determine total tries of still active children
459      int totalTries = 0;
460      bestChildrenBuf.Clear();
461      for (int i = 0; i < tree.children.Length; i++) {
462        var ch = tree.children[i];
463        if (ch.done) continue;
464        if (ch.visits == 0) bestChildrenBuf.Add(ch);
465        else totalTries += tree.children[i].visits;
466      }
467      // if there are unvisited children select a random child
468      if (bestChildrenBuf.Any()) {
469        return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
470      }
471      Contract.Assert(totalTries > 0); // the tree is not done yet so there is at least on child that is not done
472      double logTotalTries = Math.Log(totalTries);
473      var bestQ = double.NegativeInfinity;
474      for (int i = 0; i < tree.children.Length; i++) {
475        var ch = tree.children[i];
476        if (ch.done) continue;
477        var varianceBound = ch.QualityVariance + Math.Sqrt(2.0 * logTotalTries / ch.visits);
478        if (varianceBound > 0.25) varianceBound = 0.25;
479        var childQ = ch.AverageQuality + c * Math.Sqrt(logTotalTries / ch.visits * varianceBound);
480        if (childQ > bestQ) {
481          bestChildrenBuf.Clear();
482          bestChildrenBuf.Add(ch);
483          bestQ = childQ;
484        } else if (childQ >= bestQ) {
485          bestChildrenBuf.Add(ch);
486        }
487      }
488      return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
489    }
490
491    private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
492      // if one of the new children leads to a final state then go there
493      // otherwise choose a random child
494      int selectedChildIdx = -1;
495      // find first final state if there is one
496      for (int i = 0; i < tree.children.Length; i++) {
497        if (automaton.IsFinalState(tree.children[i].state)) {
498          selectedChildIdx = i;
499          break;
500        }
501      }
502      // no final state -> select a random child
503      if (selectedChildIdx == -1) {
504        selectedChildIdx = rand.Next(tree.children.Length);
505      }
506      return tree.children[selectedChildIdx];
507    }
508
509    // scales data and extracts values from dataset into arrays
510    private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,
511      out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
512      xs = new double[problemData.AllowedInputVariables.Count()][];
513
514      var i = 0;
515      if (scaleVariables) {
516        scalingFactor = new double[xs.Length];
517        scalingOffset = new double[xs.Length];
518      } else {
519        scalingFactor = null;
520        scalingOffset = null;
521      }
522      foreach (var var in problemData.AllowedInputVariables) {
523        if (scaleVariables) {
524          var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
525          var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
526          var range = maxX - minX;
527
528          // scaledX = (x - min) / range
529          var sf = 1.0 / range;
530          var offset = -minX / range;
531          scalingFactor[i] = sf;
532          scalingOffset[i] = offset;
533          i++;
534        }
535      }
536
537      GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
538    }
539
540    // extract values from dataset into arrays
541    private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,
542     out double[][] xs, out double[] y) {
543      xs = new double[problemData.AllowedInputVariables.Count()][];
544
545      int i = 0;
546      foreach (var var in problemData.AllowedInputVariables) {
547        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
548        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
549        xs[i++] =
550          problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
551      }
552
553      y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
554    }
555  }
556}
Note: See TracBrowser for help on using the repository browser.