Free cookie consent management tool by TermsFeed Policy Generator

source: branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs @ 15416

Last change on this file since 15416 was 15416, checked in by gkronber, 5 years ago

#2796 worked on MCTS for symbreg

File size: 43.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Diagnostics.Contracts;
26using System.Linq;
27using System.Text;
28using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
29using HeuristicLab.Core;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31using HeuristicLab.Optimization;
32using HeuristicLab.Problems.DataAnalysis;
33using HeuristicLab.Problems.DataAnalysis.Symbolic;
34using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
35using HeuristicLab.Random;
36
37namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
38  public static class MctsSymbolicRegressionStatic {
39    // OBJECTIVES:
40    // 1) solve toy problems without numeric constants (to show that structure search is effective / efficient)
41    //    - e.g. Keijzer, Nguyen ... where no numeric constants are involved
42    //    - assumptions:
43    //      - we don't know the necessary operations or functions -> all available functions could be necessary
44    //      - but we do not need to tune numeric constants -> no scaling of input variables x!
45    // 2) Solve toy problems with numeric constants to make the algorithm invariant concerning variable scale.
46    //    This is important for real world applications.
47    //    - e.g. Korns or Vladislavleva problems where numeric constants are involved
48    //    - assumptions:
49    //      - any numeric constant is possible (a-priori we might assume that small abs. constants are more likely)
50    //      - standardization of variables is possible (or might be necessary) as we adjust numeric parameters of the expression anyway
51    //      - to simplify the problem we can restrict the set of functions e.g. we assume which functions are necessary for the problem instance
52    //        -> several steps: (a) polyinomials, (b) rational polynomials, (c) exponential or logarithmic functions, rational functions with exponential and logarithmic parts
53    // 3) efficiency and effectiveness for real-world problems
54    //    - e.g. Tower problem
55    //    - (1) and (2) combined, structure search must be effective in combination with numeric optimization of constants
56    //   
57
58    // TODO: Taking averages of R² values is probably not ideal as an improvement of R² from 0.99 to 0.999 should
59    //       weight more than an improvement from 0.98 to 0.99. Also, we are more interested in the best value of a
60    //       branch and less in the expected value. (--> Review "Extreme Bandit" literature again)
61    // TODO: Solve Poly-10
62    // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?
63    // TODO: Why is the algorithm so slow for rather greedy policies (e.g. low C value in UCB)?
64    // TODO: check if we can use a quality measure with range [-1..1] in policies
65    // TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants
66    // TODO: check if transformation of y is correct and works (Obj 2)
67    // TODO: The algorithm is not invariant to location and scale of variables.
68    //       Include offset for variables as parameter (for Objective 2)
69    // TODO: why does LM optimization converge so slowly with exp(x), log(x), and 1/x allowed (Obj 2)?
70    // TODO: support e(-x) and possibly (1/-x) (Obj 1)
71    // TODO: is it OK to initialize all constants to 1 (Obj 2)?
72    // TODO: improve memory usage
73    #region static API
74
75    public interface IState {
76      bool Done { get; }
77      ISymbolicRegressionModel BestModel { get; }
78      double BestSolutionTrainingQuality { get; }
79      double BestSolutionTestQuality { get; }
80      IEnumerable<ISymbolicRegressionSolution> ParetoBestModels { get; }
81      int TotalRollouts { get; }
82      int EffectiveRollouts { get; }
83      int FuncEvaluations { get; }
84      int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
85      // TODO other stats on LM optimizer might be interesting here
86    }
87
88    // created through factory method
89    private class State : IState {
90      private const int MaxParams = 100;
91
92      // state variables used by MCTS
93      internal readonly Automaton automaton;
94      internal IRandom random { get; private set; }
95      internal readonly Tree tree;
96      internal readonly Func<byte[], int, double> evalFun;
97      internal readonly IPolicy treePolicy;
98      // MCTS might get stuck. Track statistics on the number of effective rollouts
99      internal int totalRollouts;
100      internal int effectiveRollouts;
101
102
103      // state variables used only internally (for eval function)
104      private readonly IRegressionProblemData problemData;
105      private readonly double[][] x;
106      private readonly double[] y;
107      private readonly double[][] testX;
108      private readonly double[] testY;
109      private readonly double[] scalingFactor;
110      private readonly double[] scalingOffset;
111      private readonly double yStdDev; // for scaling parameters (e.g. stopping condition for LM)
112      private readonly int constOptIterations;
113      private readonly double lambda; // weight of penalty term for regularization
114      private readonly double lowerEstimationLimit, upperEstimationLimit;
115      private readonly bool collectParetoOptimalModels;
116      private readonly List<ISymbolicRegressionSolution> paretoBestModels = new List<ISymbolicRegressionSolution>();
117      private readonly List<double[]> paretoFront = new List<double[]>(); // matching the models
118
119      private readonly ExpressionEvaluator evaluator, testEvaluator;
120
121      internal readonly Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();
122      internal readonly Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();
123      internal readonly Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();
124
125      // values for best solution
126      private double bestR;
127      private byte[] bestCode;
128      private int bestNParams;
129      private double[] bestConsts;
130
131      // stats
132      private int funcEvaluations;
133      private int gradEvaluations;
134
135      // buffers
136      private readonly double[] ones; // vector of ones (as default params)
137      private readonly double[] constsBuf;
138      private readonly double[] predBuf, testPredBuf;
139      private readonly double[][] gradBuf;
140
141      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,
142        int constOptIterations, double lambda,
143        IPolicy treePolicy = null,
144        bool collectParetoOptimalModels = false,
145        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
146        bool allowProdOfVars = true,
147        bool allowExp = true,
148        bool allowLog = true,
149        bool allowInv = true,
150        bool allowMultipleTerms = false) {
151
152        if (lambda < 0) throw new ArgumentException("Lambda must be larger or equal zero", "lambda");
153
154        this.problemData = problemData;
155        this.constOptIterations = constOptIterations;
156        this.lambda = lambda;
157        this.evalFun = this.Eval;
158        this.lowerEstimationLimit = lowerEstimationLimit;
159        this.upperEstimationLimit = upperEstimationLimit;
160        this.collectParetoOptimalModels = collectParetoOptimalModels;
161
162        random = new MersenneTwister(randSeed);
163
164        // prepare data for evaluation
165        double[][] x;
166        double[] y;
167        double[][] testX;
168        double[] testY;
169        double[] scalingFactor;
170        double[] scalingOffset;
171        // get training and test datasets (scale linearly based on training set if required)
172        GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
173        GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
174        this.x = x;
175        this.y = y;
176        this.yStdDev = HeuristicLab.Common.EnumerableStatisticExtensions.StandardDeviation(y);
177        this.testX = testX;
178        this.testY = testY;
179        this.scalingFactor = scalingFactor;
180        this.scalingOffset = scalingOffset;
181        this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
182        // we need a separate evaluator because the vector length for the test dataset might differ
183        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
184
185        this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
186        this.treePolicy = treePolicy ?? new Ucb();
187        this.tree = new Tree() {
188          state = automaton.CurrentState,
189          actionStatistics = treePolicy.CreateActionStatistics(),
190          expr = "",
191          level = 0
192        };
193
194        // reset best solution
195        this.bestR = 0;
196        // code for default solution (constant model)
197        this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
198        this.bestNParams = 0;
199        this.bestConsts = null;
200
201        // init buffers
202        this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
203        constsBuf = new double[MaxParams];
204        this.predBuf = new double[y.Length];
205        this.testPredBuf = new double[testY.Length];
206
207        this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
208      }
209
210      #region IState inferface
211      public bool Done { get { return tree != null && tree.Done; } }
212
213      public double BestSolutionTrainingQuality {
214        get {
215          evaluator.Exec(bestCode, x, bestConsts, predBuf);
216          return Rho(y, predBuf);
217        }
218      }
219
220      public double BestSolutionTestQuality {
221        get {
222          testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
223          return Rho(testY, testPredBuf);
224        }
225      }
226
227      // takes the code of the best solution and creates and equivalent symbolic regression model
228      public ISymbolicRegressionModel BestModel {
229        get {
230          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
231          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
232
233          var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
234          var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
235          model.Scale(problemData); // apply linear scaling
236          return model;
237        }
238      }
239      public IEnumerable<ISymbolicRegressionSolution> ParetoBestModels {
240        get { return paretoBestModels; }
241      }
242
243      public int TotalRollouts { get { return totalRollouts; } }
244      public int EffectiveRollouts { get { return effectiveRollouts; } }
245      public int FuncEvaluations { get { return funcEvaluations; } }
246      public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
247
248      #endregion
249
250      private double Eval(byte[] code, int nParams) {
251        double[] optConsts;
252        double q;
253        Eval(code, nParams, out q, out optConsts);
254
255        // single objective best
256        if (q > bestR) {
257          bestR = q;
258          bestNParams = nParams;
259          this.bestCode = new byte[code.Length];
260          this.bestConsts = new double[bestNParams];
261
262          Array.Copy(code, bestCode, code.Length);
263          Array.Copy(optConsts, bestConsts, bestNParams);
264        }
265        if (collectParetoOptimalModels) {
266          // multi-objective best
267          var complexity = // SymbolicDataAnalysisModelComplexityCalculator.CalculateComplexity() TODO: implement Kommenda's tree complexity directly in the evaluator
268            Array.FindIndex(code, (opc) => opc == (byte)OpCodes.Exit);  // use length of expression as surrogate for complexity
269          UpdateParetoFront(q, complexity, code, optConsts, nParams, scalingFactor, scalingOffset);
270        }
271        return q;
272      }
273
274      private void Eval(byte[] code, int nParams, out double rho, out double[] optConsts) {
275        // we make a first pass to determine a valid starting configuration for all constants
276        // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
277        // scale and offset are set to optimal starting configuration
278        // assumes scale is the first param and offset is the last param
279
280        // reset constants
281        Array.Copy(ones, constsBuf, nParams);
282        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
283        funcEvaluations++;
284
285        if (nParams == 0 || constOptIterations < 0) {
286          // if we don't need to optimize parameters then we are done
287          // changing scale and offset does not influence r²
288          rho = Rho(y, predBuf);
289          optConsts = constsBuf;
290        } else {
291          // optimize constants using the starting point calculated above
292          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
293
294          evaluator.Exec(code, x, constsBuf, predBuf);
295          funcEvaluations++;
296
297          rho = Rho(y, predBuf);
298          optConsts = constsBuf;
299        }
300      }
301
302
303
304      #region helpers
305      private static double Rho(IEnumerable<double> x, IEnumerable<double> y) {
306        OnlineCalculatorError error;
307        double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
308        return error == OnlineCalculatorError.None ? r : 0.0;
309      }
310
311
312      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
313        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
314        Array.Copy(consts, optConsts, nParams);
315
316        // direct usage of LM is recommended in alglib manual for better performance than the lsfit interface (which uses lm internally).
317        alglib.minlmstate state;
318        alglib.minlmreport rep = null;
319        alglib.minlmcreatevj(y.Length + 1, optConsts, out state); // +1 for penalty term
320        // Using the change of the gradient as stopping criterion is recommended in alglib manual.
321        // However, the most recent version of alglib (as of Oct 2017) only supports epsX as stopping criterion
322        alglib.minlmsetcond(state, epsg: 1E-6 * yStdDev, epsf: epsF, epsx: 0.0, maxits: nIters);
323        // alglib.minlmsetgradientcheck(state, 1E-5);
324        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
325        alglib.minlmresults(state, out optConsts, out rep);
326        funcEvaluations += rep.nfunc;
327        gradEvaluations += rep.njac * nParams;
328
329        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
330
331        // only use optimized constants if successful
332        if (rep.terminationtype >= 0) {
333          Array.Copy(optConsts, consts, optConsts.Length);
334        }
335      }
336
337      private void Func(double[] arg, double[] fi, object obj) {
338        var code = (byte[])obj;
339        int n = predBuf.Length;
340        evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
341        for (int r = 0; r < n; r++) {
342          var res = predBuf[r] - y[r];
343          fi[r] = res;
344        }
345
346        var penaltyIdx = fi.Length - 1;
347        fi[penaltyIdx] = 0.0;
348        // calc length of parameter vector for regularization
349        var aa = 0.0;
350        for (int i = 0; i < arg.Length; i++) {
351          aa += arg[i] * arg[i];
352        }
353        if (lambda > 0 && aa > 0) {
354          // scale lambda using stdDev(y) to make the parameter independent of the scale of y
355          // scale lambda using n to make parameter independent of the number of training points
356          // take the root because LM squares the result
357          fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
358        }
359      }
360
361      private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
362        int n = predBuf.Length;
363        int nParams = arg.Length;
364        var code = (byte[])obj;
365        evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
366        for (int r = 0; r < n; r++) {
367          var res = predBuf[r] - y[r];
368          fi[r] = res;
369
370          for (int k = 0; k < nParams; k++) {
371            jac[r, k] = gradBuf[k][r];
372          }
373        }
374        // calc length of parameter vector for regularization
375        double aa = 0.0;
376        for (int i = 0; i < arg.Length; i++) {
377          aa += arg[i] * arg[i];
378        }
379
380        var penaltyIdx = fi.Length - 1;
381        if (lambda > 0 && aa > 0) {
382          fi[penaltyIdx] = 0.0;
383          // scale lambda using stdDev(y) to make the parameter independent of the scale of y
384          // scale lambda using n to make parameter independent of the number of training points
385          // take the root because alglib LM squares the result
386          fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
387
388          for (int i = 0; i < arg.Length; i++) {
389            jac[penaltyIdx, i] = 0.5 / fi[penaltyIdx] * 2 * n * lambda / yStdDev * arg[i];
390          }
391        } else {
392          fi[penaltyIdx] = 0.0;
393          for (int i = 0; i < arg.Length; i++) {
394            jac[penaltyIdx, i] = 0.0;
395          }
396        }
397      }
398
399
400      private void UpdateParetoFront(double q, int complexity, byte[] code, double[] param, int nParam,
401        double[] scalingFactor, double[] scalingOffset) {
402        double[] best = new double[2];
403        double[] cur = new double[2] { q, complexity };
404        bool[] max = new[] { true, false };
405        var isNonDominated = true;
406        foreach (var e in paretoFront) {
407          var domRes = DominationCalculator<int>.Dominates(cur, e, max, true);
408          if (domRes == DominationResult.IsDominated) {
409            isNonDominated = false;
410            break;
411          }
412        }
413        if (isNonDominated) {
414          paretoFront.Add(cur);
415
416          // create model
417          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
418          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
419
420          var t = new SymbolicExpressionTree(treeGen.Exec(code, param, nParam, scalingFactor, scalingOffset));
421          var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
422          model.Scale(problemData); // apply linear scaling
423
424          var sol = model.CreateRegressionSolution(this.problemData);
425          sol.Name = string.Format("{0:N5} {1}", q, complexity);
426
427          paretoBestModels.Add(sol);
428        }
429        for (int i = paretoFront.Count - 2; i >= 0; i--) {
430          var @ref = paretoFront[i];
431          var domRes = DominationCalculator<int>.Dominates(cur, @ref, max, true);
432          if (domRes == DominationResult.Dominates) {
433            paretoFront.RemoveAt(i);
434            paretoBestModels.RemoveAt(i);
435          }
436        }
437      }
438      #endregion
439    }
440
441
442    /// <summary>
443    /// Static method to initialize a state for the algorithm
444    /// </summary>
445    /// <param name="problemData">The problem data</param>
446    /// <param name="randSeed">Random seed.</param>
447    /// <param name="maxVariables">Maximum number of variable references that are allowed in the expression.</param>
448    /// <param name="scaleVariables">Optionally scale input variables to the interval [0..1] (recommended)</param>
449    /// <param name="constOptIterations">Maximum number of iterations for constants optimization (Levenberg-Marquardt)</param>
450    /// <param name="lambda">Penalty factor for regularization (0..inf.), small penalty disabled regularization.</param>
451    /// <param name="policy">Tree search policy (random, ucb, eps-greedy, ...)</param>
452    /// <param name="collectParameterOptimalModels">Optionally collect all Pareto-optimal solutions having minimal length and error.</param>
453    /// <param name="lowerEstimationLimit">Optionally limit the result of the expression to this lower value.</param>
454    /// <param name="upperEstimationLimit">Optionally limit the result of the expression to this upper value.</param>
455    /// <param name="allowProdOfVars">Allow products of expressions.</param>
456    /// <param name="allowExp">Allow expressions with exponentials.</param>
457    /// <param name="allowLog">Allow expressions with logarithms</param>
458    /// <param name="allowInv">Allow expressions with 1/x</param>
459    /// <param name="allowMultipleTerms">Allow expressions which are sums of multiple terms.</param>
460    /// <returns></returns>
461
462    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
463      bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0,
464      IPolicy policy = null,
465      bool collectParameterOptimalModels = false,
466      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
467      bool allowProdOfVars = true,
468      bool allowExp = true,
469      bool allowLog = true,
470      bool allowInv = true,
471      bool allowMultipleTerms = false
472      ) {
473      return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,
474        policy, collectParameterOptimalModels,
475        lowerEstimationLimit, upperEstimationLimit,
476        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
477    }
478
479    // returns the quality of the evaluated solution
480    public static double MakeStep(IState state) {
481      var mctsState = state as State;
482      if (mctsState == null) throw new ArgumentException("state");
483      if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
484
485      return TreeSearch(mctsState);
486    }
487    #endregion
488
489    private static double TreeSearch(State mctsState) {
490      var automaton = mctsState.automaton;
491      var tree = mctsState.tree;
492      var eval = mctsState.evalFun;
493      var rand = mctsState.random;
494      var treePolicy = mctsState.treePolicy;
495      double q = 0;
496      bool success = false;
497      do {
498        automaton.Reset();
499        success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, mctsState, out q);
500        mctsState.totalRollouts++;
501      } while (!success && !tree.Done);
502      mctsState.effectiveRollouts++;
503
504      //if (mctsState.effectiveRollouts % 100 == 1) {
505        // Console.WriteLine(WriteTree(tree, mctsState));
506        // Console.WriteLine(TraceTree(tree, mctsState));
507      //}
508      return q;
509    }
510
511
512    // search forward
513    private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
514      State state,
515      out double q) {
516      // ROLLOUT AND EXPANSION
517      // We are navigating a graph (states might be reached via different paths) instead of a tree.
518      // State equivalence is checked through ExprHash (based on the generated code through the path).
519
520      // We switch between rollout-mode and expansion mode
521      // Rollout-mode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB)
522      // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression)
523      // In expansion mode we might re-enter the graph and switch back to rollout-mode
524      // We do this until we reach a complete expression (final state)
525
526      // Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent
527      // Sub-graphs which have been completely searched are marked as done.
528      // Roll-out could lead to a state where all follow-states are done. In this case we call the rollout ineffective.
529
530      while (!automaton.IsFinalState(automaton.CurrentState)) {
531        if (state.children.ContainsKey(tree)) {
532          if (state.children[tree].All(ch => ch.Done)) {
533            tree.Done = true;
534            break;
535          }
536          // ROLLOUT INSIDE TREE
537          // UCT selection within tree
538          int selectedIdx = 0;
539          if (state.children[tree].Count > 1) {
540            selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand);
541          }
542          tree = state.children[tree][selectedIdx];
543
544          // move the automaton forward until reaching the state
545          // all steps where no alternatives are possible are immediately taken
546          // TODO: simplification of the automaton
547          int[] possibleFollowStates;
548          int nFs;
549          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
550          while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
551            automaton.Goto(possibleFollowStates[0]);
552            automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
553          }
554          Debug.Assert(possibleFollowStates.Contains(tree.state));
555          automaton.Goto(tree.state);
556        } else {
557          // EXPAND
558          int[] possibleFollowStates;
559          int nFs;
560          string actionString = "";
561          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
562          while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
563            actionString += " " + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[0]);
564            // no alternatives -> just go to the next state
565            automaton.Goto(possibleFollowStates[0]);
566            automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
567          }
568          if (nFs == 0) {
569            // stuck in a dead end (no final state and no allowed follow states)
570            tree.Done = true;
571            break;
572          }
573          var newChildren = new List<Tree>(nFs);
574          state.children.Add(tree, newChildren);
575          for (int i = 0; i < nFs; i++) {
576            Tree child = null;
577            // for selected states (EvalStates) we introduce state unification (detection of equivalent states)
578            if (automaton.IsEvalState(possibleFollowStates[i])) {
579              var hc = Hashcode(automaton);
580              if (!state.nodes.TryGetValue(hc, out child)) {
581                child = new Tree() {
582                  children = null,
583                  state = possibleFollowStates[i],
584                  actionStatistics = treePolicy.CreateActionStatistics(),
585                  expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
586                  level = tree.level + 1
587                };
588                state.nodes.Add(hc, child);
589              }
590              // only allow forward edges (don't add the child if we would go back in the graph)
591              else if (child.level > tree.level)  {
592                // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
593                // to all parents
594                BackpropagateStatistics(child.actionStatistics, tree, state);
595              } else {
596                // prevent cycles
597                Debug.Assert(child.level <= tree.level);
598                child = null;
599              }   
600            } else {
601              child = new Tree() {
602                children = null,
603                state = possibleFollowStates[i],
604                actionStatistics = treePolicy.CreateActionStatistics(),
605                expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
606                level = tree.level + 1
607              };
608            }
609            if (child != null)
610              newChildren.Add(child);
611          }
612
613          if (!newChildren.Any()) {
614            // stuck in a dead end (no final state and no allowed follow states)
615            tree.Done = true;
616            break;
617          }
618
619          foreach (var ch in newChildren) {
620            if (!state.parents.ContainsKey(ch)) {
621              state.parents.Add(ch, new List<Tree>());
622            }
623            state.parents[ch].Add(tree);
624          }
625
626
627          // follow one of the children
628          tree = SelectStateLeadingToFinal(automaton, tree, rand, state);
629          automaton.Goto(tree.state);
630        }
631      }
632
633      bool success;
634
635      // EVALUATE TREE
636      if (automaton.IsFinalState(automaton.CurrentState)) {
637        tree.Done = true;
638        tree.expr = ExprStr(automaton);
639        byte[] code; int nParams;
640        automaton.GetCode(out code, out nParams);
641        q = eval(code, nParams);
642        // Console.WriteLine("{0:N4}\t{1}", q*q, tree.expr);
643        q = TransformQuality(q);
644        success = true;
645      } else {
646        // we got stuck in roll-out (not evaluation necessary!)
647        // Console.WriteLine("\t" + ExprStr(automaton) + " STOP");
648        q = 0.0;
649        success = false;
650      }
651
652      // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS
653      // Update statistics
654      // Set branch to done if all children are done.
655      BackpropagateQuality(tree, q, treePolicy, state);
656
657      return success;
658    }
659
660
661    private static double TransformQuality(double q) {
662      // no transformation
663      // return q;
664
665      // EXPERIMENTAL!
666
667      // Fisher transformation
668      // (assumes q is Correl(pred, target)
669
670      q = Math.Min(q,  0.99999999);
671      q = Math.Max(q, -0.99999999);
672      return 0.5 * Math.Log((1 + q) / (1 - q));
673
674      // optimal result: q = 1 -> return huge value
675      // if (q >= 1.0) return 1E16;
676      // // return number of 9s in R²
677      // return -Math.Log10(1 - q);
678    }
679
680    // backpropagate existing statistics to all parents
681    private static void BackpropagateStatistics(IActionStatistics stats, Tree tree, State state) {
682      tree.actionStatistics.Add(stats);
683      if (state.parents.ContainsKey(tree)) {
684        foreach (var parent in state.parents[tree]) {
685          BackpropagateStatistics(stats, parent, state);
686        }
687      }
688    }
689
690    private static ulong Hashcode(Automaton automaton) {
691      byte[] code;
692      int nParams;
693      automaton.GetCode(out code, out nParams);
694      return ExprHash.GetHash(code, nParams);
695    }
696
697    private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) {
698      if (q > 0) policy.Update(tree.actionStatistics, q);
699      if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) {
700        tree.Done = true;
701        // children[tree] = null; keep all nodes
702      }
703
704      if (state.parents.ContainsKey(tree)) {
705        foreach (var parent in state.parents[tree]) {
706          BackpropagateQuality(parent, q, policy, state);
707        }
708      }
709    }
710
711    private static Tree SelectStateLeadingToFinal(Automaton automaton, Tree tree, IRandom rand, State state) {
712      // find the child with the smallest state value (smaller values are closer to the final state)
713      int selectedChildIdx = 0;
714      var children = state.children[tree];
715      Tree minChild = children.First();
716      for (int i = 1; i < children.Count; i++) {
717        if(children[i].state < minChild.state)
718          selectedChildIdx = i;
719      }
720      return children[selectedChildIdx];
721    }
722
723    // tree search might fail because of constraints for expressions
724    // in this case we get stuck we just restart
725    // see ConstraintHandler.cs for more info
726    private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
727      out double q) {
728      Tree selectedChild = null;
729      Contract.Assert(tree.state == automaton.CurrentState);
730      Contract.Assert(!tree.Done);
731      if (tree.children == null) {
732        if (automaton.IsFinalState(tree.state)) {
733          // final state
734          tree.Done = true;
735
736          // EVALUATE
737          byte[] code; int nParams;
738          automaton.GetCode(out code, out nParams);
739          q = eval(code, nParams);
740
741          treePolicy.Update(tree.actionStatistics, q);
742          return true; // we reached a final state
743        } else {
744          // EXPAND
745          int[] possibleFollowStates;
746          int nFs;
747          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
748          if (nFs == 0) {
749            // stuck in a dead end (no final state and no allowed follow states)
750            q = 0;
751            tree.Done = true;
752            tree.children = null;
753            return false;
754          }
755          tree.children = new Tree[nFs];
756          for (int i = 0; i < tree.children.Length; i++)
757            tree.children[i] = new Tree() {
758              children = null,
759              state = possibleFollowStates[i],
760              actionStatistics = treePolicy.CreateActionStatistics()
761            };
762
763          selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
764        }
765      } else {
766        // tree.children != null
767        // UCT selection within tree
768        int selectedIdx = 0;
769        if (tree.children.Length > 1) {
770          selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);
771        }
772        selectedChild = tree.children[selectedIdx];
773      }
774      // make selected step and recurse
775      automaton.Goto(selectedChild.state);
776      var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);
777      if (success) {
778        // only update if successful
779        treePolicy.Update(tree.actionStatistics, q);
780      }
781
782      tree.Done = tree.children.All(ch => ch.Done);
783      if (tree.Done) {
784        tree.children = null; // cut off the sub-branch if it has been fully explored
785      }
786      return success;
787    }
788
789    private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
790      // if one of the new children leads to a final state then go there
791      // otherwise choose a random child
792      int selectedChildIdx = -1;
793      // find first final state if there is one
794      for (int i = 0; i < tree.children.Length; i++) {
795        if (automaton.IsFinalState(tree.children[i].state)) {
796          selectedChildIdx = i;
797          break;
798        }
799      }
800      // no final state -> select a the first child
801      if (selectedChildIdx == -1) {
802        selectedChildIdx = 0;
803      }
804      return tree.children[selectedChildIdx];
805    }
806
807    // scales data and extracts values from dataset into arrays
808    private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,
809      out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
810      xs = new double[problemData.AllowedInputVariables.Count()][];
811
812      var i = 0;
813      if (scaleVariables) {
814        scalingFactor = new double[xs.Length + 1];
815        scalingOffset = new double[xs.Length + 1];
816      } else {
817        scalingFactor = null;
818        scalingOffset = null;
819      }
820      foreach (var var in problemData.AllowedInputVariables) {
821        if (scaleVariables) {
822          var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
823          var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
824          var range = maxX - minX;
825
826          // scaledX = (x - min) / range
827          var sf = 1.0 / range;
828          var offset = -minX / range;
829          scalingFactor[i] = sf;
830          scalingOffset[i] = offset;
831          i++;
832        }
833      }
834
835      if (scaleVariables) {
836        // transform target variable to zero-mean
837        scalingFactor[i] = 1.0;
838        scalingOffset[i] = -problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Average();
839      }
840
841      GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
842    }
843
844    // extract values from dataset into arrays
845    private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,
846     out double[][] xs, out double[] y) {
847      xs = new double[problemData.AllowedInputVariables.Count()][];
848
849      int i = 0;
850      foreach (var var in problemData.AllowedInputVariables) {
851        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
852        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
853        xs[i++] =
854          problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
855      }
856
857      {
858        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
859        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
860        y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(yi => yi * sf + offset).ToArray();
861      }
862    }
863
864    // for debugging only
865
866
867    private static string ExprStr(Automaton automaton) {
868      byte[] code;
869      int nParams;
870      automaton.GetCode(out code, out nParams);
871      return Disassembler.CodeToString(code);
872    }
873
874    private static string WriteStatistics(Tree tree, State state) {
875      var sb = new System.IO.StringWriter();
876      sb.WriteLine("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);
877      if (state.children.ContainsKey(tree)) {
878        foreach (var ch in state.children[tree]) {
879          sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);
880        }
881      }
882      return sb.ToString();
883    }
884
885    private static string TraceTree(Tree tree, State state) {
886      var sb = new StringBuilder();
887      sb.Append(
888@"digraph {
889  ratio = fill;
890  node [style=filled];
891");
892      int nodeId = 0;
893
894      TraceTreeRec(tree, 0, sb, ref nodeId, state);
895      sb.Append("}");
896      return sb.ToString();
897    }
898
899    private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) {
900      var avgNodeQ = tree.actionStatistics.AverageQuality;
901      var tries = tree.actionStatistics.Tries;
902      if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
903      var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
904      hue = 0.0;
905
906      sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine();
907
908      var list = new List<Tuple<int, int, Tree>>();
909      if (state.children.ContainsKey(tree)) {
910        foreach (var ch in state.children[tree]) {
911          nextId++;
912          avgNodeQ = ch.actionStatistics.AverageQuality;
913          tries = ch.actionStatistics.Tries;
914          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
915          hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
916          hue = 0.0;
917          sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
918          sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", parentId, nextId, avgNodeQ, ch.expr).AppendLine();
919          list.Add(Tuple.Create(tries, nextId, ch));
920        }
921
922        foreach(var tup in list) {
923          var ch = tup.Item3;
924          var chId = tup.Item2;
925          if(state.children.ContainsKey(ch) && state.children[ch].Count == 1) {
926            var chch = state.children[ch].First();
927            nextId++;
928            avgNodeQ = chch.actionStatistics.AverageQuality;
929            tries = chch.actionStatistics.Tries;
930            if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
931            hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
932            hue = 0.0;
933            sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
934            sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", chId, nextId, avgNodeQ, chch.expr).AppendLine();
935          }
936        }
937
938        foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {
939          TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId, state);
940        }
941      }
942    }
943
944    private static string WriteTree(Tree tree, State state) {
945      var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);
946      var nodeIds = new Dictionary<Tree, int>();
947      sb.Write(
948@"digraph {
949  ratio = fill;
950  node [style=filled];
951");
952      int threshold = /* state.nodes.Count > 500 ? 10 : */ 0;
953      foreach (var kvp in state.children) {
954        var parent = kvp.Key;
955        int parentId;
956        if (!nodeIds.TryGetValue(parent, out parentId)) {
957          parentId = nodeIds.Count + 1;
958          var avgNodeQ = parent.actionStatistics.AverageQuality;
959          var tries = parent.actionStatistics.Tries;
960          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
961          var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
962          hue = 0.0;
963          if (parent.actionStatistics.Tries > threshold)
964            sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);
965          nodeIds.Add(parent, parentId);
966        }
967        foreach (var child in kvp.Value) {
968          int childId;
969          if (!nodeIds.TryGetValue(child, out childId)) {
970            childId = nodeIds.Count + 1;
971            nodeIds.Add(child, childId);
972          }
973          var avgNodeQ = child.actionStatistics.AverageQuality;
974          var tries = child.actionStatistics.Tries;
975          if (tries < 1) continue;
976          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
977          var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
978          hue = 0.0;
979          if (tries > threshold) {
980            sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);
981            var edgeLabel = child.expr;
982            // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
983            sb.Write("{0} -> {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel);
984          }
985        }
986      }
987
988      sb.Write("}");
989      return sb.ToString();
990    }
991  }
992}
Note: See TracBrowser for help on using the repository browser.