Free cookie consent management tool by TermsFeed Policy Generator

source: branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs @ 15414

Last change on this file since 15414 was 15414, checked in by gkronber, 5 years ago

#2796 worked on MCTS

File size: 41.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Diagnostics.Contracts;
26using System.Linq;
27using System.Text;
28using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
29using HeuristicLab.Core;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31using HeuristicLab.Optimization;
32using HeuristicLab.Problems.DataAnalysis;
33using HeuristicLab.Problems.DataAnalysis.Symbolic;
34using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
35using HeuristicLab.Random;
36
37namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
38  public static class MctsSymbolicRegressionStatic {
39    // OBJECTIVES:
40    // 1) solve toy problems without numeric constants (to show that structure search is effective / efficient)
41    //    - e.g. Keijzer, Nguyen ... where no numeric constants are involved
42    //    - assumptions:
43    //      - we don't know the necessary operations or functions -> all available functions could be necessary
44    //      - but we do not need to tune numeric constants -> no scaling of input variables x!
45    // 2) Solve toy problems with numeric constants to make the algorithm invariant concerning variable scale.
46    //    This is important for real world applications.
47    //    - e.g. Korns or Vladislavleva problems where numeric constants are involved
48    //    - assumptions:
49    //      - any numeric constant is possible (a-priori we might assume that small abs. constants are more likely)
50    //      - standardization of variables is possible (or might be necessary) as we adjust numeric parameters of the expression anyway
51    //      - to simplify the problem we can restrict the set of functions e.g. we assume which functions are necessary for the problem instance
52    //        -> several steps: (a) polyinomials, (b) rational polynomials, (c) exponential or logarithmic functions, rational functions with exponential and logarithmic parts
53    // 3) efficiency and effectiveness for real-world problems
54    //    - e.g. Tower problem
55    //    - (1) and (2) combined, structure search must be effective in combination with numeric optimization of constants
56    //   
57
58    // TODO: Taking averages of R² values is probably not ideal as an improvement of R² from 0.99 to 0.999 should
59    //       weight more than an improvement from 0.98 to 0.99. Also, we are more interested in the best value of a
60    //       branch and less in the expected value. (--> Review "Extreme Bandit" literature again)
61    // TODO: Solve Poly-10
62    // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?
63    // TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants
64    // TODO: check if transformation of y is correct and works (Obj 2)
65    // TODO: The algorithm is not invariant to location and scale of variables.
66    //       Include offset for variables as parameter (for Objective 2)
67    // TODO: why does LM optimization converge so slowly with exp(x), log(x), and 1/x allowed (Obj 2)?
68    // TODO: support e(-x) and possibly (1/-x) (Obj 1)
69    // TODO: is it OK to initialize all constants to 1 (Obj 2)?
70    // TODO: improve memory usage
71    #region static API
72
73    public interface IState {
74      bool Done { get; }
75      ISymbolicRegressionModel BestModel { get; }
76      double BestSolutionTrainingQuality { get; }
77      double BestSolutionTestQuality { get; }
78      IEnumerable<ISymbolicRegressionSolution> ParetoBestModels { get; }
79      int TotalRollouts { get; }
80      int EffectiveRollouts { get; }
81      int FuncEvaluations { get; }
82      int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
83      // TODO other stats on LM optimizer might be interesting here
84    }
85
86    // created through factory method
87    private class State : IState {
88      private const int MaxParams = 100;
89
90      // state variables used by MCTS
91      internal readonly Automaton automaton;
92      internal IRandom random { get; private set; }
93      internal readonly Tree tree;
94      internal readonly Func<byte[], int, double> evalFun;
95      internal readonly IPolicy treePolicy;
96      // MCTS might get stuck. Track statistics on the number of effective rollouts
97      internal int totalRollouts;
98      internal int effectiveRollouts;
99
100
101      // state variables used only internally (for eval function)
102      private readonly IRegressionProblemData problemData;
103      private readonly double[][] x;
104      private readonly double[] y;
105      private readonly double[][] testX;
106      private readonly double[] testY;
107      private readonly double[] scalingFactor;
108      private readonly double[] scalingOffset;
109      private readonly double yStdDev; // for scaling parameters (e.g. stopping condition for LM)
110      private readonly int constOptIterations;
111      private readonly double lambda; // weight of penalty term for regularization
112      private readonly double lowerEstimationLimit, upperEstimationLimit;
113      private readonly bool collectParetoOptimalModels;
114      private readonly List<ISymbolicRegressionSolution> paretoBestModels = new List<ISymbolicRegressionSolution>();
115      private readonly List<double[]> paretoFront = new List<double[]>(); // matching the models
116
117      private readonly ExpressionEvaluator evaluator, testEvaluator;
118
119      // values for best solution
120      private double bestRSq;
121      private byte[] bestCode;
122      private int bestNParams;
123      private double[] bestConsts;
124
125      // stats
126      private int funcEvaluations;
127      private int gradEvaluations;
128
129      // buffers
130      private readonly double[] ones; // vector of ones (as default params)
131      private readonly double[] constsBuf;
132      private readonly double[] predBuf, testPredBuf;
133      private readonly double[][] gradBuf;
134
135      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,
136        int constOptIterations, double lambda,
137        IPolicy treePolicy = null,
138        bool collectParetoOptimalModels = false,
139        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
140        bool allowProdOfVars = true,
141        bool allowExp = true,
142        bool allowLog = true,
143        bool allowInv = true,
144        bool allowMultipleTerms = false) {
145
146        if (lambda < 0) throw new ArgumentException("Lambda must be larger or equal zero", "lambda");
147
148        this.problemData = problemData;
149        this.constOptIterations = constOptIterations;
150        this.lambda = lambda;
151        this.evalFun = this.Eval;
152        this.lowerEstimationLimit = lowerEstimationLimit;
153        this.upperEstimationLimit = upperEstimationLimit;
154        this.collectParetoOptimalModels = collectParetoOptimalModels;
155
156        random = new MersenneTwister(randSeed);
157
158        // prepare data for evaluation
159        double[][] x;
160        double[] y;
161        double[][] testX;
162        double[] testY;
163        double[] scalingFactor;
164        double[] scalingOffset;
165        // get training and test datasets (scale linearly based on training set if required)
166        GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
167        GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
168        this.x = x;
169        this.y = y;
170        this.yStdDev = HeuristicLab.Common.EnumerableStatisticExtensions.StandardDeviation(y);
171        this.testX = testX;
172        this.testY = testY;
173        this.scalingFactor = scalingFactor;
174        this.scalingOffset = scalingOffset;
175        this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
176        // we need a separate evaluator because the vector length for the test dataset might differ
177        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
178
179        this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
180        this.treePolicy = treePolicy ?? new Ucb();
181        this.tree = new Tree() {
182          state = automaton.CurrentState,
183          actionStatistics = treePolicy.CreateActionStatistics(),
184          expr = "",
185          level = 0
186        };
187
188        // reset best solution
189        this.bestRSq = 0;
190        // code for default solution (constant model)
191        this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
192        this.bestNParams = 0;
193        this.bestConsts = null;
194
195        // init buffers
196        this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
197        constsBuf = new double[MaxParams];
198        this.predBuf = new double[y.Length];
199        this.testPredBuf = new double[testY.Length];
200
201        this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
202      }
203
204      #region IState inferface
205      public bool Done { get { return tree != null && tree.Done; } }
206
207      public double BestSolutionTrainingQuality {
208        get {
209          evaluator.Exec(bestCode, x, bestConsts, predBuf);
210          return RSq(y, predBuf);
211        }
212      }
213
214      public double BestSolutionTestQuality {
215        get {
216          testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
217          return RSq(testY, testPredBuf);
218        }
219      }
220
221      // takes the code of the best solution and creates and equivalent symbolic regression model
222      public ISymbolicRegressionModel BestModel {
223        get {
224          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
225          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
226
227          var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
228          var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
229          model.Scale(problemData); // apply linear scaling
230          return model;
231        }
232      }
233      public IEnumerable<ISymbolicRegressionSolution> ParetoBestModels {
234        get { return paretoBestModels; }
235      }
236
237      public int TotalRollouts { get { return totalRollouts; } }
238      public int EffectiveRollouts { get { return effectiveRollouts; } }
239      public int FuncEvaluations { get { return funcEvaluations; } }
240      public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
241
242      #endregion
243
244      private double Eval(byte[] code, int nParams) {
245        double[] optConsts;
246        double q;
247        Eval(code, nParams, out q, out optConsts);
248
249        // single objective best
250        if (q > bestRSq) {
251          bestRSq = q;
252          bestNParams = nParams;
253          this.bestCode = new byte[code.Length];
254          this.bestConsts = new double[bestNParams];
255
256          Array.Copy(code, bestCode, code.Length);
257          Array.Copy(optConsts, bestConsts, bestNParams);
258        }
259        if (collectParetoOptimalModels) {
260          // multi-objective best
261          var complexity = // SymbolicDataAnalysisModelComplexityCalculator.CalculateComplexity() TODO: implement Kommenda's tree complexity directly in the evaluator
262            Array.FindIndex(code, (opc) => opc == (byte)OpCodes.Exit);  // use length of expression as surrogate for complexity
263          UpdateParetoFront(q, complexity, code, optConsts, nParams, scalingFactor, scalingOffset);
264        }
265        return q;
266      }
267
268      private void Eval(byte[] code, int nParams, out double rsq, out double[] optConsts) {
269        // we make a first pass to determine a valid starting configuration for all constants
270        // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
271        // scale and offset are set to optimal starting configuration
272        // assumes scale is the first param and offset is the last param
273
274        // reset constants
275        Array.Copy(ones, constsBuf, nParams);
276        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
277        funcEvaluations++;
278
279        if (nParams == 0 || constOptIterations < 0) {
280          // if we don't need to optimize parameters then we are done
281          // changing scale and offset does not influence r²
282          rsq = RSq(y, predBuf);
283          optConsts = constsBuf;
284        } else {
285          // optimize constants using the starting point calculated above
286          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
287
288          evaluator.Exec(code, x, constsBuf, predBuf);
289          funcEvaluations++;
290
291          rsq = RSq(y, predBuf);
292          optConsts = constsBuf;
293        }
294      }
295
296
297
298      #region helpers
299      private static double RSq(IEnumerable<double> x, IEnumerable<double> y) {
300        OnlineCalculatorError error;
301        double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
302        return error == OnlineCalculatorError.None ? r * r : 0.0;
303      }
304
305
306      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
307        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
308        Array.Copy(consts, optConsts, nParams);
309
310        // direct usage of LM is recommended in alglib manual for better performance than the lsfit interface (which uses lm internally).
311        alglib.minlmstate state;
312        alglib.minlmreport rep = null;
313        alglib.minlmcreatevj(y.Length + 1, optConsts, out state); // +1 for penalty term
314        // Using the change of the gradient as stopping criterion is recommended in alglib manual.
315        // However, the most recent version of alglib (as of Oct 2017) only supports epsX as stopping criterion
316        alglib.minlmsetcond(state, epsg: 1E-6 * yStdDev, epsf: epsF, epsx: 0.0, maxits: nIters);
317        // alglib.minlmsetgradientcheck(state, 1E-5);
318        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
319        alglib.minlmresults(state, out optConsts, out rep);
320        funcEvaluations += rep.nfunc;
321        gradEvaluations += rep.njac * nParams;
322
323        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
324
325        // only use optimized constants if successful
326        if (rep.terminationtype >= 0) {
327          Array.Copy(optConsts, consts, optConsts.Length);
328        }
329      }
330
331      private void Func(double[] arg, double[] fi, object obj) {
332        var code = (byte[])obj;
333        int n = predBuf.Length;
334        evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
335        for (int r = 0; r < n; r++) {
336          var res = predBuf[r] - y[r];
337          fi[r] = res;
338        }
339
340        var penaltyIdx = fi.Length - 1;
341        fi[penaltyIdx] = 0.0;
342        // calc length of parameter vector for regularization
343        var aa = 0.0;
344        for (int i = 0; i < arg.Length; i++) {
345          aa += arg[i] * arg[i];
346        }
347        if (lambda > 0 && aa > 0) {
348          // scale lambda using stdDev(y) to make the parameter independent of the scale of y
349          // scale lambda using n to make parameter independent of the number of training points
350          // take the root because LM squares the result
351          fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
352        }
353      }
354
355      private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
356        int n = predBuf.Length;
357        int nParams = arg.Length;
358        var code = (byte[])obj;
359        evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
360        for (int r = 0; r < n; r++) {
361          var res = predBuf[r] - y[r];
362          fi[r] = res;
363
364          for (int k = 0; k < nParams; k++) {
365            jac[r, k] = gradBuf[k][r];
366          }
367        }
368        // calc length of parameter vector for regularization
369        double aa = 0.0;
370        for (int i = 0; i < arg.Length; i++) {
371          aa += arg[i] * arg[i];
372        }
373
374        var penaltyIdx = fi.Length - 1;
375        if (lambda > 0 && aa > 0) {
376          fi[penaltyIdx] = 0.0;
377          // scale lambda using stdDev(y) to make the parameter independent of the scale of y
378          // scale lambda using n to make parameter independent of the number of training points
379          // take the root because alglib LM squares the result
380          fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
381
382          for (int i = 0; i < arg.Length; i++) {
383            jac[penaltyIdx, i] = 0.5 / fi[penaltyIdx] * 2 * n * lambda / yStdDev * arg[i];
384          }
385        } else {
386          fi[penaltyIdx] = 0.0;
387          for (int i = 0; i < arg.Length; i++) {
388            jac[penaltyIdx, i] = 0.0;
389          }
390        }
391      }
392
393
394      private void UpdateParetoFront(double q, int complexity, byte[] code, double[] param, int nParam,
395        double[] scalingFactor, double[] scalingOffset) {
396        double[] best = new double[2];
397        double[] cur = new double[2] { q, complexity };
398        bool[] max = new[] { true, false };
399        var isNonDominated = true;
400        foreach (var e in paretoFront) {
401          var domRes = DominationCalculator<int>.Dominates(cur, e, max, true);
402          if (domRes == DominationResult.IsDominated) {
403            isNonDominated = false;
404            break;
405          }
406        }
407        if (isNonDominated) {
408          paretoFront.Add(cur);
409
410          // create model
411          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
412          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
413
414          var t = new SymbolicExpressionTree(treeGen.Exec(code, param, nParam, scalingFactor, scalingOffset));
415          var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
416          model.Scale(problemData); // apply linear scaling
417
418          var sol = model.CreateRegressionSolution(this.problemData);
419          sol.Name = string.Format("{0:N5} {1}", q, complexity);
420
421          paretoBestModels.Add(sol);
422        }
423        for (int i = paretoFront.Count - 2; i >= 0; i--) {
424          var @ref = paretoFront[i];
425          var domRes = DominationCalculator<int>.Dominates(cur, @ref, max, true);
426          if (domRes == DominationResult.Dominates) {
427            paretoFront.RemoveAt(i);
428            paretoBestModels.RemoveAt(i);
429          }
430        }
431      }
432      #endregion
433    }
434
435
436    /// <summary>
437    /// Static method to initialize a state for the algorithm
438    /// </summary>
439    /// <param name="problemData">The problem data</param>
440    /// <param name="randSeed">Random seed.</param>
441    /// <param name="maxVariables">Maximum number of variable references that are allowed in the expression.</param>
442    /// <param name="scaleVariables">Optionally scale input variables to the interval [0..1] (recommended)</param>
443    /// <param name="constOptIterations">Maximum number of iterations for constants optimization (Levenberg-Marquardt)</param>
444    /// <param name="lambda">Penalty factor for regularization (0..inf.), small penalty disabled regularization.</param>
445    /// <param name="policy">Tree search policy (random, ucb, eps-greedy, ...)</param>
446    /// <param name="collectParameterOptimalModels">Optionally collect all Pareto-optimal solutions having minimal length and error.</param>
447    /// <param name="lowerEstimationLimit">Optionally limit the result of the expression to this lower value.</param>
448    /// <param name="upperEstimationLimit">Optionally limit the result of the expression to this upper value.</param>
449    /// <param name="allowProdOfVars">Allow products of expressions.</param>
450    /// <param name="allowExp">Allow expressions with exponentials.</param>
451    /// <param name="allowLog">Allow expressions with logarithms</param>
452    /// <param name="allowInv">Allow expressions with 1/x</param>
453    /// <param name="allowMultipleTerms">Allow expressions which are sums of multiple terms.</param>
454    /// <returns></returns>
455
456    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
457      bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0,
458      IPolicy policy = null,
459      bool collectParameterOptimalModels = false,
460      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
461      bool allowProdOfVars = true,
462      bool allowExp = true,
463      bool allowLog = true,
464      bool allowInv = true,
465      bool allowMultipleTerms = false
466      ) {
467      return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,
468        policy, collectParameterOptimalModels,
469        lowerEstimationLimit, upperEstimationLimit,
470        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
471    }
472
473    // returns the quality of the evaluated solution
474    public static double MakeStep(IState state) {
475      var mctsState = state as State;
476      if (mctsState == null) throw new ArgumentException("state");
477      if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
478
479      return TreeSearch(mctsState);
480    }
481    #endregion
482
483    private static double TreeSearch(State mctsState) {
484      var automaton = mctsState.automaton;
485      var tree = mctsState.tree;
486      var eval = mctsState.evalFun;
487      var rand = mctsState.random;
488      var treePolicy = mctsState.treePolicy;
489      double q = 0;
490      bool success = false;
491      do {
492        automaton.Reset();
493        success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, out q);
494        mctsState.totalRollouts++;
495      } while (!success && !tree.Done);
496      mctsState.effectiveRollouts++;
497
498      if (mctsState.effectiveRollouts % 10 == 1) {
499        //Console.WriteLine(WriteTree(tree));
500        //Console.WriteLine(TraceTree(tree));
501      }
502      return q;
503    }
504
505    private static Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();
506    private static Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();
507    private static Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();
508
509
510
511    // search forward
512    private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
513      out double q) {
514      // ROLLOUT AND EXPANSION
515      // We are navigating a graph (states might be reached via different paths) instead of a tree.
516      // State equivalence is checked through ExprHash (based on the generated code through the path).
517
518      // We switch between rollout-mode and expansion mode
519      // Rollout-mode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB)
520      // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression)
521      // In expansion mode we might re-enter the graph and switch back to rollout-mode
522      // We do this until we reach a complete expression (final state)
523
524      // Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent
525      // Sub-graphs which have been completely searched are marked as done.
526      // Roll-out could lead to a state where all follow-states are done. In this case we call the rollout ineffective.
527
528      while (!automaton.IsFinalState(automaton.CurrentState)) {
529        if (children.ContainsKey(tree)) {
530          if (children[tree].All(ch => ch.Done)) {
531            tree.Done = true;
532            break;
533          }
534          // ROLLOUT INSIDE TREE
535          // UCT selection within tree
536          int selectedIdx = 0;
537          if (children[tree].Count > 1) {
538            selectedIdx = treePolicy.Select(children[tree].Select(ch => ch.actionStatistics), rand);
539          }
540          tree = children[tree][selectedIdx];
541
542          // move the automaton forward until reaching the state
543          // all steps where no alternatives are possible are immediately taken
544          // TODO: simplification of the automaton
545          int[] possibleFollowStates;
546          int nFs;
547          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
548          while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
549            automaton.Goto(possibleFollowStates[0]);
550            automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
551          }
552          Debug.Assert(possibleFollowStates.Contains(tree.state));
553          automaton.Goto(tree.state);
554        } else {
555          // EXPAND
556          int[] possibleFollowStates;
557          int nFs;
558          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
559          while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
560            // no alternatives -> just go to the next state
561            automaton.Goto(possibleFollowStates[0]);
562            automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
563          }
564          if (nFs == 0) {
565            // stuck in a dead end (no final state and no allowed follow states)
566            tree.Done = true;
567            break;
568          }
569          var newChildren = new List<Tree>(nFs);
570          children.Add(tree, newChildren);
571          for (int i = 0; i < nFs; i++) {
572            Tree child = null;
573            // for selected states (EvalStates) we introduce state unification (detection of equivalent states)
574            if (automaton.IsEvalState(possibleFollowStates[i])) {
575              var hc = Hashcode(automaton);
576              if (!nodes.TryGetValue(hc, out child)) {
577                child = new Tree() {
578                  children = null,
579                  state = possibleFollowStates[i],
580                  actionStatistics = treePolicy.CreateActionStatistics(),
581                  expr = string.Empty, // ExprStr(automaton),
582                  level = tree.level + 1
583                };
584                nodes.Add(hc, child);
585              }
586              // only allow forward edges (don't add the child if we would go back in the graph)
587              else if (child.level > tree.level) {
588                // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
589                // to all parents
590                BackpropagateStatistics(child.actionStatistics, tree);
591              } else {
592                // prevent cycles
593                Debug.Assert(child.level <= tree.level);
594                child = null;
595              }
596            } else {
597              child = new Tree() {
598                children = null,
599                state = possibleFollowStates[i],
600                actionStatistics = treePolicy.CreateActionStatistics(),
601                expr = string.Empty, // ExprStr(automaton),
602                level = tree.level + 1
603              };
604            }
605            if (child != null)
606              newChildren.Add(child);
607          }
608
609          if (!newChildren.Any()) {
610            // stuck in a dead end (no final state and no allowed follow states)
611            tree.Done = true;
612            break;
613          }
614
615          foreach (var ch in newChildren) {
616            if (!parents.ContainsKey(ch)) {
617              parents.Add(ch, new List<Tree>());
618            }
619            parents[ch].Add(tree);
620          }
621
622
623          // follow one of the children
624          tree = SelectFinalOrRandom2(automaton, tree, rand);
625          automaton.Goto(tree.state);
626        }
627      }
628
629      bool success;
630
631      // EVALUATE TREE
632      if (automaton.IsFinalState(automaton.CurrentState)) {
633        tree.Done = true;
634        tree.expr = ExprStr(automaton);
635        byte[] code; int nParams;
636        automaton.GetCode(out code, out nParams);
637        q = eval(code, nParams);
638        q = TransformQuality(q);
639        success = true;
640      } else {
641        // we got stuck in roll-out (not evaluation necessary!)
642        q = 0.0;
643        success = false;
644      }
645
646      // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS
647      // Update statistics
648      // Set branch to done if all children are done.
649      BackpropagateQuality(tree, q, treePolicy);
650
651      return success;
652    }
653
654
655    private static double TransformQuality(double q) {
656      // no transformation
657      return q;
658
659      // EXPERIMENTAL!
660      // optimal result: q = 1 -> return huge value
661      // if (q >= 1.0) return 1E16;
662      // // return number of 9s in R²
663      // return -Math.Log10(1 - q);
664    }
665
666    // backpropagate existing statistics to all parents
667    private static void BackpropagateStatistics(IActionStatistics stats, Tree tree) {
668      tree.actionStatistics.Add(stats);
669      if (parents.ContainsKey(tree)) {
670        foreach (var parent in parents[tree]) {
671          BackpropagateStatistics(stats, parent);
672        }
673      }
674    }
675
676    private static ulong Hashcode(Automaton automaton) {
677      byte[] code;
678      int nParams;
679      automaton.GetCode(out code, out nParams);
680      return ExprHash.GetHash(code, nParams);
681    }
682
683    private static void BackpropagateQuality(Tree tree, double q, IPolicy policy) {
684      if (q > 0) policy.Update(tree.actionStatistics, q);
685      if (children.ContainsKey(tree) && children[tree].All(ch => ch.Done)) {
686        tree.Done = true;
687        // children[tree] = null; keep all nodes
688      }
689
690      if (parents.ContainsKey(tree)) {
691        foreach (var parent in parents[tree]) {
692          BackpropagateQuality(parent, q, policy);
693        }
694      }
695    }
696
697    private static Tree SelectFinalOrRandom2(Automaton automaton, Tree tree, IRandom rand) {
698      // if one of the new children leads to a final state then go there
699      // otherwise choose a random child
700      int selectedChildIdx = -1;
701      // find first final state if there is one
702      var children = MctsSymbolicRegressionStatic.children[tree];
703      for (int i = 0; i < children.Count; i++) {
704        if (automaton.IsFinalState(children[i].state)) {
705          selectedChildIdx = i;
706          break;
707        }
708      }
709      // no final state -> select the first child
710      if (selectedChildIdx == -1) {
711        selectedChildIdx = 0;
712      }
713      return children[selectedChildIdx];
714    }
715
716    // tree search might fail because of constraints for expressions
717    // in this case we get stuck we just restart
718    // see ConstraintHandler.cs for more info
719    private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
720      out double q) {
721      Tree selectedChild = null;
722      Contract.Assert(tree.state == automaton.CurrentState);
723      Contract.Assert(!tree.Done);
724      if (tree.children == null) {
725        if (automaton.IsFinalState(tree.state)) {
726          // final state
727          tree.Done = true;
728
729          // EVALUATE
730          byte[] code; int nParams;
731          automaton.GetCode(out code, out nParams);
732          q = eval(code, nParams);
733
734          treePolicy.Update(tree.actionStatistics, q);
735          return true; // we reached a final state
736        } else {
737          // EXPAND
738          int[] possibleFollowStates;
739          int nFs;
740          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
741          if (nFs == 0) {
742            // stuck in a dead end (no final state and no allowed follow states)
743            q = 0;
744            tree.Done = true;
745            tree.children = null;
746            return false;
747          }
748          tree.children = new Tree[nFs];
749          for (int i = 0; i < tree.children.Length; i++)
750            tree.children[i] = new Tree() {
751              children = null,
752              state = possibleFollowStates[i],
753              actionStatistics = treePolicy.CreateActionStatistics()
754            };
755
756          selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
757        }
758      } else {
759        // tree.children != null
760        // UCT selection within tree
761        int selectedIdx = 0;
762        if (tree.children.Length > 1) {
763          selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);
764        }
765        selectedChild = tree.children[selectedIdx];
766      }
767      // make selected step and recurse
768      automaton.Goto(selectedChild.state);
769      var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);
770      if (success) {
771        // only update if successful
772        treePolicy.Update(tree.actionStatistics, q);
773      }
774
775      tree.Done = tree.children.All(ch => ch.Done);
776      if (tree.Done) {
777        tree.children = null; // cut off the sub-branch if it has been fully explored
778      }
779      return success;
780    }
781
782    private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
783      // if one of the new children leads to a final state then go there
784      // otherwise choose a random child
785      int selectedChildIdx = -1;
786      // find first final state if there is one
787      for (int i = 0; i < tree.children.Length; i++) {
788        if (automaton.IsFinalState(tree.children[i].state)) {
789          selectedChildIdx = i;
790          break;
791        }
792      }
793      // no final state -> select a the first child
794      if (selectedChildIdx == -1) {
795        selectedChildIdx = 0;
796      }
797      return tree.children[selectedChildIdx];
798    }
799
800    // scales data and extracts values from dataset into arrays
801    private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,
802      out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
803      xs = new double[problemData.AllowedInputVariables.Count()][];
804
805      var i = 0;
806      if (scaleVariables) {
807        scalingFactor = new double[xs.Length + 1];
808        scalingOffset = new double[xs.Length + 1];
809      } else {
810        scalingFactor = null;
811        scalingOffset = null;
812      }
813      foreach (var var in problemData.AllowedInputVariables) {
814        if (scaleVariables) {
815          var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
816          var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
817          var range = maxX - minX;
818
819          // scaledX = (x - min) / range
820          var sf = 1.0 / range;
821          var offset = -minX / range;
822          scalingFactor[i] = sf;
823          scalingOffset[i] = offset;
824          i++;
825        }
826      }
827
828      if (scaleVariables) {
829        // transform target variable to zero-mean
830        scalingFactor[i] = 1.0;
831        scalingOffset[i] = -problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Average();
832      }
833
834      GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
835    }
836
837    // extract values from dataset into arrays
838    private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,
839     out double[][] xs, out double[] y) {
840      xs = new double[problemData.AllowedInputVariables.Count()][];
841
842      int i = 0;
843      foreach (var var in problemData.AllowedInputVariables) {
844        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
845        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
846        xs[i++] =
847          problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
848      }
849
850      {
851        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
852        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
853        y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(yi => yi * sf + offset).ToArray();
854      }
855    }
856
857    // for debugging only
858
859
860    private static string ExprStr(Automaton automaton) {
861      byte[] code;
862      int nParams;
863      automaton.GetCode(out code, out nParams);
864      return Disassembler.CodeToString(code);
865    }
866
867    private static string WriteStatistics(Tree tree) {
868      var sb = new System.IO.StringWriter();
869      sb.WriteLine("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);
870      if (children.ContainsKey(tree)) {
871        foreach (var ch in children[tree]) {
872          sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);
873        }
874      }
875      return sb.ToString();
876    }
877
878    private static string TraceTree(Tree tree) {
879      var sb = new StringBuilder();
880      sb.Append(
881@"digraph {
882  ratio = fill;
883  node [style=filled];
884");
885      int nodeId = 0;
886
887      TraceTreeRec(tree, 0, sb, ref nodeId);
888      sb.Append("}");
889      return sb.ToString();
890    }
891
892    private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId) {
893      var avgNodeQ = tree.actionStatistics.AverageQuality;
894      var tries = tree.actionStatistics.Tries;
895      if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
896      var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
897
898      sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine();
899
900      var list = new List<Tuple<int, int, Tree>>();
901      if (children.ContainsKey(tree)) {
902        foreach (var ch in children[tree]) {
903          nextId++;
904          avgNodeQ = ch.actionStatistics.AverageQuality;
905          tries = ch.actionStatistics.Tries;
906          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
907          hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
908          sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
909          sb.AppendFormat("{0} -> {1}", parentId, nextId, avgNodeQ).AppendLine();
910          list.Add(Tuple.Create(tries, nextId, ch));
911        }
912        foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {
913          TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId);
914        }
915      }
916    }
917
918    private static string WriteTree(Tree tree) {
919      var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);
920      var nodeIds = new Dictionary<Tree, int>();
921      sb.Write(
922@"digraph {
923  ratio = fill;
924  node [style=filled];
925");
926      int threshold = nodes.Count > 500 ? 10 : 0;
927      foreach (var kvp in children) {
928        var parent = kvp.Key;
929        int parentId;
930        if (!nodeIds.TryGetValue(parent, out parentId)) {
931          parentId = nodeIds.Count + 1;
932          var avgNodeQ = parent.actionStatistics.AverageQuality;
933          var tries = parent.actionStatistics.Tries;
934          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
935          var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
936          if (parent.actionStatistics.Tries > threshold)
937            sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);
938          nodeIds.Add(parent, parentId);
939        }
940        foreach (var child in kvp.Value) {
941          int childId;
942          if (!nodeIds.TryGetValue(child, out childId)) {
943            childId = nodeIds.Count + 1;
944            nodeIds.Add(child, childId);
945          }
946          var avgNodeQ = child.actionStatistics.AverageQuality;
947          var tries = child.actionStatistics.Tries;
948          if (tries < 1) continue;
949          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
950          var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
951          if (tries > threshold) {
952            sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);
953            var edgeLabel = child.expr;
954            // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
955            sb.Write("{0} -> {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel);
956          }
957        }
958      }
959
960      sb.Write("}");
961      return sb.ToString();
962    }
963  }
964}
Note: See TracBrowser for help on using the repository browser.