Free cookie consent management tool by TermsFeed Policy Generator

source: branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs @ 15441

Last change on this file since 15441 was 15441, checked in by gkronber, 6 years ago

#2796 more bug fixing

File size: 38.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Linq;
26using System.Text;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Optimization;
30using HeuristicLab.Problems.DataAnalysis;
31using HeuristicLab.Problems.DataAnalysis.Symbolic;
32using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
33using HeuristicLab.Random;
34
35namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
36  public static class MctsSymbolicRegressionStatic {
37    // OBJECTIVES:
38    // 1) solve toy problems without numeric constants (to show that structure search is effective / efficient)
39    //    - e.g. Keijzer, Nguyen ... where no numeric constants are involved
40    //    - assumptions:
41    //      - we don't know the necessary operations or functions -> all available functions could be necessary
42    //      - but we do not need to tune numeric constants -> no scaling of input variables x!
43    // 2) Solve toy problems with numeric constants to make the algorithm invariant concerning variable scale.
44    //    This is important for real world applications.
45    //    - e.g. Korns or Vladislavleva problems where numeric constants are involved
46    //    - assumptions:
47    //      - any numeric constant is possible (a-priori we might assume that small abs. constants are more likely)
48    //      - standardization of variables is possible (or might be necessary) as we adjust numeric parameters of the expression anyway
49    //      - to simplify the problem we can restrict the set of functions e.g. we assume which functions are necessary for the problem instance
50    //        -> several steps: (a) polyinomials, (b) rational polynomials, (c) exponential or logarithmic functions, rational functions with exponential and logarithmic parts
51    // 3) efficiency and effectiveness for real-world problems
52    //    - e.g. Tower problem
53    //    - (1) and (2) combined, structure search must be effective in combination with numeric optimization of constants
54    //   
55
56    // TODO: The samples of x1*... or x2*... do not give any information about the relevance of the interaction term x1*x2 in general!
57    //       --> E.g. if x1, x2 ~ N(0, 1) or U(-1, 1) this is trivial to show
58    //       --> Therefore, looking at rollout statistics for arm selection is useless in the general case!
59    //       --> It is necessary to rely on other features for the arm selection.
60    //       --> TODO: Which heuristics can we apply?
61    // TODO: Solve Poly-10
62    // TODO: rename everything as this is not MCTS anymore
63    // TODO: when a path to an expression is explored first (e.g. x1 + x2)
64    //       and later we find the a longer form x1 + x1 + x2 where the number of variable references
65    //       exceeds the maximum in the automaton this leads to an error (see unit tests)
66    // ~~obsolete TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?
67    // ~~obsolete TODO: Why is the algorithm so slow for rather greedy policies (e.g. low C value in UCB)?
68    // ~~obsolete TODO: check if we can use a quality measure with range [-1..1] in policies
69    // TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants
70    // TODO: check if transformation of y is correct and works (Obj 2)
71    // TODO: The algorithm is not invariant to location and scale of variables.
72    //       Include offset for variables as parameter (for Objective 2)
73    // TODO: why does LM optimization converge so slowly with exp(x), log(x), and 1/x allowed (Obj 2)?
74    // TODO: support e(-x) and possibly (1/-x) (Obj 1)
75    // TODO: is it OK to initialize all constants to 1 (Obj 2)?
76    // TODO: improve memory usage
77    // TODO: analyze / improve perf of ExprHashing (canonical form for expressions)
78    // TODO: support empty test partition
79    // TODO: the algorithm should be invariant to linear transformations of the space (y = f(x') = f( Ax ) ) for invertible transformations A --> unit tests
80    #region static API
81
82    public interface IState {
83      bool Done { get; }
84      ISymbolicRegressionModel BestModel { get; }
85      double BestSolutionTrainingQuality { get; }
86      double BestSolutionTestQuality { get; }
87      IEnumerable<ISymbolicRegressionSolution> ParetoBestModels { get; }
88      int TotalRollouts { get; }
89      int EffectiveRollouts { get; }
90      int FuncEvaluations { get; }
91      int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
92      // TODO other stats on LM optimizer might be interesting here
93    }
94
95    // created through factory method
96    private class State : IState {
97      private const int MaxParams = 100;
98
99      // state variables used by MCTS
100      internal readonly Automaton automaton;
101      internal IRandom random { get; private set; }
102      internal readonly Tree tree;
103      internal readonly Func<byte[], int, double> evalFun;
104      // MCTS might get stuck. Track statistics on the number of effective rollouts
105      internal int totalRollouts;
106      internal int effectiveRollouts;
107
108
109      // state variables used only internally (for eval function)
110      private readonly IRegressionProblemData problemData;
111      private readonly double[][] x;
112      private readonly double[] y;
113      private readonly double[][] testX;
114      private readonly double[] testY;
115      private readonly double[] scalingFactor;
116      private readonly double[] scalingOffset;
117      private readonly double yStdDev; // for scaling parameters (e.g. stopping condition for LM)
118      private readonly int constOptIterations;
119      private readonly double lambda; // weight of penalty term for regularization
120      private readonly double lowerEstimationLimit, upperEstimationLimit;
121      private readonly bool collectParetoOptimalModels;
122      private readonly List<ISymbolicRegressionSolution> paretoBestModels = new List<ISymbolicRegressionSolution>();
123      private readonly List<double[]> paretoFront = new List<double[]>(); // matching the models
124
125      private readonly ExpressionEvaluator evaluator, testEvaluator;
126
127      internal readonly Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();
128      internal readonly Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();
129      internal readonly Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();
130
131      // values for best solution
132      private double bestR;
133      private byte[] bestCode;
134      private int bestNParams;
135      private double[] bestConsts;
136
137      // stats
138      private int funcEvaluations;
139      private int gradEvaluations;
140
141      // buffers
142      private readonly double[] ones; // vector of ones (as default params)
143      private readonly double[] constsBuf;
144      private readonly double[] predBuf, testPredBuf;
145      private readonly double[][] gradBuf;
146
147      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,
148        int constOptIterations, double lambda,
149        bool collectParetoOptimalModels = false,
150        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
151        bool allowProdOfVars = true,
152        bool allowExp = true,
153        bool allowLog = true,
154        bool allowInv = true,
155        bool allowMultipleTerms = false) {
156
157        if (lambda < 0) throw new ArgumentException("Lambda must be larger or equal zero", "lambda");
158
159        this.problemData = problemData;
160        this.constOptIterations = constOptIterations;
161        this.lambda = lambda;
162        this.evalFun = this.Eval;
163        this.lowerEstimationLimit = lowerEstimationLimit;
164        this.upperEstimationLimit = upperEstimationLimit;
165        this.collectParetoOptimalModels = collectParetoOptimalModels;
166
167        random = new MersenneTwister(randSeed);
168
169        // prepare data for evaluation
170        double[][] x;
171        double[] y;
172        double[][] testX;
173        double[] testY;
174        double[] scalingFactor;
175        double[] scalingOffset;
176        // get training and test datasets (scale linearly based on training set if required)
177        GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
178        GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
179        this.x = x;
180        this.y = y;
181        this.yStdDev = HeuristicLab.Common.EnumerableStatisticExtensions.StandardDeviation(y);
182        this.testX = testX;
183        this.testY = testY;
184        this.scalingFactor = scalingFactor;
185        this.scalingOffset = scalingOffset;
186        this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
187        // we need a separate evaluator because the vector length for the test dataset might differ
188        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
189
190        this.automaton = new Automaton(x, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms, maxVariables);
191        this.tree = new Tree() {
192          state = automaton.CurrentState,
193          expr = "",
194          level = 0
195        };
196
197        // reset best solution
198        this.bestR = 0;
199        // code for default solution (constant model)
200        this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
201        this.bestNParams = 0;
202        this.bestConsts = null;
203
204        // init buffers
205        this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
206        constsBuf = new double[MaxParams];
207        this.predBuf = new double[y.Length];
208        this.testPredBuf = new double[testY.Length];
209
210        this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
211      }
212
213      #region IState inferface
214      public bool Done { get { return tree != null && tree.Done; } }
215
216      public double BestSolutionTrainingQuality {
217        get {
218          evaluator.Exec(bestCode, x, bestConsts, predBuf);
219          return Rho(y, predBuf);
220        }
221      }
222
223      public double BestSolutionTestQuality {
224        get {
225          testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
226          return Rho(testY, testPredBuf);
227        }
228      }
229
230      // takes the code of the best solution and creates and equivalent symbolic regression model
231      public ISymbolicRegressionModel BestModel {
232        get {
233          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
234          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
235
236          var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
237          var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
238          model.Scale(problemData); // apply linear scaling
239          return model;
240        }
241      }
242      public IEnumerable<ISymbolicRegressionSolution> ParetoBestModels {
243        get { return paretoBestModels; }
244      }
245
246      public int TotalRollouts { get { return totalRollouts; } }
247      public int EffectiveRollouts { get { return effectiveRollouts; } }
248      public int FuncEvaluations { get { return funcEvaluations; } }
249      public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
250
251      #endregion
252
253
254#if DEBUG
255      public string ExprStr(Automaton automaton) {
256        byte[] code;
257        int nParams;
258        automaton.GetCode(out code, out nParams);
259        var generator = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
260        var @params = Enumerable.Repeat(1.0, nParams).ToArray();
261        var root = generator.Exec(code, @params, nParams, null, null);
262        var formatter = new InfixExpressionFormatter();
263        return formatter.Format(new SymbolicExpressionTree(root));
264      }
265#endif
266
267      private double Eval(byte[] code, int nParams) {
268        double[] optConsts;
269        double q;
270        Eval(code, nParams, out q, out optConsts);
271
272        // single objective best
273        if (q > bestR) {
274          bestR = q;
275          bestNParams = nParams;
276          this.bestCode = new byte[code.Length];
277          this.bestConsts = new double[bestNParams];
278
279          Array.Copy(code, bestCode, code.Length);
280          Array.Copy(optConsts, bestConsts, bestNParams);
281        }
282        if (collectParetoOptimalModels) {
283          // multi-objective best
284          var complexity = // SymbolicDataAnalysisModelComplexityCalculator.CalculateComplexity() TODO: implement Kommenda's tree complexity directly in the evaluator
285            Array.FindIndex(code, (opc) => opc == (byte)OpCodes.Exit);  // use length of expression as surrogate for complexity
286          UpdateParetoFront(q, complexity, code, optConsts, nParams, scalingFactor, scalingOffset);
287        }
288        return q;
289      }
290
291      private void Eval(byte[] code, int nParams, out double rho, out double[] optConsts) {
292        // we make a first pass to determine a valid starting configuration for all constants
293        // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
294        // scale and offset are set to optimal starting configuration
295        // assumes scale is the first param and offset is the last param
296
297        // reset constants
298        Array.Copy(ones, constsBuf, nParams);
299        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
300        funcEvaluations++;
301
302        if (nParams == 0 || constOptIterations < 0) {
303          // if we don't need to optimize parameters then we are done
304          // changing scale and offset does not influence r²
305          rho = Rho(y, predBuf);
306          optConsts = constsBuf;
307        } else {
308          // optimize constants using the starting point calculated above
309          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
310
311          evaluator.Exec(code, x, constsBuf, predBuf);
312          funcEvaluations++;
313
314          rho = Rho(y, predBuf);
315          optConsts = constsBuf;
316        }
317      }
318
319
320
321      #region helpers
322      private static double Rho(IEnumerable<double> x, IEnumerable<double> y) {
323        OnlineCalculatorError error;
324        double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
325        return error == OnlineCalculatorError.None ? r : 0.0;
326      }
327
328
329      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
330        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
331        Array.Copy(consts, optConsts, nParams);
332
333        // direct usage of LM is recommended in alglib manual for better performance than the lsfit interface (which uses lm internally).
334        alglib.minlmstate state;
335        alglib.minlmreport rep = null;
336        alglib.minlmcreatevj(y.Length + 1, optConsts, out state); // +1 for penalty term
337        // Using the change of the gradient as stopping criterion is recommended in alglib manual.
338        // However, the most recent version of alglib (as of Oct 2017) only supports epsX as stopping criterion
339        alglib.minlmsetcond(state, epsg: 1E-6 * yStdDev, epsf: epsF, epsx: 0.0, maxits: nIters);
340        // alglib.minlmsetgradientcheck(state, 1E-5);
341        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
342        alglib.minlmresults(state, out optConsts, out rep);
343        funcEvaluations += rep.nfunc;
344        gradEvaluations += rep.njac * nParams;
345
346        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
347
348        // only use optimized constants if successful
349        if (rep.terminationtype >= 0) {
350          Array.Copy(optConsts, consts, optConsts.Length);
351        }
352      }
353
354      private void Func(double[] arg, double[] fi, object obj) {
355        var code = (byte[])obj;
356        int n = predBuf.Length;
357        evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
358        for (int r = 0; r < n; r++) {
359          var res = predBuf[r] - y[r];
360          fi[r] = res;
361        }
362
363        var penaltyIdx = fi.Length - 1;
364        fi[penaltyIdx] = 0.0;
365        // calc length of parameter vector for regularization
366        var aa = 0.0;
367        for (int i = 0; i < arg.Length; i++) {
368          aa += arg[i] * arg[i];
369        }
370        if (lambda > 0 && aa > 0) {
371          // scale lambda using stdDev(y) to make the parameter independent of the scale of y
372          // scale lambda using n to make parameter independent of the number of training points
373          // take the root because LM squares the result
374          fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
375        }
376      }
377
378      private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
379        int n = predBuf.Length;
380        int nParams = arg.Length;
381        var code = (byte[])obj;
382        evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
383        for (int r = 0; r < n; r++) {
384          var res = predBuf[r] - y[r];
385          fi[r] = res;
386
387          for (int k = 0; k < nParams; k++) {
388            jac[r, k] = gradBuf[k][r];
389          }
390        }
391        // calc length of parameter vector for regularization
392        double aa = 0.0;
393        for (int i = 0; i < arg.Length; i++) {
394          aa += arg[i] * arg[i];
395        }
396
397        var penaltyIdx = fi.Length - 1;
398        if (lambda > 0 && aa > 0) {
399          fi[penaltyIdx] = 0.0;
400          // scale lambda using stdDev(y) to make the parameter independent of the scale of y
401          // scale lambda using n to make parameter independent of the number of training points
402          // take the root because alglib LM squares the result
403          fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
404
405          for (int i = 0; i < arg.Length; i++) {
406            jac[penaltyIdx, i] = 0.5 / fi[penaltyIdx] * 2 * n * lambda / yStdDev * arg[i];
407          }
408        } else {
409          fi[penaltyIdx] = 0.0;
410          for (int i = 0; i < arg.Length; i++) {
411            jac[penaltyIdx, i] = 0.0;
412          }
413        }
414      }
415
416
417      private void UpdateParetoFront(double q, int complexity, byte[] code, double[] param, int nParam,
418        double[] scalingFactor, double[] scalingOffset) {
419        double[] best = new double[2];
420        double[] cur = new double[2] { q, complexity };
421        bool[] max = new[] { true, false };
422        var isNonDominated = true;
423        foreach (var e in paretoFront) {
424          var domRes = DominationCalculator<int>.Dominates(cur, e, max, true);
425          if (domRes == DominationResult.IsDominated) {
426            isNonDominated = false;
427            break;
428          }
429        }
430        if (isNonDominated) {
431          paretoFront.Add(cur);
432
433          // create model
434          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
435          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
436
437          var t = new SymbolicExpressionTree(treeGen.Exec(code, param, nParam, scalingFactor, scalingOffset));
438          var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
439          model.Scale(problemData); // apply linear scaling
440
441          var sol = model.CreateRegressionSolution(this.problemData);
442          sol.Name = string.Format("{0:N5} {1}", q, complexity);
443
444          paretoBestModels.Add(sol);
445        }
446        for (int i = paretoFront.Count - 2; i >= 0; i--) {
447          var @ref = paretoFront[i];
448          var domRes = DominationCalculator<int>.Dominates(cur, @ref, max, true);
449          if (domRes == DominationResult.Dominates) {
450            paretoFront.RemoveAt(i);
451            paretoBestModels.RemoveAt(i);
452          }
453        }
454      }
455
456      #endregion
457
458
459    }
460
461
462    /// <summary>
463    /// Static method to initialize a state for the algorithm
464    /// </summary>
465    /// <param name="problemData">The problem data</param>
466    /// <param name="randSeed">Random seed.</param>
467    /// <param name="maxVariables">Maximum number of variable references that are allowed in the expression.</param>
468    /// <param name="scaleVariables">Optionally scale input variables to the interval [0..1] (recommended)</param>
469    /// <param name="constOptIterations">Maximum number of iterations for constants optimization (Levenberg-Marquardt)</param>
470    /// <param name="lambda">Penalty factor for regularization (0..inf.), small penalty disabled regularization.</param>
471    /// <param name="policy">Tree search policy (random, ucb, eps-greedy, ...)</param>
472    /// <param name="collectParameterOptimalModels">Optionally collect all Pareto-optimal solutions having minimal length and error.</param>
473    /// <param name="lowerEstimationLimit">Optionally limit the result of the expression to this lower value.</param>
474    /// <param name="upperEstimationLimit">Optionally limit the result of the expression to this upper value.</param>
475    /// <param name="allowProdOfVars">Allow products of expressions.</param>
476    /// <param name="allowExp">Allow expressions with exponentials.</param>
477    /// <param name="allowLog">Allow expressions with logarithms</param>
478    /// <param name="allowInv">Allow expressions with 1/x</param>
479    /// <param name="allowMultipleTerms">Allow expressions which are sums of multiple terms.</param>
480    /// <returns></returns>
481
482    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
483      bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0,
484      bool collectParameterOptimalModels = false,
485      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
486      bool allowProdOfVars = true,
487      bool allowExp = true,
488      bool allowLog = true,
489      bool allowInv = true,
490      bool allowMultipleTerms = false
491      ) {
492      return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,
493        collectParameterOptimalModels,
494        lowerEstimationLimit, upperEstimationLimit,
495        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
496    }
497
498    // returns the quality of the evaluated solution
499    public static double MakeStep(IState state) {
500      var mctsState = state as State;
501      if (mctsState == null) throw new ArgumentException("state");
502      if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
503
504      return TreeSearch(mctsState);
505    }
506    #endregion
507
508    private static double TreeSearch(State mctsState) {
509      var automaton = mctsState.automaton;
510      var tree = mctsState.tree;
511      var eval = mctsState.evalFun;
512      var rand = mctsState.random;
513      double q = 0;
514      bool success = false;
515      do {
516
517        automaton.Reset();
518        success = TryTreeSearchRec2(rand, tree, automaton, eval, mctsState, out q);
519        mctsState.totalRollouts++;
520      } while (!success && !tree.Done);
521      if (success) {
522        mctsState.effectiveRollouts++;
523
524#if DEBUG
525        Console.WriteLine(mctsState.ExprStr(automaton));
526#endif
527
528        return q;
529      } else return 0.0;
530    }
531
532    // search forward
533    private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton,
534      Func<byte[], int, double> eval,
535      State state,
536      out double q) {
537      // ROLLOUT AND EXPANSION
538      // We are navigating a graph (states might be reached via different paths) instead of a tree.
539      // State equivalence is checked through ExprHash (based on the generated code through the path).
540
541      // We switch between rollout-mode and expansion mode
542      // Rollout-mode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB)
543      // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression)
544      // In expansion mode we might re-enter the graph and switch back to rollout-mode
545      // We do this until we reach a complete expression (final state)
546
547      // Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent
548      // Sub-graphs which have been completely searched are marked as done.
549      // Roll-out could lead to a state where all follow-states are done. In this case we call the rollout ineffective.
550
551      while (!automaton.IsFinalState(automaton.CurrentState)) {
552        // Console.WriteLine(automaton.stateNames[automaton.CurrentState]);
553        if (state.children.ContainsKey(tree)) {
554          if (state.children[tree].All(ch => ch.Done)) {
555            tree.Done = true;
556            break;
557          }
558          // ROLLOUT INSIDE TREE
559          // UCT selection within tree
560          int selectedIdx = 0;
561          if (state.children[tree].Count > 1) {
562            selectedIdx = SelectInternal(state.children[tree], rand);
563          }
564
565          tree = state.children[tree][selectedIdx];
566
567          // all steps where no alternatives could be taken immediately (without expanding the tree)
568          // TODO: simplification of the automaton
569          int[] possibleFollowStates = new int[1000];
570          int nFs;
571          automaton.FollowStates(automaton.CurrentState, ref possibleFollowStates, out nFs);
572          Debug.Assert(possibleFollowStates.Contains(tree.state));
573          automaton.Goto(tree.state);
574        } else {
575          // EXPAND
576          int[] possibleFollowStates = new int[1000];
577          int nFs;
578          string actionString = "";
579          automaton.FollowStates(automaton.CurrentState, ref possibleFollowStates, out nFs);
580
581          if (nFs == 0) {
582            // stuck in a dead end (no final state and no allowed follow states)
583            tree.Done = true;
584            break;
585          }
586          var newChildren = new List<Tree>(nFs);
587          state.children.Add(tree, newChildren);
588          for (int i = 0; i < nFs; i++) {
589            Tree child = null;
590            // for selected states (EvalStates) we introduce state unification (detection of equivalent states)
591            if (automaton.IsEvalState(possibleFollowStates[i])) {
592              var hc = Hashcode(automaton);
593              hc = ((hc << 5) + hc) ^ (ulong)tree.state; // TODO fix unit test for structure enumeration
594              if (!state.nodes.TryGetValue(hc, out child)) {
595                // Console.WriteLine("New expression (hash: {0}, state: {1})", Hashcode(automaton), automaton.stateNames[possibleFollowStates[i]]);
596                child = new Tree() {
597                  state = possibleFollowStates[i],
598                  expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
599                  level = tree.level + 1
600                };
601                state.nodes.Add(hc, child);
602              }
603              // only allow forward edges (don't add the child if we would go back in the graph)
604              else if (child.level > tree.level) {
605                // Console.WriteLine("Existing expression (hash: {0}, state: {1})", Hashcode(automaton), automaton.stateNames[possibleFollowStates[i]]);
606                // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
607                // to all parents
608                BackpropagateStatistics(tree, state, child.visits);
609              } else {
610                // Console.WriteLine("Cycle (hash: {0}, state: {1})", Hashcode(automaton), automaton.stateNames[possibleFollowStates[i]]);
611                // prevent cycles
612                Debug.Assert(child.level <= tree.level);
613                child = null;
614              }
615            } else {
616              child = new Tree() {
617                state = possibleFollowStates[i],
618                expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
619                level = tree.level + 1
620              };
621            }
622            if (child != null)
623              newChildren.Add(child);
624          }
625
626          if (!newChildren.Any()) {
627            // stuck in a dead end (no final state and no allowed follow states)
628            tree.Done = true;
629            break;
630          }
631
632          foreach (var ch in newChildren) {
633            if (!state.parents.ContainsKey(ch)) {
634              state.parents.Add(ch, new List<Tree>());
635            }
636            state.parents[ch].Add(tree);
637          }
638
639
640          // follow one of the children
641          tree = SelectStateLeadingToFinal(automaton, tree, rand, state);
642          automaton.Goto(tree.state);
643        }
644      }
645
646      bool success;
647
648      // EVALUATE TREE
649      if (!tree.Done && automaton.IsFinalState(automaton.CurrentState)) {
650        tree.Done = true;
651        tree.expr = state.ExprStr(automaton);
652        byte[] code; int nParams;
653        automaton.GetCode(out code, out nParams);
654        q = eval(code, nParams);
655        success = true;
656        BackpropagateQuality(tree, q, state);
657      } else {
658        // we got stuck in roll-out (not evaluation necessary!)
659        q = 0.0;
660        success = false;
661      }
662
663      // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS
664      // Update statistics
665      // Set branch to done if all children are done.
666      BackpropagateDone(tree, state);
667      BackpropagateDebugStats(tree, q, state);
668
669
670      return success;
671    }
672
673    private static int SelectInternal(List<Tree> list, IRandom rand) {
674      Debug.Assert(list.Any(t => !t.Done));
675     
676      // check if there is any node which has not been visited
677      for(int i=0;i<list.Count;i++) {
678        if (!list[i].Done && list[i].visits == 0) return i;
679      }
680
681      // choose a random node.
682      var idx = rand.Next(list.Count);
683      while (list[idx].Done) { idx = rand.Next(list.Count); }
684      return idx;
685    }
686
687    // backpropagate existing statistics to all parents
688    private static void BackpropagateStatistics(Tree tree, State state, int numVisits) {
689      tree.visits += numVisits;
690
691      if (state.parents.ContainsKey(tree)) {
692        foreach (var parent in state.parents[tree]) {
693          BackpropagateStatistics(parent, state, numVisits);
694        }
695      }
696    }
697
698    private static ulong Hashcode(Automaton automaton) {
699      byte[] code;
700      int nParams;
701      automaton.GetCode(out code, out nParams);
702      return (ulong)ExprHashSymbolic.GetHash(code, nParams);
703    }
704
705    private static void BackpropagateQuality(Tree tree, double q, State state) {
706      tree.visits++;
707      // TODO: q is ignored for now
708
709      if (state.parents.ContainsKey(tree)) {
710        foreach (var parent in state.parents[tree]) {
711          BackpropagateQuality(parent, q, state);
712        }
713      }
714    }
715
716    private static void BackpropagateDone(Tree tree, State state) {
717      if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) {
718        tree.Done = true;
719        // children[tree] = null; keep all nodes
720      }
721
722      if (state.parents.ContainsKey(tree)) {
723        foreach (var parent in state.parents[tree]) {
724          BackpropagateDone(parent, state);
725        }
726      }
727    }
728
729    private static void BackpropagateDebugStats(Tree tree, double q, State state) {
730      if (state.parents.ContainsKey(tree)) {
731        foreach (var parent in state.parents[tree]) {
732          BackpropagateDebugStats(parent, q, state);
733        }
734      }
735
736    }
737
738    private static Tree SelectStateLeadingToFinal(Automaton automaton, Tree tree, IRandom rand, State state) {
739      // find the child with the smallest state value (smaller values are closer to the final state)
740      int selectedChildIdx = 0;
741      var children = state.children[tree];
742      Tree minChild = children.First();
743      for (int i = 1; i < children.Count; i++) {
744        if (children[i].state < minChild.state)
745          selectedChildIdx = i;
746      }
747      return children[selectedChildIdx];
748    }
749
750    // scales data and extracts values from dataset into arrays
751    private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,
752      out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
753      xs = new double[problemData.AllowedInputVariables.Count()][];
754
755      var i = 0;
756      if (scaleVariables) {
757        scalingFactor = new double[xs.Length + 1];
758        scalingOffset = new double[xs.Length + 1];
759      } else {
760        scalingFactor = null;
761        scalingOffset = null;
762      }
763      foreach (var var in problemData.AllowedInputVariables) {
764        if (scaleVariables) {
765          var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
766          var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
767          var range = maxX - minX;
768
769          // scaledX = (x - min) / range
770          var sf = 1.0 / range;
771          var offset = -minX / range;
772          scalingFactor[i] = sf;
773          scalingOffset[i] = offset;
774          i++;
775        }
776      }
777
778      if (scaleVariables) {
779        // transform target variable to zero-mean
780        scalingFactor[i] = 1.0;
781        scalingOffset[i] = -problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Average();
782      }
783
784      GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
785    }
786
787    // extract values from dataset into arrays
788    private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,
789     out double[][] xs, out double[] y) {
790      xs = new double[problemData.AllowedInputVariables.Count()][];
791
792      int i = 0;
793      foreach (var var in problemData.AllowedInputVariables) {
794        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
795        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
796        xs[i++] =
797          problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
798      }
799
800      {
801        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
802        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
803        y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(yi => yi * sf + offset).ToArray();
804      }
805    }
806
807    // for debugging only
808
809
810    private static string TraceTree(Tree tree, State state) {
811      var sb = new StringBuilder();
812      sb.Append(
813@"digraph {
814  ratio = fill;
815  node [style=filled];
816");
817      int nodeId = 0;
818
819      TraceTreeRec(tree, 0, sb, ref nodeId, state);
820      sb.Append("}");
821      return sb.ToString();
822    }
823
824    private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) {
825      var tries = tree.visits;
826
827      sb.AppendFormat("{0} [label=\"{1}\"]; ", parentId, tries).AppendLine();
828
829      var list = new List<Tuple<int, int, Tree>>();
830      if (state.children.ContainsKey(tree)) {
831        foreach (var ch in state.children[tree]) {
832          nextId++;
833          tries = ch.visits;
834          sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine();
835          sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", parentId, nextId, ch.expr).AppendLine();
836          list.Add(Tuple.Create(tries, nextId, ch));
837        }
838
839        foreach (var tup in list) {
840          var ch = tup.Item3;
841          var chId = tup.Item2;
842          if (state.children.ContainsKey(ch) && state.children[ch].Count == 1) {
843            var chch = state.children[ch].First();
844            nextId++;
845            tries = chch.visits;
846            sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine();
847            sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", chId, nextId, chch.expr).AppendLine();
848          }
849        }
850
851        foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {
852          TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId, state);
853        }
854      }
855    }
856
857    private static string WriteTree(Tree tree, State state) {
858      var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);
859      var nodeIds = new Dictionary<Tree, int>();
860      sb.Write(
861@"digraph {
862  ratio = fill;
863  node [style=filled];
864");
865      int threshold = /* state.nodes.Count > 500 ? 10 : */ 0;
866      foreach (var kvp in state.children) {
867        var parent = kvp.Key;
868        int parentId;
869        if (!nodeIds.TryGetValue(parent, out parentId)) {
870          parentId = nodeIds.Count + 1;
871          var tries = parent.visits;
872          if (tries > threshold)
873            sb.Write("{0} [label=\"{1}\"]; ", parentId, tries);
874          nodeIds.Add(parent, parentId);
875        }
876        foreach (var child in kvp.Value) {
877          int childId;
878          if (!nodeIds.TryGetValue(child, out childId)) {
879            childId = nodeIds.Count + 1;
880            nodeIds.Add(child, childId);
881          }
882          var tries = child.visits;
883          if (tries < 1) continue;
884          if (tries > threshold) {
885            sb.Write("{0} [label=\"{1}\"]; ", childId, tries);
886            var edgeLabel = child.expr;
887            // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
888            sb.Write("{0} -> {1} [label=\"{2}\"]", parentId, childId, edgeLabel);
889          }
890        }
891      }
892
893      sb.Write("}");
894      return sb.ToString();
895    }
896  }
897}
Note: See TracBrowser for help on using the repository browser.