Free cookie consent management tool by TermsFeed Policy Generator

source: branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs @ 15438

Last change on this file since 15438 was 15438, checked in by gkronber, 7 years ago

#2796 refactoring to simplify the code

File size: 37.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Diagnostics.Contracts;
26using System.Linq;
27using System.Text;
28using HeuristicLab.Core;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Optimization;
31using HeuristicLab.Problems.DataAnalysis;
32using HeuristicLab.Problems.DataAnalysis.Symbolic;
33using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
34using HeuristicLab.Random;
35
36namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
37  public static class MctsSymbolicRegressionStatic {
38    // OBJECTIVES:
39    // 1) solve toy problems without numeric constants (to show that structure search is effective / efficient)
40    //    - e.g. Keijzer, Nguyen ... where no numeric constants are involved
41    //    - assumptions:
42    //      - we don't know the necessary operations or functions -> all available functions could be necessary
43    //      - but we do not need to tune numeric constants -> no scaling of input variables x!
44    // 2) Solve toy problems with numeric constants to make the algorithm invariant concerning variable scale.
45    //    This is important for real world applications.
46    //    - e.g. Korns or Vladislavleva problems where numeric constants are involved
47    //    - assumptions:
48    //      - any numeric constant is possible (a-priori we might assume that small abs. constants are more likely)
49    //      - standardization of variables is possible (or might be necessary) as we adjust numeric parameters of the expression anyway
50    //      - to simplify the problem we can restrict the set of functions e.g. we assume which functions are necessary for the problem instance
51    //        -> several steps: (a) polyinomials, (b) rational polynomials, (c) exponential or logarithmic functions, rational functions with exponential and logarithmic parts
52    // 3) efficiency and effectiveness for real-world problems
53    //    - e.g. Tower problem
54    //    - (1) and (2) combined, structure search must be effective in combination with numeric optimization of constants
55    //   
56
57    // TODO: The samples of x1*... or x2*... do not give any information about the relevance of the interaction term x1*x2 in general!
58    //       --> E.g. if x1, x2 ~ N(0, 1) or U(-1, 1) this is trivial to show
59    //       --> Therefore, looking at rollout statistics for arm selection is useless in the general case!
60    //       --> It is necessary to rely on other features for the arm selection.
61    //       --> TODO: Which heuristics can we apply?
62    // TODO: Solve Poly-10
63    // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?
64    // ~~obsolete TODO: Why is the algorithm so slow for rather greedy policies (e.g. low C value in UCB)?
65    // ~~obsolete TODO: check if we can use a quality measure with range [-1..1] in policies
66    // TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants
67    // TODO: check if transformation of y is correct and works (Obj 2)
68    // TODO: The algorithm is not invariant to location and scale of variables.
69    //       Include offset for variables as parameter (for Objective 2)
70    // TODO: why does LM optimization converge so slowly with exp(x), log(x), and 1/x allowed (Obj 2)?
71    // TODO: support e(-x) and possibly (1/-x) (Obj 1)
72    // TODO: is it OK to initialize all constants to 1 (Obj 2)?
73    // TODO: improve memory usage
74    // TODO: support empty test partition
75    // TODO: the algorithm should be invariant to linear transformations of the space (y = f(x') = f( Ax ) ) for invertible transformations A --> unit tests
76    #region static API
77
78    public interface IState {
79      bool Done { get; }
80      ISymbolicRegressionModel BestModel { get; }
81      double BestSolutionTrainingQuality { get; }
82      double BestSolutionTestQuality { get; }
83      IEnumerable<ISymbolicRegressionSolution> ParetoBestModels { get; }
84      int TotalRollouts { get; }
85      int EffectiveRollouts { get; }
86      int FuncEvaluations { get; }
87      int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
88      // TODO other stats on LM optimizer might be interesting here
89    }
90
91    // created through factory method
92    private class State : IState {
93      private const int MaxParams = 100;
94
95      // state variables used by MCTS
96      internal readonly Automaton automaton;
97      internal IRandom random { get; private set; }
98      internal readonly Tree tree;
99      internal readonly Func<byte[], int, double> evalFun;
100      // MCTS might get stuck. Track statistics on the number of effective rollouts
101      internal int totalRollouts;
102      internal int effectiveRollouts;
103
104
105      // state variables used only internally (for eval function)
106      private readonly IRegressionProblemData problemData;
107      private readonly double[][] x;
108      private readonly double[] y;
109      private readonly double[][] testX;
110      private readonly double[] testY;
111      private readonly double[] scalingFactor;
112      private readonly double[] scalingOffset;
113      private readonly double yStdDev; // for scaling parameters (e.g. stopping condition for LM)
114      private readonly int constOptIterations;
115      private readonly double lambda; // weight of penalty term for regularization
116      private readonly double lowerEstimationLimit, upperEstimationLimit;
117      private readonly bool collectParetoOptimalModels;
118      private readonly List<ISymbolicRegressionSolution> paretoBestModels = new List<ISymbolicRegressionSolution>();
119      private readonly List<double[]> paretoFront = new List<double[]>(); // matching the models
120
121      private readonly ExpressionEvaluator evaluator, testEvaluator;
122
123      internal readonly Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();
124      internal readonly Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();
125      internal readonly Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();
126
127      // values for best solution
128      private double bestR;
129      private byte[] bestCode;
130      private int bestNParams;
131      private double[] bestConsts;
132
133      // stats
134      private int funcEvaluations;
135      private int gradEvaluations;
136
137      // buffers
138      private readonly double[] ones; // vector of ones (as default params)
139      private readonly double[] constsBuf;
140      private readonly double[] predBuf, testPredBuf;
141      private readonly double[][] gradBuf;
142
143      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,
144        int constOptIterations, double lambda,
145        bool collectParetoOptimalModels = false,
146        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
147        bool allowProdOfVars = true,
148        bool allowExp = true,
149        bool allowLog = true,
150        bool allowInv = true,
151        bool allowMultipleTerms = false) {
152
153        if (lambda < 0) throw new ArgumentException("Lambda must be larger or equal zero", "lambda");
154
155        this.problemData = problemData;
156        this.constOptIterations = constOptIterations;
157        this.lambda = lambda;
158        this.evalFun = this.Eval;
159        this.lowerEstimationLimit = lowerEstimationLimit;
160        this.upperEstimationLimit = upperEstimationLimit;
161        this.collectParetoOptimalModels = collectParetoOptimalModels;
162
163        random = new MersenneTwister(randSeed);
164
165        // prepare data for evaluation
166        double[][] x;
167        double[] y;
168        double[][] testX;
169        double[] testY;
170        double[] scalingFactor;
171        double[] scalingOffset;
172        // get training and test datasets (scale linearly based on training set if required)
173        GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
174        GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
175        this.x = x;
176        this.y = y;
177        this.yStdDev = HeuristicLab.Common.EnumerableStatisticExtensions.StandardDeviation(y);
178        this.testX = testX;
179        this.testY = testY;
180        this.scalingFactor = scalingFactor;
181        this.scalingOffset = scalingOffset;
182        this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
183        // we need a separate evaluator because the vector length for the test dataset might differ
184        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
185
186        this.automaton = new Automaton(x, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms, maxVariables);
187        this.tree = new Tree() {
188          state = automaton.CurrentState,
189          expr = "",
190          level = 0
191        };
192
193        // reset best solution
194        this.bestR = 0;
195        // code for default solution (constant model)
196        this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
197        this.bestNParams = 0;
198        this.bestConsts = null;
199
200        // init buffers
201        this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
202        constsBuf = new double[MaxParams];
203        this.predBuf = new double[y.Length];
204        this.testPredBuf = new double[testY.Length];
205
206        this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
207      }
208
209      #region IState inferface
210      public bool Done { get { return tree != null && tree.Done; } }
211
212      public double BestSolutionTrainingQuality {
213        get {
214          evaluator.Exec(bestCode, x, bestConsts, predBuf);
215          return Rho(y, predBuf);
216        }
217      }
218
219      public double BestSolutionTestQuality {
220        get {
221          testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
222          return Rho(testY, testPredBuf);
223        }
224      }
225
226      // takes the code of the best solution and creates and equivalent symbolic regression model
227      public ISymbolicRegressionModel BestModel {
228        get {
229          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
230          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
231
232          var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
233          var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
234          model.Scale(problemData); // apply linear scaling
235          return model;
236        }
237      }
238      public IEnumerable<ISymbolicRegressionSolution> ParetoBestModels {
239        get { return paretoBestModels; }
240      }
241
242      public int TotalRollouts { get { return totalRollouts; } }
243      public int EffectiveRollouts { get { return effectiveRollouts; } }
244      public int FuncEvaluations { get { return funcEvaluations; } }
245      public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
246
247      #endregion
248
249      private double Eval(byte[] code, int nParams) {
250        double[] optConsts;
251        double q;
252        Eval(code, nParams, out q, out optConsts);
253
254        // single objective best
255        if (q > bestR) {
256          bestR = q;
257          bestNParams = nParams;
258          this.bestCode = new byte[code.Length];
259          this.bestConsts = new double[bestNParams];
260
261          Array.Copy(code, bestCode, code.Length);
262          Array.Copy(optConsts, bestConsts, bestNParams);
263        }
264        if (collectParetoOptimalModels) {
265          // multi-objective best
266          var complexity = // SymbolicDataAnalysisModelComplexityCalculator.CalculateComplexity() TODO: implement Kommenda's tree complexity directly in the evaluator
267            Array.FindIndex(code, (opc) => opc == (byte)OpCodes.Exit);  // use length of expression as surrogate for complexity
268          UpdateParetoFront(q, complexity, code, optConsts, nParams, scalingFactor, scalingOffset);
269        }
270        return q;
271      }
272
273      private void Eval(byte[] code, int nParams, out double rho, out double[] optConsts) {
274        // we make a first pass to determine a valid starting configuration for all constants
275        // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
276        // scale and offset are set to optimal starting configuration
277        // assumes scale is the first param and offset is the last param
278
279        // reset constants
280        Array.Copy(ones, constsBuf, nParams);
281        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
282        funcEvaluations++;
283
284        if (nParams == 0 || constOptIterations < 0) {
285          // if we don't need to optimize parameters then we are done
286          // changing scale and offset does not influence r²
287          rho = Rho(y, predBuf);
288          optConsts = constsBuf;
289        } else {
290          // optimize constants using the starting point calculated above
291          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
292
293          evaluator.Exec(code, x, constsBuf, predBuf);
294          funcEvaluations++;
295
296          rho = Rho(y, predBuf);
297          optConsts = constsBuf;
298        }
299      }
300
301
302
303      #region helpers
304      private static double Rho(IEnumerable<double> x, IEnumerable<double> y) {
305        OnlineCalculatorError error;
306        double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
307        return error == OnlineCalculatorError.None ? r : 0.0;
308      }
309
310
311      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
312        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
313        Array.Copy(consts, optConsts, nParams);
314
315        // direct usage of LM is recommended in alglib manual for better performance than the lsfit interface (which uses lm internally).
316        alglib.minlmstate state;
317        alglib.minlmreport rep = null;
318        alglib.minlmcreatevj(y.Length + 1, optConsts, out state); // +1 for penalty term
319        // Using the change of the gradient as stopping criterion is recommended in alglib manual.
320        // However, the most recent version of alglib (as of Oct 2017) only supports epsX as stopping criterion
321        alglib.minlmsetcond(state, epsg: 1E-6 * yStdDev, epsf: epsF, epsx: 0.0, maxits: nIters);
322        // alglib.minlmsetgradientcheck(state, 1E-5);
323        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
324        alglib.minlmresults(state, out optConsts, out rep);
325        funcEvaluations += rep.nfunc;
326        gradEvaluations += rep.njac * nParams;
327
328        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
329
330        // only use optimized constants if successful
331        if (rep.terminationtype >= 0) {
332          Array.Copy(optConsts, consts, optConsts.Length);
333        }
334      }
335
336      private void Func(double[] arg, double[] fi, object obj) {
337        var code = (byte[])obj;
338        int n = predBuf.Length;
339        evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
340        for (int r = 0; r < n; r++) {
341          var res = predBuf[r] - y[r];
342          fi[r] = res;
343        }
344
345        var penaltyIdx = fi.Length - 1;
346        fi[penaltyIdx] = 0.0;
347        // calc length of parameter vector for regularization
348        var aa = 0.0;
349        for (int i = 0; i < arg.Length; i++) {
350          aa += arg[i] * arg[i];
351        }
352        if (lambda > 0 && aa > 0) {
353          // scale lambda using stdDev(y) to make the parameter independent of the scale of y
354          // scale lambda using n to make parameter independent of the number of training points
355          // take the root because LM squares the result
356          fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
357        }
358      }
359
360      private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
361        int n = predBuf.Length;
362        int nParams = arg.Length;
363        var code = (byte[])obj;
364        evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
365        for (int r = 0; r < n; r++) {
366          var res = predBuf[r] - y[r];
367          fi[r] = res;
368
369          for (int k = 0; k < nParams; k++) {
370            jac[r, k] = gradBuf[k][r];
371          }
372        }
373        // calc length of parameter vector for regularization
374        double aa = 0.0;
375        for (int i = 0; i < arg.Length; i++) {
376          aa += arg[i] * arg[i];
377        }
378
379        var penaltyIdx = fi.Length - 1;
380        if (lambda > 0 && aa > 0) {
381          fi[penaltyIdx] = 0.0;
382          // scale lambda using stdDev(y) to make the parameter independent of the scale of y
383          // scale lambda using n to make parameter independent of the number of training points
384          // take the root because alglib LM squares the result
385          fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
386
387          for (int i = 0; i < arg.Length; i++) {
388            jac[penaltyIdx, i] = 0.5 / fi[penaltyIdx] * 2 * n * lambda / yStdDev * arg[i];
389          }
390        } else {
391          fi[penaltyIdx] = 0.0;
392          for (int i = 0; i < arg.Length; i++) {
393            jac[penaltyIdx, i] = 0.0;
394          }
395        }
396      }
397
398
399      private void UpdateParetoFront(double q, int complexity, byte[] code, double[] param, int nParam,
400        double[] scalingFactor, double[] scalingOffset) {
401        double[] best = new double[2];
402        double[] cur = new double[2] { q, complexity };
403        bool[] max = new[] { true, false };
404        var isNonDominated = true;
405        foreach (var e in paretoFront) {
406          var domRes = DominationCalculator<int>.Dominates(cur, e, max, true);
407          if (domRes == DominationResult.IsDominated) {
408            isNonDominated = false;
409            break;
410          }
411        }
412        if (isNonDominated) {
413          paretoFront.Add(cur);
414
415          // create model
416          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
417          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
418
419          var t = new SymbolicExpressionTree(treeGen.Exec(code, param, nParam, scalingFactor, scalingOffset));
420          var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
421          model.Scale(problemData); // apply linear scaling
422
423          var sol = model.CreateRegressionSolution(this.problemData);
424          sol.Name = string.Format("{0:N5} {1}", q, complexity);
425
426          paretoBestModels.Add(sol);
427        }
428        for (int i = paretoFront.Count - 2; i >= 0; i--) {
429          var @ref = paretoFront[i];
430          var domRes = DominationCalculator<int>.Dominates(cur, @ref, max, true);
431          if (domRes == DominationResult.Dominates) {
432            paretoFront.RemoveAt(i);
433            paretoBestModels.RemoveAt(i);
434          }
435        }
436      }
437
438      #endregion
439
440
441    }
442
443
444    /// <summary>
445    /// Static method to initialize a state for the algorithm
446    /// </summary>
447    /// <param name="problemData">The problem data</param>
448    /// <param name="randSeed">Random seed.</param>
449    /// <param name="maxVariables">Maximum number of variable references that are allowed in the expression.</param>
450    /// <param name="scaleVariables">Optionally scale input variables to the interval [0..1] (recommended)</param>
451    /// <param name="constOptIterations">Maximum number of iterations for constants optimization (Levenberg-Marquardt)</param>
452    /// <param name="lambda">Penalty factor for regularization (0..inf.), small penalty disabled regularization.</param>
453    /// <param name="policy">Tree search policy (random, ucb, eps-greedy, ...)</param>
454    /// <param name="collectParameterOptimalModels">Optionally collect all Pareto-optimal solutions having minimal length and error.</param>
455    /// <param name="lowerEstimationLimit">Optionally limit the result of the expression to this lower value.</param>
456    /// <param name="upperEstimationLimit">Optionally limit the result of the expression to this upper value.</param>
457    /// <param name="allowProdOfVars">Allow products of expressions.</param>
458    /// <param name="allowExp">Allow expressions with exponentials.</param>
459    /// <param name="allowLog">Allow expressions with logarithms</param>
460    /// <param name="allowInv">Allow expressions with 1/x</param>
461    /// <param name="allowMultipleTerms">Allow expressions which are sums of multiple terms.</param>
462    /// <returns></returns>
463
464    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
465      bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0,
466      bool collectParameterOptimalModels = false,
467      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
468      bool allowProdOfVars = true,
469      bool allowExp = true,
470      bool allowLog = true,
471      bool allowInv = true,
472      bool allowMultipleTerms = false
473      ) {
474      return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,
475        collectParameterOptimalModels,
476        lowerEstimationLimit, upperEstimationLimit,
477        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
478    }
479
480    // returns the quality of the evaluated solution
481    public static double MakeStep(IState state) {
482      var mctsState = state as State;
483      if (mctsState == null) throw new ArgumentException("state");
484      if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
485
486      return TreeSearch(mctsState);
487    }
488    #endregion
489
490    private static double TreeSearch(State mctsState) {
491      var automaton = mctsState.automaton;
492      var tree = mctsState.tree;
493      var eval = mctsState.evalFun;
494      var rand = mctsState.random;
495      double q = 0;
496      bool success = false;
497      do {
498
499        automaton.Reset();
500        success = TryTreeSearchRec2(rand, tree, automaton, eval, mctsState, out q);
501        mctsState.totalRollouts++;
502      } while (!success && !tree.Done);
503      mctsState.effectiveRollouts++;
504
505#if DEBUG
506      Console.WriteLine(ExprStr(automaton));
507#endif
508      return q;
509    }
510
511    // search forward
512    private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton,
513      Func<byte[], int, double> eval,
514      State state,
515      out double q) {
516      // ROLLOUT AND EXPANSION
517      // We are navigating a graph (states might be reached via different paths) instead of a tree.
518      // State equivalence is checked through ExprHash (based on the generated code through the path).
519
520      // We switch between rollout-mode and expansion mode
521      // Rollout-mode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB)
522      // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression)
523      // In expansion mode we might re-enter the graph and switch back to rollout-mode
524      // We do this until we reach a complete expression (final state)
525
526      // Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent
527      // Sub-graphs which have been completely searched are marked as done.
528      // Roll-out could lead to a state where all follow-states are done. In this case we call the rollout ineffective.
529
530      while (!automaton.IsFinalState(automaton.CurrentState)) {
531        Console.WriteLine(automaton.stateNames[automaton.CurrentState]);
532        if (state.children.ContainsKey(tree)) {
533          if (state.children[tree].All(ch => ch.Done)) {
534            tree.Done = true;
535            break;
536          }
537          // ROLLOUT INSIDE TREE
538          // UCT selection within tree
539          int selectedIdx = 0;
540          if (state.children[tree].Count > 1) {
541            selectedIdx = SelectInternal(state.children[tree], rand);
542          }
543
544          tree = state.children[tree][selectedIdx];
545
546          // move the automaton forward until reaching the state
547          // all steps where no alternatives could be taken immediately (without expanding the tree)
548          // TODO: simplification of the automaton
549          int[] possibleFollowStates = new int[1000];
550          int nFs;
551          automaton.FollowStates(automaton.CurrentState, ref possibleFollowStates, out nFs);
552          Debug.Assert(possibleFollowStates.Contains(tree.state));
553          automaton.Goto(tree.state);
554        } else {
555          // EXPAND
556          int[] possibleFollowStates = new int[1000];
557          int nFs;
558          string actionString = "";
559          automaton.FollowStates(automaton.CurrentState, ref possibleFollowStates, out nFs);
560
561          if (nFs == 0) {
562            // stuck in a dead end (no final state and no allowed follow states)
563            tree.Done = true;
564            break;
565          }
566          var newChildren = new List<Tree>(nFs);
567          state.children.Add(tree, newChildren);
568          for (int i = 0; i < nFs; i++) {
569            Tree child = null;
570            // for selected states (EvalStates) we introduce state unification (detection of equivalent states)
571            if (automaton.IsEvalState(possibleFollowStates[i])) {
572              var hc = Hashcode(automaton); // TODO fix unit test for structure enumeration
573              if (!state.nodes.TryGetValue(hc, out child)) {
574                child = new Tree() {
575                  state = possibleFollowStates[i],
576                  expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
577                  level = tree.level + 1
578                };
579                state.nodes.Add(hc, child);
580              }
581              // only allow forward edges (don't add the child if we would go back in the graph)
582              else if (child.level > tree.level) {
583                // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
584                // to all parents
585                BackpropagateStatistics(tree, state, child.visits);
586              } else {
587                // prevent cycles
588                Debug.Assert(child.level <= tree.level);
589                child = null;
590              }
591            } else {
592              child = new Tree() {
593                state = possibleFollowStates[i],
594                expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
595                level = tree.level + 1
596              };
597            }
598            if (child != null)
599              newChildren.Add(child);
600          }
601
602          if (!newChildren.Any()) {
603            // stuck in a dead end (no final state and no allowed follow states)
604            tree.Done = true;
605            break;
606          }
607
608          foreach (var ch in newChildren) {
609            if (!state.parents.ContainsKey(ch)) {
610              state.parents.Add(ch, new List<Tree>());
611            }
612            state.parents[ch].Add(tree);
613          }
614
615
616          // follow one of the children
617          tree = SelectStateLeadingToFinal(automaton, tree, rand, state);
618          automaton.Goto(tree.state);
619        }
620      }
621
622      bool success;
623
624      // EVALUATE TREE
625      if (automaton.IsFinalState(automaton.CurrentState)) {
626        tree.Done = true;
627        tree.expr = ExprStr(automaton);
628        byte[] code; int nParams;
629        automaton.GetCode(out code, out nParams);
630        q = eval(code, nParams);
631        success = true;
632        BackpropagateQuality(tree, q, state);
633      } else {
634        // we got stuck in roll-out (not evaluation necessary!)
635        q = 0.0;
636        success = false;
637      }
638
639      // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS
640      // Update statistics
641      // Set branch to done if all children are done.
642      BackpropagateDone(tree, state);
643      BackpropagateDebugStats(tree, q, state);
644
645
646      return success;
647    }
648
649    private static int SelectInternal(List<Tree> list, IRandom rand) {
650      // choose a random node.
651      Debug.Assert(list.Any(t => !t.Done));
652
653      var idx = rand.Next(list.Count);
654      while(list[idx].Done) { idx = rand.Next(list.Count); }
655      return idx;
656    }
657
658    // backpropagate existing statistics to all parents
659    private static void BackpropagateStatistics(Tree tree, State state, int numVisits) {
660      tree.visits += numVisits;
661
662      if (state.parents.ContainsKey(tree)) {
663        foreach (var parent in state.parents[tree]) {
664          BackpropagateStatistics(parent, state, numVisits);
665        }
666      }
667    }
668
669    private static ulong Hashcode(Automaton automaton) {
670      byte[] code;
671      int nParams;
672      automaton.GetCode(out code, out nParams);
673      return ExprHash.GetHash(code, nParams);
674    }
675
676    private static void BackpropagateQuality(Tree tree, double q, State state) {
677      tree.visits++;
678      // TODO: q is ignored for now
679
680      if (state.parents.ContainsKey(tree)) {
681        foreach (var parent in state.parents[tree]) {
682          BackpropagateQuality(parent, q, state);
683        }
684      }
685    }
686
687    private static void BackpropagateDone(Tree tree, State state) {
688      if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) {
689        tree.Done = true;
690        // children[tree] = null; keep all nodes
691      }
692
693      if (state.parents.ContainsKey(tree)) {
694        foreach (var parent in state.parents[tree]) {
695          BackpropagateDone(parent, state);
696        }
697      }
698    }
699
700    private static void BackpropagateDebugStats(Tree tree, double q, State state) {
701      if (state.parents.ContainsKey(tree)) {
702        foreach (var parent in state.parents[tree]) {
703          BackpropagateDebugStats(parent, q, state);
704        }
705      }
706
707    }
708
709    private static Tree SelectStateLeadingToFinal(Automaton automaton, Tree tree, IRandom rand, State state) {
710      // find the child with the smallest state value (smaller values are closer to the final state)
711      int selectedChildIdx = 0;
712      var children = state.children[tree];
713      Tree minChild = children.First();
714      for (int i = 1; i < children.Count; i++) {
715        if (children[i].state < minChild.state)
716          selectedChildIdx = i;
717      }
718      return children[selectedChildIdx];
719    }                                           
720
721    // scales data and extracts values from dataset into arrays
722    private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,
723      out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
724      xs = new double[problemData.AllowedInputVariables.Count()][];
725
726      var i = 0;
727      if (scaleVariables) {
728        scalingFactor = new double[xs.Length + 1];
729        scalingOffset = new double[xs.Length + 1];
730      } else {
731        scalingFactor = null;
732        scalingOffset = null;
733      }
734      foreach (var var in problemData.AllowedInputVariables) {
735        if (scaleVariables) {
736          var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
737          var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
738          var range = maxX - minX;
739
740          // scaledX = (x - min) / range
741          var sf = 1.0 / range;
742          var offset = -minX / range;
743          scalingFactor[i] = sf;
744          scalingOffset[i] = offset;
745          i++;
746        }
747      }
748
749      if (scaleVariables) {
750        // transform target variable to zero-mean
751        scalingFactor[i] = 1.0;
752        scalingOffset[i] = -problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Average();
753      }
754
755      GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
756    }
757
758    // extract values from dataset into arrays
759    private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,
760     out double[][] xs, out double[] y) {
761      xs = new double[problemData.AllowedInputVariables.Count()][];
762
763      int i = 0;
764      foreach (var var in problemData.AllowedInputVariables) {
765        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
766        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
767        xs[i++] =
768          problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
769      }
770
771      {
772        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
773        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
774        y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(yi => yi * sf + offset).ToArray();
775      }
776    }
777
778    // for debugging only
779
780
781    private static string ExprStr(Automaton automaton) {
782      byte[] code;
783      int nParams;
784      automaton.GetCode(out code, out nParams);
785      return Disassembler.CodeToString(code);
786    }
787
788    private static string TraceTree(Tree tree, State state) {
789      var sb = new StringBuilder();
790      sb.Append(
791@"digraph {
792  ratio = fill;
793  node [style=filled];
794");
795      int nodeId = 0;
796
797      TraceTreeRec(tree, 0, sb, ref nodeId, state);
798      sb.Append("}");
799      return sb.ToString();
800    }
801
802    private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) {
803      var tries = tree.visits;
804
805      sb.AppendFormat("{0} [label=\"{1}\"]; ", parentId, tries).AppendLine();
806
807      var list = new List<Tuple<int, int, Tree>>();
808      if (state.children.ContainsKey(tree)) {
809        foreach (var ch in state.children[tree]) {
810          nextId++;
811          tries = ch.visits;
812          sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine();
813          sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", parentId, nextId, ch.expr).AppendLine();
814          list.Add(Tuple.Create(tries, nextId, ch));
815        }
816
817        foreach (var tup in list) {
818          var ch = tup.Item3;
819          var chId = tup.Item2;
820          if (state.children.ContainsKey(ch) && state.children[ch].Count == 1) {
821            var chch = state.children[ch].First();
822            nextId++;
823            tries = chch.visits;
824            sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine();
825            sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", chId, nextId, chch.expr).AppendLine();
826          }
827        }
828
829        foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {
830          TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId, state);
831        }
832      }
833    }
834
835    private static string WriteTree(Tree tree, State state) {
836      var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);
837      var nodeIds = new Dictionary<Tree, int>();
838      sb.Write(
839@"digraph {
840  ratio = fill;
841  node [style=filled];
842");
843      int threshold = /* state.nodes.Count > 500 ? 10 : */ 0;
844      foreach (var kvp in state.children) {
845        var parent = kvp.Key;
846        int parentId;
847        if (!nodeIds.TryGetValue(parent, out parentId)) {
848          parentId = nodeIds.Count + 1;
849          var tries = parent.visits;
850          if (tries > threshold)
851            sb.Write("{0} [label=\"{1}\"]; ", parentId, tries);
852          nodeIds.Add(parent, parentId);
853        }
854        foreach (var child in kvp.Value) {
855          int childId;
856          if (!nodeIds.TryGetValue(child, out childId)) {
857            childId = nodeIds.Count + 1;
858            nodeIds.Add(child, childId);
859          }
860          var tries = child.visits;
861          if (tries < 1) continue;
862          if (tries > threshold) {
863            sb.Write("{0} [label=\"{1}\"]; ", childId, tries);
864            var edgeLabel = child.expr;
865            // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
866            sb.Write("{0} -> {1} [label=\"{2}\"]", parentId, childId, edgeLabel);
867          }
868        }
869      }
870
871      sb.Write("}");
872      return sb.ToString();
873    }
874  }
875}
Note: See TracBrowser for help on using the repository browser.