Free cookie consent management tool by TermsFeed Policy Generator

source: branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs @ 15420

Last change on this file since 15420 was 15420, checked in by gkronber, 7 years ago

#2796: debugging

File size: 44.5 KB
RevLine 
[13645]1#region License Information
2/* HeuristicLab
[14185]3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[13645]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
[15410]24using System.Diagnostics;
[13645]25using System.Diagnostics.Contracts;
26using System.Linq;
[15414]27using System.Text;
[13658]28using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
[13645]29using HeuristicLab.Core;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[15360]31using HeuristicLab.Optimization;
[13645]32using HeuristicLab.Problems.DataAnalysis;
33using HeuristicLab.Problems.DataAnalysis.Symbolic;
34using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
35using HeuristicLab.Random;
36
37namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
38  public static class MctsSymbolicRegressionStatic {
[15403]39    // OBJECTIVES:
40    // 1) solve toy problems without numeric constants (to show that structure search is effective / efficient)
41    //    - e.g. Keijzer, Nguyen ... where no numeric constants are involved
42    //    - assumptions:
43    //      - we don't know the necessary operations or functions -> all available functions could be necessary
44    //      - but we do not need to tune numeric constants -> no scaling of input variables x!
45    // 2) Solve toy problems with numeric constants to make the algorithm invariant concerning variable scale.
46    //    This is important for real world applications.
47    //    - e.g. Korns or Vladislavleva problems where numeric constants are involved
48    //    - assumptions:
49    //      - any numeric constant is possible (a-priori we might assume that small abs. constants are more likely)
50    //      - standardization of variables is possible (or might be necessary) as we adjust numeric parameters of the expression anyway
51    //      - to simplify the problem we can restrict the set of functions e.g. we assume which functions are necessary for the problem instance
52    //        -> several steps: (a) polyinomials, (b) rational polynomials, (c) exponential or logarithmic functions, rational functions with exponential and logarithmic parts
53    // 3) efficiency and effectiveness for real-world problems
54    //    - e.g. Tower problem
55    //    - (1) and (2) combined, structure search must be effective in combination with numeric optimization of constants
56    //   
57
[15410]58    // TODO: Taking averages of R² values is probably not ideal as an improvement of R² from 0.99 to 0.999 should
59    //       weight more than an improvement from 0.98 to 0.99. Also, we are more interested in the best value of a
60    //       branch and less in the expected value. (--> Review "Extreme Bandit" literature again)
[15414]61    // TODO: Solve Poly-10
[15410]62    // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?
[15416]63    // TODO: Why is the algorithm so slow for rather greedy policies (e.g. low C value in UCB)?
64    // TODO: check if we can use a quality measure with range [-1..1] in policies
[15414]65    // TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants
[15404]66    // TODO: check if transformation of y is correct and works (Obj 2)
67    // TODO: The algorithm is not invariant to location and scale of variables.
68    //       Include offset for variables as parameter (for Objective 2)
69    // TODO: why does LM optimization converge so slowly with exp(x), log(x), and 1/x allowed (Obj 2)?
70    // TODO: support e(-x) and possibly (1/-x) (Obj 1)
71    // TODO: is it OK to initialize all constants to 1 (Obj 2)?
[15414]72    // TODO: improve memory usage
[15420]73    // TODO: support empty test partition
[13645]74    #region static API
75
76    public interface IState {
77      bool Done { get; }
78      ISymbolicRegressionModel BestModel { get; }
79      double BestSolutionTrainingQuality { get; }
80      double BestSolutionTestQuality { get; }
[15360]81      IEnumerable<ISymbolicRegressionSolution> ParetoBestModels { get; }
[13651]82      int TotalRollouts { get; }
83      int EffectiveRollouts { get; }
84      int FuncEvaluations { get; }
85      int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
86      // TODO other stats on LM optimizer might be interesting here
[13645]87    }
88
89    // created through factory method
90    private class State : IState {
91      private const int MaxParams = 100;
92
93      // state variables used by MCTS
94      internal readonly Automaton automaton;
95      internal IRandom random { get; private set; }
96      internal readonly Tree tree;
97      internal readonly Func<byte[], int, double> evalFun;
[13658]98      internal readonly IPolicy treePolicy;
[13651]99      // MCTS might get stuck. Track statistics on the number of effective rollouts
100      internal int totalRollouts;
101      internal int effectiveRollouts;
[13645]102
103
104      // state variables used only internally (for eval function)
105      private readonly IRegressionProblemData problemData;
106      private readonly double[][] x;
107      private readonly double[] y;
108      private readonly double[][] testX;
109      private readonly double[] testY;
110      private readonly double[] scalingFactor;
111      private readonly double[] scalingOffset;
[15403]112      private readonly double yStdDev; // for scaling parameters (e.g. stopping condition for LM)
[13645]113      private readonly int constOptIterations;
[15403]114      private readonly double lambda; // weight of penalty term for regularization
[13645]115      private readonly double lowerEstimationLimit, upperEstimationLimit;
[15403]116      private readonly bool collectParetoOptimalModels;
[15360]117      private readonly List<ISymbolicRegressionSolution> paretoBestModels = new List<ISymbolicRegressionSolution>();
118      private readonly List<double[]> paretoFront = new List<double[]>(); // matching the models
[13645]119
120      private readonly ExpressionEvaluator evaluator, testEvaluator;
121
[15416]122      internal readonly Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();
123      internal readonly Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();
124      internal readonly Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();
125
[13645]126      // values for best solution
[15416]127      private double bestR;
[13645]128      private byte[] bestCode;
129      private int bestNParams;
130      private double[] bestConsts;
131
[13651]132      // stats
133      private int funcEvaluations;
134      private int gradEvaluations;
135
[13645]136      // buffers
137      private readonly double[] ones; // vector of ones (as default params)
138      private readonly double[] constsBuf;
139      private readonly double[] predBuf, testPredBuf;
140      private readonly double[][] gradBuf;
141
[15420]142      // debugging stats
143      // calculate for each level the number of alternatives the average 'inequality' of tries and 'inequality' of quality over the alternatives for each trie
144      // inequality can be calculated using the Gini coefficient
145      internal readonly double[] giniCoeffs = new double[100];
146     
147
[15403]148      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,
149        int constOptIterations, double lambda,
[13658]150        IPolicy treePolicy = null,
[15403]151        bool collectParetoOptimalModels = false,
[13645]152        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
153        bool allowProdOfVars = true,
154        bool allowExp = true,
155        bool allowLog = true,
156        bool allowInv = true,
157        bool allowMultipleTerms = false) {
158
[15403]159        if (lambda < 0) throw new ArgumentException("Lambda must be larger or equal zero", "lambda");
160
[13645]161        this.problemData = problemData;
162        this.constOptIterations = constOptIterations;
[15403]163        this.lambda = lambda;
[13645]164        this.evalFun = this.Eval;
165        this.lowerEstimationLimit = lowerEstimationLimit;
166        this.upperEstimationLimit = upperEstimationLimit;
[15403]167        this.collectParetoOptimalModels = collectParetoOptimalModels;
[13645]168
169        random = new MersenneTwister(randSeed);
170
171        // prepare data for evaluation
172        double[][] x;
173        double[] y;
174        double[][] testX;
175        double[] testY;
176        double[] scalingFactor;
177        double[] scalingOffset;
178        // get training and test datasets (scale linearly based on training set if required)
179        GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
180        GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
181        this.x = x;
182        this.y = y;
[15403]183        this.yStdDev = HeuristicLab.Common.EnumerableStatisticExtensions.StandardDeviation(y);
[13645]184        this.testX = testX;
185        this.testY = testY;
186        this.scalingFactor = scalingFactor;
187        this.scalingOffset = scalingOffset;
188        this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
189        // we need a separate evaluator because the vector length for the test dataset might differ
190        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
191
[15414]192        this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
[13658]193        this.treePolicy = treePolicy ?? new Ucb();
[15410]194        this.tree = new Tree() {
195          state = automaton.CurrentState,
196          actionStatistics = treePolicy.CreateActionStatistics(),
[15414]197          expr = "",
198          level = 0
[15410]199        };
[13645]200
201        // reset best solution
[15416]202        this.bestR = 0;
[13645]203        // code for default solution (constant model)
204        this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
205        this.bestNParams = 0;
206        this.bestConsts = null;
207
208        // init buffers
209        this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
210        constsBuf = new double[MaxParams];
211        this.predBuf = new double[y.Length];
212        this.testPredBuf = new double[testY.Length];
213
214        this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
215      }
216
217      #region IState inferface
[13658]218      public bool Done { get { return tree != null && tree.Done; } }
[13645]219
220      public double BestSolutionTrainingQuality {
221        get {
222          evaluator.Exec(bestCode, x, bestConsts, predBuf);
[15416]223          return Rho(y, predBuf);
[13645]224        }
225      }
226
227      public double BestSolutionTestQuality {
228        get {
229          testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
[15416]230          return Rho(testY, testPredBuf);
[13645]231        }
232      }
233
234      // takes the code of the best solution and creates and equivalent symbolic regression model
235      public ISymbolicRegressionModel BestModel {
236        get {
237          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
238          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
239
240          var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
[13941]241          var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
[15403]242          model.Scale(problemData); // apply linear scaling
[13645]243          return model;
244        }
245      }
[15360]246      public IEnumerable<ISymbolicRegressionSolution> ParetoBestModels {
247        get { return paretoBestModels; }
248      }
[13651]249
250      public int TotalRollouts { get { return totalRollouts; } }
251      public int EffectiveRollouts { get { return effectiveRollouts; } }
252      public int FuncEvaluations { get { return funcEvaluations; } }
253      public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
254
[13645]255      #endregion
256
257      private double Eval(byte[] code, int nParams) {
258        double[] optConsts;
259        double q;
260        Eval(code, nParams, out q, out optConsts);
261
[15360]262        // single objective best
[15416]263        if (q > bestR) {
264          bestR = q;
[13645]265          bestNParams = nParams;
266          this.bestCode = new byte[code.Length];
267          this.bestConsts = new double[bestNParams];
268
269          Array.Copy(code, bestCode, code.Length);
270          Array.Copy(optConsts, bestConsts, bestNParams);
271        }
[15403]272        if (collectParetoOptimalModels) {
273          // multi-objective best
274          var complexity = // SymbolicDataAnalysisModelComplexityCalculator.CalculateComplexity() TODO: implement Kommenda's tree complexity directly in the evaluator
275            Array.FindIndex(code, (opc) => opc == (byte)OpCodes.Exit);  // use length of expression as surrogate for complexity
276          UpdateParetoFront(q, complexity, code, optConsts, nParams, scalingFactor, scalingOffset);
277        }
[13645]278        return q;
279      }
280
[15416]281      private void Eval(byte[] code, int nParams, out double rho, out double[] optConsts) {
[13645]282        // we make a first pass to determine a valid starting configuration for all constants
283        // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
284        // scale and offset are set to optimal starting configuration
285        // assumes scale is the first param and offset is the last param
286
287        // reset constants
288        Array.Copy(ones, constsBuf, nParams);
289        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
[13651]290        funcEvaluations++;
[13645]291
[15403]292        if (nParams == 0 || constOptIterations < 0) {
[13645]293          // if we don't need to optimize parameters then we are done
294          // changing scale and offset does not influence r²
[15416]295          rho = Rho(y, predBuf);
[13645]296          optConsts = constsBuf;
297        } else {
298          // optimize constants using the starting point calculated above
299          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
[13651]300
[13645]301          evaluator.Exec(code, x, constsBuf, predBuf);
[13651]302          funcEvaluations++;
303
[15416]304          rho = Rho(y, predBuf);
[13645]305          optConsts = constsBuf;
306        }
307      }
308
309
310
311      #region helpers
[15416]312      private static double Rho(IEnumerable<double> x, IEnumerable<double> y) {
[13645]313        OnlineCalculatorError error;
314        double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
[15416]315        return error == OnlineCalculatorError.None ? r : 0.0;
[13645]316      }
317
318
319      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
[13651]320        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
[13645]321        Array.Copy(consts, optConsts, nParams);
322
[15403]323        // direct usage of LM is recommended in alglib manual for better performance than the lsfit interface (which uses lm internally).
[13645]324        alglib.minlmstate state;
325        alglib.minlmreport rep = null;
[15403]326        alglib.minlmcreatevj(y.Length + 1, optConsts, out state); // +1 for penalty term
327        // Using the change of the gradient as stopping criterion is recommended in alglib manual.
328        // However, the most recent version of alglib (as of Oct 2017) only supports epsX as stopping criterion
329        alglib.minlmsetcond(state, epsg: 1E-6 * yStdDev, epsf: epsF, epsx: 0.0, maxits: nIters);
330        // alglib.minlmsetgradientcheck(state, 1E-5);
[13645]331        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
332        alglib.minlmresults(state, out optConsts, out rep);
[13651]333        funcEvaluations += rep.nfunc;
334        gradEvaluations += rep.njac * nParams;
[13645]335
336        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
337
338        // only use optimized constants if successful
339        if (rep.terminationtype >= 0) {
340          Array.Copy(optConsts, consts, optConsts.Length);
341        }
342      }
343
344      private void Func(double[] arg, double[] fi, object obj) {
345        var code = (byte[])obj;
[15403]346        int n = predBuf.Length;
[13645]347        evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
[15403]348        for (int r = 0; r < n; r++) {
[13645]349          var res = predBuf[r] - y[r];
350          fi[r] = res;
351        }
[15403]352
353        var penaltyIdx = fi.Length - 1;
354        fi[penaltyIdx] = 0.0;
355        // calc length of parameter vector for regularization
356        var aa = 0.0;
357        for (int i = 0; i < arg.Length; i++) {
358          aa += arg[i] * arg[i];
359        }
360        if (lambda > 0 && aa > 0) {
361          // scale lambda using stdDev(y) to make the parameter independent of the scale of y
362          // scale lambda using n to make parameter independent of the number of training points
363          // take the root because LM squares the result
364          fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
365        }
[13645]366      }
[15403]367
[13645]368      private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
[15403]369        int n = predBuf.Length;
[13645]370        int nParams = arg.Length;
371        var code = (byte[])obj;
372        evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
[15403]373        for (int r = 0; r < n; r++) {
[13645]374          var res = predBuf[r] - y[r];
375          fi[r] = res;
376
377          for (int k = 0; k < nParams; k++) {
378            jac[r, k] = gradBuf[k][r];
379          }
380        }
[15403]381        // calc length of parameter vector for regularization
382        double aa = 0.0;
383        for (int i = 0; i < arg.Length; i++) {
384          aa += arg[i] * arg[i];
385        }
386
387        var penaltyIdx = fi.Length - 1;
388        if (lambda > 0 && aa > 0) {
389          fi[penaltyIdx] = 0.0;
390          // scale lambda using stdDev(y) to make the parameter independent of the scale of y
391          // scale lambda using n to make parameter independent of the number of training points
392          // take the root because alglib LM squares the result
393          fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
394
395          for (int i = 0; i < arg.Length; i++) {
396            jac[penaltyIdx, i] = 0.5 / fi[penaltyIdx] * 2 * n * lambda / yStdDev * arg[i];
397          }
398        } else {
399          fi[penaltyIdx] = 0.0;
400          for (int i = 0; i < arg.Length; i++) {
401            jac[penaltyIdx, i] = 0.0;
402          }
403        }
[13645]404      }
[15403]405
406
407      private void UpdateParetoFront(double q, int complexity, byte[] code, double[] param, int nParam,
408        double[] scalingFactor, double[] scalingOffset) {
409        double[] best = new double[2];
410        double[] cur = new double[2] { q, complexity };
411        bool[] max = new[] { true, false };
412        var isNonDominated = true;
413        foreach (var e in paretoFront) {
414          var domRes = DominationCalculator<int>.Dominates(cur, e, max, true);
415          if (domRes == DominationResult.IsDominated) {
416            isNonDominated = false;
417            break;
418          }
419        }
420        if (isNonDominated) {
421          paretoFront.Add(cur);
422
423          // create model
424          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
425          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
426
427          var t = new SymbolicExpressionTree(treeGen.Exec(code, param, nParam, scalingFactor, scalingOffset));
428          var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
429          model.Scale(problemData); // apply linear scaling
430
431          var sol = model.CreateRegressionSolution(this.problemData);
432          sol.Name = string.Format("{0:N5} {1}", q, complexity);
433
434          paretoBestModels.Add(sol);
435        }
436        for (int i = paretoFront.Count - 2; i >= 0; i--) {
437          var @ref = paretoFront[i];
438          var domRes = DominationCalculator<int>.Dominates(cur, @ref, max, true);
439          if (domRes == DominationResult.Dominates) {
440            paretoFront.RemoveAt(i);
441            paretoBestModels.RemoveAt(i);
442          }
443        }
444      }
[15420]445
[13645]446      #endregion
[15420]447
448#if DEBUG
449      internal void ClearStats() {
450        for (int i = 0; i < giniCoeffs.Length; i++) giniCoeffs[i] = -1;
451      }
452      internal void WriteStats() {
453        Console.WriteLine(string.Join("\t", giniCoeffs.TakeWhile(x => x >= 0).Select(x => string.Format("{0:N3}", x))));
454      }
455
456#endif
457
[13645]458    }
459
[15403]460
461    /// <summary>
462    /// Static method to initialize a state for the algorithm
463    /// </summary>
464    /// <param name="problemData">The problem data</param>
465    /// <param name="randSeed">Random seed.</param>
466    /// <param name="maxVariables">Maximum number of variable references that are allowed in the expression.</param>
467    /// <param name="scaleVariables">Optionally scale input variables to the interval [0..1] (recommended)</param>
468    /// <param name="constOptIterations">Maximum number of iterations for constants optimization (Levenberg-Marquardt)</param>
469    /// <param name="lambda">Penalty factor for regularization (0..inf.), small penalty disabled regularization.</param>
470    /// <param name="policy">Tree search policy (random, ucb, eps-greedy, ...)</param>
471    /// <param name="collectParameterOptimalModels">Optionally collect all Pareto-optimal solutions having minimal length and error.</param>
472    /// <param name="lowerEstimationLimit">Optionally limit the result of the expression to this lower value.</param>
473    /// <param name="upperEstimationLimit">Optionally limit the result of the expression to this upper value.</param>
474    /// <param name="allowProdOfVars">Allow products of expressions.</param>
475    /// <param name="allowExp">Allow expressions with exponentials.</param>
476    /// <param name="allowLog">Allow expressions with logarithms</param>
477    /// <param name="allowInv">Allow expressions with 1/x</param>
478    /// <param name="allowMultipleTerms">Allow expressions which are sums of multiple terms.</param>
479    /// <returns></returns>
480
[13658]481    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
[15403]482      bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0,
[13658]483      IPolicy policy = null,
[15403]484      bool collectParameterOptimalModels = false,
[13658]485      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
[13645]486      bool allowProdOfVars = true,
487      bool allowExp = true,
488      bool allowLog = true,
489      bool allowInv = true,
490      bool allowMultipleTerms = false
491      ) {
[15403]492      return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,
493        policy, collectParameterOptimalModels,
[13645]494        lowerEstimationLimit, upperEstimationLimit,
495        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
496    }
497
498    // returns the quality of the evaluated solution
499    public static double MakeStep(IState state) {
500      var mctsState = state as State;
501      if (mctsState == null) throw new ArgumentException("state");
502      if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
503
504      return TreeSearch(mctsState);
505    }
506    #endregion
507
508    private static double TreeSearch(State mctsState) {
509      var automaton = mctsState.automaton;
510      var tree = mctsState.tree;
511      var eval = mctsState.evalFun;
512      var rand = mctsState.random;
[13658]513      var treePolicy = mctsState.treePolicy;
[13651]514      double q = 0;
515      bool success = false;
516      do {
[15420]517#if DEBUG
518        mctsState.ClearStats();
519#endif
[13651]520        automaton.Reset();
[15416]521        success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, mctsState, out q);
[13651]522        mctsState.totalRollouts++;
[13658]523      } while (!success && !tree.Done);
[13651]524      mctsState.effectiveRollouts++;
[15410]525
[15420]526#if DEBUG
527      mctsState.WriteStats();
528#endif
[15416]529      //if (mctsState.effectiveRollouts % 100 == 1) {
[15420]530      // Console.WriteLine(WriteTree(tree, mctsState));
531      // Console.WriteLine(TraceTree(tree, mctsState));
[15416]532      //}
[13651]533      return q;
[13645]534    }
535
[15410]536    // search forward
537    private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
[15416]538      State state,
[15410]539      out double q) {
540      // ROLLOUT AND EXPANSION
541      // We are navigating a graph (states might be reached via different paths) instead of a tree.
542      // State equivalence is checked through ExprHash (based on the generated code through the path).
543
544      // We switch between rollout-mode and expansion mode
545      // Rollout-mode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB)
546      // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression)
547      // In expansion mode we might re-enter the graph and switch back to rollout-mode
548      // We do this until we reach a complete expression (final state)
549
[15414]550      // Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent
[15410]551      // Sub-graphs which have been completely searched are marked as done.
552      // Roll-out could lead to a state where all follow-states are done. In this case we call the rollout ineffective.
553
554      while (!automaton.IsFinalState(automaton.CurrentState)) {
[15416]555        if (state.children.ContainsKey(tree)) {
556          if (state.children[tree].All(ch => ch.Done)) {
[15414]557            tree.Done = true;
558            break;
559          }
[15410]560          // ROLLOUT INSIDE TREE
561          // UCT selection within tree
562          int selectedIdx = 0;
[15416]563          if (state.children[tree].Count > 1) {
564            selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand);
[15410]565          }
[15420]566
567          // STATS
568          state.giniCoeffs[tree.level] = InequalityCoefficient(state.children[tree].Select(ch => (double)ch.actionStatistics.AverageQuality));
569
[15416]570          tree = state.children[tree][selectedIdx];
[15410]571
572          // move the automaton forward until reaching the state
573          // all steps where no alternatives are possible are immediately taken
574          // TODO: simplification of the automaton
575          int[] possibleFollowStates;
576          int nFs;
577          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
[15420]578          while (automaton.CurrentState != tree.state && nFs == 1 &&
579            !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
[15410]580            automaton.Goto(possibleFollowStates[0]);
581            automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
582          }
583          Debug.Assert(possibleFollowStates.Contains(tree.state));
584          automaton.Goto(tree.state);
585        } else {
586          // EXPAND
587          int[] possibleFollowStates;
588          int nFs;
[15416]589          string actionString = "";
[15410]590          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
[15414]591          while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
[15416]592            actionString += " " + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[0]);
[15410]593            // no alternatives -> just go to the next state
594            automaton.Goto(possibleFollowStates[0]);
595            automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
596          }
597          if (nFs == 0) {
598            // stuck in a dead end (no final state and no allowed follow states)
599            tree.Done = true;
600            break;
601          }
602          var newChildren = new List<Tree>(nFs);
[15416]603          state.children.Add(tree, newChildren);
[15410]604          for (int i = 0; i < nFs; i++) {
605            Tree child = null;
[15414]606            // for selected states (EvalStates) we introduce state unification (detection of equivalent states)
[15410]607            if (automaton.IsEvalState(possibleFollowStates[i])) {
608              var hc = Hashcode(automaton);
[15416]609              if (!state.nodes.TryGetValue(hc, out child)) {
[15410]610                child = new Tree() {
611                  children = null,
612                  state = possibleFollowStates[i],
613                  actionStatistics = treePolicy.CreateActionStatistics(),
[15416]614                  expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
[15414]615                  level = tree.level + 1
[15410]616                };
[15416]617                state.nodes.Add(hc, child);
[15414]618              }
619              // only allow forward edges (don't add the child if we would go back in the graph)
[15416]620              else if (child.level > tree.level)  {
[15410]621                // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
622                // to all parents
[15416]623                BackpropagateStatistics(child.actionStatistics, tree, state);
[15414]624              } else {
625                // prevent cycles
626                Debug.Assert(child.level <= tree.level);
627                child = null;
[15416]628              }   
[15410]629            } else {
630              child = new Tree() {
631                children = null,
632                state = possibleFollowStates[i],
633                actionStatistics = treePolicy.CreateActionStatistics(),
[15416]634                expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
[15414]635                level = tree.level + 1
[15410]636              };
637            }
[15414]638            if (child != null)
639              newChildren.Add(child);
[15410]640          }
641
[15414]642          if (!newChildren.Any()) {
643            // stuck in a dead end (no final state and no allowed follow states)
644            tree.Done = true;
645            break;
646          }
647
[15410]648          foreach (var ch in newChildren) {
[15416]649            if (!state.parents.ContainsKey(ch)) {
650              state.parents.Add(ch, new List<Tree>());
[15410]651            }
[15416]652            state.parents[ch].Add(tree);
[15410]653          }
654
[15414]655
[15416]656          // follow one of the children
657          tree = SelectStateLeadingToFinal(automaton, tree, rand, state);
[15410]658          automaton.Goto(tree.state);
659        }
660      }
661
662      bool success;
663
664      // EVALUATE TREE
665      if (automaton.IsFinalState(automaton.CurrentState)) {
666        tree.Done = true;
[15414]667        tree.expr = ExprStr(automaton);
[15410]668        byte[] code; int nParams;
669        automaton.GetCode(out code, out nParams);
670        q = eval(code, nParams);
[15416]671        // Console.WriteLine("{0:N4}\t{1}", q*q, tree.expr);
[15410]672        q = TransformQuality(q);
673        success = true;
674      } else {
675        // we got stuck in roll-out (not evaluation necessary!)
[15416]676        // Console.WriteLine("\t" + ExprStr(automaton) + " STOP");
[15410]677        q = 0.0;
678        success = false;
679      }
680
681      // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS
682      // Update statistics
683      // Set branch to done if all children are done.
[15416]684      BackpropagateQuality(tree, q, treePolicy, state);
[15410]685
686      return success;
687    }
688
[15420]689    private static double InequalityCoefficient(IEnumerable<double> xs) {
690      var arr = xs.ToArray();
691      var sad = 0.0;
692      var sum = 0.0;
[15410]693
[15420]694      for(int i=0;i<arr.Length;i++) {
695        for(int j=0;j<arr.Length;j++) {
696          sad += Math.Abs(arr[i] - arr[j]);
697          sum += arr[j];
698        }
699      }
700      return 0.5 * sad / sum;     
701    }
702
[15410]703    private static double TransformQuality(double q) {
704      // no transformation
[15416]705      // return q;
[15410]706
707      // EXPERIMENTAL!
[15416]708
709      // Fisher transformation
710      // (assumes q is Correl(pred, target)
711
712      q = Math.Min(q,  0.99999999);
713      q = Math.Max(q, -0.99999999);
714      return 0.5 * Math.Log((1 + q) / (1 - q));
715
[15410]716      // optimal result: q = 1 -> return huge value
[15414]717      // if (q >= 1.0) return 1E16;
718      // // return number of 9s in R²
719      // return -Math.Log10(1 - q);
[15410]720    }
721
722    // backpropagate existing statistics to all parents
[15416]723    private static void BackpropagateStatistics(IActionStatistics stats, Tree tree, State state) {
[15410]724      tree.actionStatistics.Add(stats);
[15416]725      if (state.parents.ContainsKey(tree)) {
726        foreach (var parent in state.parents[tree]) {
727          BackpropagateStatistics(stats, parent, state);
[15410]728        }
729      }
730    }
731
732    private static ulong Hashcode(Automaton automaton) {
733      byte[] code;
734      int nParams;
735      automaton.GetCode(out code, out nParams);
736      return ExprHash.GetHash(code, nParams);
737    }
738
[15416]739    private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) {
[15410]740      if (q > 0) policy.Update(tree.actionStatistics, q);
[15416]741      if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) {
[15410]742        tree.Done = true;
743        // children[tree] = null; keep all nodes
744      }
745
[15416]746      if (state.parents.ContainsKey(tree)) {
747        foreach (var parent in state.parents[tree]) {
748          BackpropagateQuality(parent, q, policy, state);
[15410]749        }
750      }
751    }
752
[15416]753    private static Tree SelectStateLeadingToFinal(Automaton automaton, Tree tree, IRandom rand, State state) {
754      // find the child with the smallest state value (smaller values are closer to the final state)
755      int selectedChildIdx = 0;
756      var children = state.children[tree];
757      Tree minChild = children.First();
758      for (int i = 1; i < children.Count; i++) {
759        if(children[i].state < minChild.state)
[15410]760          selectedChildIdx = i;
761      }
762      return children[selectedChildIdx];
763    }
764
[13651]765    // tree search might fail because of constraints for expressions
766    // in this case we get stuck we just restart
767    // see ConstraintHandler.cs for more info
[13658]768    private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
769      out double q) {
[13645]770      Tree selectedChild = null;
771      Contract.Assert(tree.state == automaton.CurrentState);
[13658]772      Contract.Assert(!tree.Done);
[13645]773      if (tree.children == null) {
774        if (automaton.IsFinalState(tree.state)) {
775          // final state
[13658]776          tree.Done = true;
[13645]777
778          // EVALUATE
779          byte[] code; int nParams;
780          automaton.GetCode(out code, out nParams);
781          q = eval(code, nParams);
[13658]782
783          treePolicy.Update(tree.actionStatistics, q);
[13651]784          return true; // we reached a final state
[13645]785        } else {
786          // EXPAND
787          int[] possibleFollowStates;
788          int nFs;
789          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
[13651]790          if (nFs == 0) {
791            // stuck in a dead end (no final state and no allowed follow states)
792            q = 0;
[13658]793            tree.Done = true;
[13651]794            tree.children = null;
795            return false;
796          }
[13645]797          tree.children = new Tree[nFs];
[13651]798          for (int i = 0; i < tree.children.Length; i++)
[15410]799            tree.children[i] = new Tree() {
800              children = null,
801              state = possibleFollowStates[i],
802              actionStatistics = treePolicy.CreateActionStatistics()
803            };
[13645]804
[13657]805          selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
[13645]806        }
807      } else {
808        // tree.children != null
809        // UCT selection within tree
[13658]810        int selectedIdx = 0;
811        if (tree.children.Length > 1) {
812          selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);
813        }
814        selectedChild = tree.children[selectedIdx];
[13645]815      }
816      // make selected step and recurse
817      automaton.Goto(selectedChild.state);
[13658]818      var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);
[13651]819      if (success) {
820        // only update if successful
[13658]821        treePolicy.Update(tree.actionStatistics, q);
[13651]822      }
[13645]823
[13658]824      tree.Done = tree.children.All(ch => ch.Done);
825      if (tree.Done) {
[13651]826        tree.children = null; // cut off the sub-branch if it has been fully explored
[13645]827      }
[13651]828      return success;
[13645]829    }
830
831    private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
832      // if one of the new children leads to a final state then go there
833      // otherwise choose a random child
834      int selectedChildIdx = -1;
835      // find first final state if there is one
836      for (int i = 0; i < tree.children.Length; i++) {
837        if (automaton.IsFinalState(tree.children[i].state)) {
838          selectedChildIdx = i;
839          break;
840        }
841      }
[13669]842      // no final state -> select a the first child
[13645]843      if (selectedChildIdx == -1) {
[13669]844        selectedChildIdx = 0;
[13645]845      }
846      return tree.children[selectedChildIdx];
847    }
848
849    // scales data and extracts values from dataset into arrays
850    private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,
851      out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
852      xs = new double[problemData.AllowedInputVariables.Count()][];
853
854      var i = 0;
855      if (scaleVariables) {
[15403]856        scalingFactor = new double[xs.Length + 1];
857        scalingOffset = new double[xs.Length + 1];
[13645]858      } else {
859        scalingFactor = null;
860        scalingOffset = null;
861      }
862      foreach (var var in problemData.AllowedInputVariables) {
863        if (scaleVariables) {
864          var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
865          var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
866          var range = maxX - minX;
867
868          // scaledX = (x - min) / range
869          var sf = 1.0 / range;
870          var offset = -minX / range;
871          scalingFactor[i] = sf;
872          scalingOffset[i] = offset;
873          i++;
874        }
875      }
876
[15403]877      if (scaleVariables) {
878        // transform target variable to zero-mean
879        scalingFactor[i] = 1.0;
880        scalingOffset[i] = -problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Average();
881      }
882
[13645]883      GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
884    }
885
886    // extract values from dataset into arrays
887    private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,
888     out double[][] xs, out double[] y) {
889      xs = new double[problemData.AllowedInputVariables.Count()][];
890
891      int i = 0;
892      foreach (var var in problemData.AllowedInputVariables) {
893        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
894        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
895        xs[i++] =
896          problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
897      }
898
[15403]899      {
900        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
901        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
902        y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(yi => yi * sf + offset).ToArray();
903      }
[13645]904    }
[15410]905
906    // for debugging only
907
908
909    private static string ExprStr(Automaton automaton) {
910      byte[] code;
911      int nParams;
912      automaton.GetCode(out code, out nParams);
913      return Disassembler.CodeToString(code);
914    }
915
[15420]916
[15416]917    private static string WriteStatistics(Tree tree, State state) {
[15410]918      var sb = new System.IO.StringWriter();
919      sb.WriteLine("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);
[15416]920      if (state.children.ContainsKey(tree)) {
921        foreach (var ch in state.children[tree]) {
[15410]922          sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);
923        }
924      }
925      return sb.ToString();
926    }
[15414]927
[15416]928    private static string TraceTree(Tree tree, State state) {
[15414]929      var sb = new StringBuilder();
930      sb.Append(
931@"digraph {
932  ratio = fill;
933  node [style=filled];
934");
935      int nodeId = 0;
936
[15416]937      TraceTreeRec(tree, 0, sb, ref nodeId, state);
[15414]938      sb.Append("}");
939      return sb.ToString();
940    }
941
[15416]942    private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) {
[15414]943      var avgNodeQ = tree.actionStatistics.AverageQuality;
944      var tries = tree.actionStatistics.Tries;
945      if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
946      var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
[15416]947      hue = 0.0;
[15414]948
[15416]949      sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine();
[15414]950
951      var list = new List<Tuple<int, int, Tree>>();
[15416]952      if (state.children.ContainsKey(tree)) {
953        foreach (var ch in state.children[tree]) {
[15414]954          nextId++;
955          avgNodeQ = ch.actionStatistics.AverageQuality;
956          tries = ch.actionStatistics.Tries;
957          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
958          hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
[15416]959          hue = 0.0;
960          sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
961          sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", parentId, nextId, avgNodeQ, ch.expr).AppendLine();
[15414]962          list.Add(Tuple.Create(tries, nextId, ch));
963        }
[15416]964
965        foreach(var tup in list) {
966          var ch = tup.Item3;
967          var chId = tup.Item2;
968          if(state.children.ContainsKey(ch) && state.children[ch].Count == 1) {
969            var chch = state.children[ch].First();
970            nextId++;
971            avgNodeQ = chch.actionStatistics.AverageQuality;
972            tries = chch.actionStatistics.Tries;
973            if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
974            hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
975            hue = 0.0;
976            sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
977            sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", chId, nextId, avgNodeQ, chch.expr).AppendLine();
978          }
979        }
980
[15414]981        foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {
[15416]982          TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId, state);
[15414]983        }
984      }
985    }
986
[15416]987    private static string WriteTree(Tree tree, State state) {
[15410]988      var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);
989      var nodeIds = new Dictionary<Tree, int>();
990      sb.Write(
991@"digraph {
992  ratio = fill;
993  node [style=filled];
994");
[15416]995      int threshold = /* state.nodes.Count > 500 ? 10 : */ 0;
996      foreach (var kvp in state.children) {
[15410]997        var parent = kvp.Key;
998        int parentId;
[15414]999        if (!nodeIds.TryGetValue(parent, out parentId)) {
[15410]1000          parentId = nodeIds.Count + 1;
[15414]1001          var avgNodeQ = parent.actionStatistics.AverageQuality;
[15410]1002          var tries = parent.actionStatistics.Tries;
1003          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
[15414]1004          var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
[15416]1005          hue = 0.0;
[15414]1006          if (parent.actionStatistics.Tries > threshold)
[15416]1007            sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);
[15410]1008          nodeIds.Add(parent, parentId);
1009        }
[15414]1010        foreach (var child in kvp.Value) {
[15410]1011          int childId;
[15414]1012          if (!nodeIds.TryGetValue(child, out childId)) {
[15410]1013            childId = nodeIds.Count + 1;
1014            nodeIds.Add(child, childId);
1015          }
1016          var avgNodeQ = child.actionStatistics.AverageQuality;
1017          var tries = child.actionStatistics.Tries;
1018          if (tries < 1) continue;
1019          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
[15414]1020          var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
[15416]1021          hue = 0.0;
[15414]1022          if (tries > threshold) {
[15416]1023            sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);
[15414]1024            var edgeLabel = child.expr;
1025            // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
1026            sb.Write("{0} -> {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel);
1027          }
[15410]1028        }
1029      }
1030
1031      sb.Write("}");
1032      return sb.ToString();
1033    }
[13645]1034  }
1035}
Note: See TracBrowser for help on using the repository browser.