1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  *


5  * This file is part of HeuristicLab.


6  *


7  * HeuristicLab is free software: you can redistribute it and/or modify


8  * it under the terms of the GNU General Public License as published by


9  * the Free Software Foundation, either version 3 of the License, or


10  * (at your option) any later version.


11  *


12  * HeuristicLab is distributed in the hope that it will be useful,


13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


15  * GNU General Public License for more details.


16  *


17  * You should have received a copy of the GNU General Public License


18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


19  */


20  #endregion


21 


22  using System;


23  using System.Collections.Generic;


24  using System.Diagnostics;


25  using System.Diagnostics.Contracts;


26  using System.Linq;


27  using System.Text;


28  using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;


29  using HeuristicLab.Core;


30  using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;


31  using HeuristicLab.Optimization;


32  using HeuristicLab.Problems.DataAnalysis;


33  using HeuristicLab.Problems.DataAnalysis.Symbolic;


34  using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;


35  using HeuristicLab.Random;


36 


37  namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {


38  public static class MctsSymbolicRegressionStatic {


39  // OBJECTIVES:


40  // 1) solve toy problems without numeric constants (to show that structure search is effective / efficient)


41  //  e.g. Keijzer, Nguyen ... where no numeric constants are involved


42  //  assumptions:


43  //  we don't know the necessary operations or functions > all available functions could be necessary


44  //  but we do not need to tune numeric constants > no scaling of input variables x!


45  // 2) Solve toy problems with numeric constants to make the algorithm invariant concerning variable scale.


46  // This is important for real world applications.


47  //  e.g. Korns or Vladislavleva problems where numeric constants are involved


48  //  assumptions:


49  //  any numeric constant is possible (apriori we might assume that small abs. constants are more likely)


50  //  standardization of variables is possible (or might be necessary) as we adjust numeric parameters of the expression anyway


51  //  to simplify the problem we can restrict the set of functions e.g. we assume which functions are necessary for the problem instance


52  // > several steps: (a) polyinomials, (b) rational polynomials, (c) exponential or logarithmic functions, rational functions with exponential and logarithmic parts


53  // 3) efficiency and effectiveness for realworld problems


54  //  e.g. Tower problem


55  //  (1) and (2) combined, structure search must be effective in combination with numeric optimization of constants


56  //


57 


58  // TODO: Taking averages of R² values is probably not ideal as an improvement of R² from 0.99 to 0.999 should


59  // weight more than an improvement from 0.98 to 0.99. Also, we are more interested in the best value of a


60  // branch and less in the expected value. (> Review "Extreme Bandit" literature again)


61  // TODO: Solve Poly10


62  // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?


63  // TODO: unit tests for benchmark problems which contain log / exp / x^1 but without numeric constants


64  // TODO: check if transformation of y is correct and works (Obj 2)


65  // TODO: The algorithm is not invariant to location and scale of variables.


66  // Include offset for variables as parameter (for Objective 2)


67  // TODO: why does LM optimization converge so slowly with exp(x), log(x), and 1/x allowed (Obj 2)?


68  // TODO: support e(x) and possibly (1/x) (Obj 1)


69  // TODO: is it OK to initialize all constants to 1 (Obj 2)?


70  // TODO: improve memory usage


71  #region static API


72 


73  public interface IState {


74  bool Done { get; }


75  ISymbolicRegressionModel BestModel { get; }


76  double BestSolutionTrainingQuality { get; }


77  double BestSolutionTestQuality { get; }


78  IEnumerable<ISymbolicRegressionSolution> ParetoBestModels { get; }


79  int TotalRollouts { get; }


80  int EffectiveRollouts { get; }


81  int FuncEvaluations { get; }


82  int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations


83  // TODO other stats on LM optimizer might be interesting here


84  }


85 


86  // created through factory method


87  private class State : IState {


88  private const int MaxParams = 100;


89 


90  // state variables used by MCTS


91  internal readonly Automaton automaton;


92  internal IRandom random { get; private set; }


93  internal readonly Tree tree;


94  internal readonly Func<byte[], int, double> evalFun;


95  internal readonly IPolicy treePolicy;


96  // MCTS might get stuck. Track statistics on the number of effective rollouts


97  internal int totalRollouts;


98  internal int effectiveRollouts;


99 


100 


101  // state variables used only internally (for eval function)


102  private readonly IRegressionProblemData problemData;


103  private readonly double[][] x;


104  private readonly double[] y;


105  private readonly double[][] testX;


106  private readonly double[] testY;


107  private readonly double[] scalingFactor;


108  private readonly double[] scalingOffset;


109  private readonly double yStdDev; // for scaling parameters (e.g. stopping condition for LM)


110  private readonly int constOptIterations;


111  private readonly double lambda; // weight of penalty term for regularization


112  private readonly double lowerEstimationLimit, upperEstimationLimit;


113  private readonly bool collectParetoOptimalModels;


114  private readonly List<ISymbolicRegressionSolution> paretoBestModels = new List<ISymbolicRegressionSolution>();


115  private readonly List<double[]> paretoFront = new List<double[]>(); // matching the models


116 


117  private readonly ExpressionEvaluator evaluator, testEvaluator;


118 


119  // values for best solution


120  private double bestRSq;


121  private byte[] bestCode;


122  private int bestNParams;


123  private double[] bestConsts;


124 


125  // stats


126  private int funcEvaluations;


127  private int gradEvaluations;


128 


129  // buffers


130  private readonly double[] ones; // vector of ones (as default params)


131  private readonly double[] constsBuf;


132  private readonly double[] predBuf, testPredBuf;


133  private readonly double[][] gradBuf;


134 


135  public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,


136  int constOptIterations, double lambda,


137  IPolicy treePolicy = null,


138  bool collectParetoOptimalModels = false,


139  double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,


140  bool allowProdOfVars = true,


141  bool allowExp = true,


142  bool allowLog = true,


143  bool allowInv = true,


144  bool allowMultipleTerms = false) {


145 


146  if (lambda < 0) throw new ArgumentException("Lambda must be larger or equal zero", "lambda");


147 


148  this.problemData = problemData;


149  this.constOptIterations = constOptIterations;


150  this.lambda = lambda;


151  this.evalFun = this.Eval;


152  this.lowerEstimationLimit = lowerEstimationLimit;


153  this.upperEstimationLimit = upperEstimationLimit;


154  this.collectParetoOptimalModels = collectParetoOptimalModels;


155 


156  random = new MersenneTwister(randSeed);


157 


158  // prepare data for evaluation


159  double[][] x;


160  double[] y;


161  double[][] testX;


162  double[] testY;


163  double[] scalingFactor;


164  double[] scalingOffset;


165  // get training and test datasets (scale linearly based on training set if required)


166  GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);


167  GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);


168  this.x = x;


169  this.y = y;


170  this.yStdDev = HeuristicLab.Common.EnumerableStatisticExtensions.StandardDeviation(y);


171  this.testX = testX;


172  this.testY = testY;


173  this.scalingFactor = scalingFactor;


174  this.scalingOffset = scalingOffset;


175  this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);


176  // we need a separate evaluator because the vector length for the test dataset might differ


177  this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);


178 


179  this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);


180  this.treePolicy = treePolicy ?? new Ucb();


181  this.tree = new Tree() {


182  state = automaton.CurrentState,


183  actionStatistics = treePolicy.CreateActionStatistics(),


184  expr = "",


185  level = 0


186  };


187 


188  // reset best solution


189  this.bestRSq = 0;


190  // code for default solution (constant model)


191  this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };


192  this.bestNParams = 0;


193  this.bestConsts = null;


194 


195  // init buffers


196  this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();


197  constsBuf = new double[MaxParams];


198  this.predBuf = new double[y.Length];


199  this.testPredBuf = new double[testY.Length];


200 


201  this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();


202  }


203 


204  #region IState inferface


205  public bool Done { get { return tree != null && tree.Done; } }


206 


207  public double BestSolutionTrainingQuality {


208  get {


209  evaluator.Exec(bestCode, x, bestConsts, predBuf);


210  return RSq(y, predBuf);


211  }


212  }


213 


214  public double BestSolutionTestQuality {


215  get {


216  testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);


217  return RSq(testY, testPredBuf);


218  }


219  }


220 


221  // takes the code of the best solution and creates and equivalent symbolic regression model


222  public ISymbolicRegressionModel BestModel {


223  get {


224  var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());


225  var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();


226 


227  var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));


228  var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);


229  model.Scale(problemData); // apply linear scaling


230  return model;


231  }


232  }


233  public IEnumerable<ISymbolicRegressionSolution> ParetoBestModels {


234  get { return paretoBestModels; }


235  }


236 


237  public int TotalRollouts { get { return totalRollouts; } }


238  public int EffectiveRollouts { get { return effectiveRollouts; } }


239  public int FuncEvaluations { get { return funcEvaluations; } }


240  public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations


241 


242  #endregion


243 


244  private double Eval(byte[] code, int nParams) {


245  double[] optConsts;


246  double q;


247  Eval(code, nParams, out q, out optConsts);


248 


249  // single objective best


250  if (q > bestRSq) {


251  bestRSq = q;


252  bestNParams = nParams;


253  this.bestCode = new byte[code.Length];


254  this.bestConsts = new double[bestNParams];


255 


256  Array.Copy(code, bestCode, code.Length);


257  Array.Copy(optConsts, bestConsts, bestNParams);


258  }


259  if (collectParetoOptimalModels) {


260  // multiobjective best


261  var complexity = // SymbolicDataAnalysisModelComplexityCalculator.CalculateComplexity() TODO: implement Kommenda's tree complexity directly in the evaluator


262  Array.FindIndex(code, (opc) => opc == (byte)OpCodes.Exit); // use length of expression as surrogate for complexity


263  UpdateParetoFront(q, complexity, code, optConsts, nParams, scalingFactor, scalingOffset);


264  }


265  return q;


266  }


267 


268  private void Eval(byte[] code, int nParams, out double rsq, out double[] optConsts) {


269  // we make a first pass to determine a valid starting configuration for all constants


270  // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)


271  // scale and offset are set to optimal starting configuration


272  // assumes scale is the first param and offset is the last param


273 


274  // reset constants


275  Array.Copy(ones, constsBuf, nParams);


276  evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);


277  funcEvaluations++;


278 


279  if (nParams == 0  constOptIterations < 0) {


280  // if we don't need to optimize parameters then we are done


281  // changing scale and offset does not influence r²


282  rsq = RSq(y, predBuf);


283  optConsts = constsBuf;


284  } else {


285  // optimize constants using the starting point calculated above


286  OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);


287 


288  evaluator.Exec(code, x, constsBuf, predBuf);


289  funcEvaluations++;


290 


291  rsq = RSq(y, predBuf);


292  optConsts = constsBuf;


293  }


294  }


295 


296 


297 


298  #region helpers


299  private static double RSq(IEnumerable<double> x, IEnumerable<double> y) {


300  OnlineCalculatorError error;


301  double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);


302  return error == OnlineCalculatorError.None ? r * r : 0.0;


303  }


304 


305 


306  private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {


307  double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)


308  Array.Copy(consts, optConsts, nParams);


309 


310  // direct usage of LM is recommended in alglib manual for better performance than the lsfit interface (which uses lm internally).


311  alglib.minlmstate state;


312  alglib.minlmreport rep = null;


313  alglib.minlmcreatevj(y.Length + 1, optConsts, out state); // +1 for penalty term


314  // Using the change of the gradient as stopping criterion is recommended in alglib manual.


315  // However, the most recent version of alglib (as of Oct 2017) only supports epsX as stopping criterion


316  alglib.minlmsetcond(state, epsg: 1E6 * yStdDev, epsf: epsF, epsx: 0.0, maxits: nIters);


317  // alglib.minlmsetgradientcheck(state, 1E5);


318  alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);


319  alglib.minlmresults(state, out optConsts, out rep);


320  funcEvaluations += rep.nfunc;


321  gradEvaluations += rep.njac * nParams;


322 


323  if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);


324 


325  // only use optimized constants if successful


326  if (rep.terminationtype >= 0) {


327  Array.Copy(optConsts, consts, optConsts.Length);


328  }


329  }


330 


331  private void Func(double[] arg, double[] fi, object obj) {


332  var code = (byte[])obj;


333  int n = predBuf.Length;


334  evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen


335  for (int r = 0; r < n; r++) {


336  var res = predBuf[r]  y[r];


337  fi[r] = res;


338  }


339 


340  var penaltyIdx = fi.Length  1;


341  fi[penaltyIdx] = 0.0;


342  // calc length of parameter vector for regularization


343  var aa = 0.0;


344  for (int i = 0; i < arg.Length; i++) {


345  aa += arg[i] * arg[i];


346  }


347  if (lambda > 0 && aa > 0) {


348  // scale lambda using stdDev(y) to make the parameter independent of the scale of y


349  // scale lambda using n to make parameter independent of the number of training points


350  // take the root because LM squares the result


351  fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);


352  }


353  }


354 


355  private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {


356  int n = predBuf.Length;


357  int nParams = arg.Length;


358  var code = (byte[])obj;


359  evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen


360  for (int r = 0; r < n; r++) {


361  var res = predBuf[r]  y[r];


362  fi[r] = res;


363 


364  for (int k = 0; k < nParams; k++) {


365  jac[r, k] = gradBuf[k][r];


366  }


367  }


368  // calc length of parameter vector for regularization


369  double aa = 0.0;


370  for (int i = 0; i < arg.Length; i++) {


371  aa += arg[i] * arg[i];


372  }


373 


374  var penaltyIdx = fi.Length  1;


375  if (lambda > 0 && aa > 0) {


376  fi[penaltyIdx] = 0.0;


377  // scale lambda using stdDev(y) to make the parameter independent of the scale of y


378  // scale lambda using n to make parameter independent of the number of training points


379  // take the root because alglib LM squares the result


380  fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);


381 


382  for (int i = 0; i < arg.Length; i++) {


383  jac[penaltyIdx, i] = 0.5 / fi[penaltyIdx] * 2 * n * lambda / yStdDev * arg[i];


384  }


385  } else {


386  fi[penaltyIdx] = 0.0;


387  for (int i = 0; i < arg.Length; i++) {


388  jac[penaltyIdx, i] = 0.0;


389  }


390  }


391  }


392 


393 


394  private void UpdateParetoFront(double q, int complexity, byte[] code, double[] param, int nParam,


395  double[] scalingFactor, double[] scalingOffset) {


396  double[] best = new double[2];


397  double[] cur = new double[2] { q, complexity };


398  bool[] max = new[] { true, false };


399  var isNonDominated = true;


400  foreach (var e in paretoFront) {


401  var domRes = DominationCalculator<int>.Dominates(cur, e, max, true);


402  if (domRes == DominationResult.IsDominated) {


403  isNonDominated = false;


404  break;


405  }


406  }


407  if (isNonDominated) {


408  paretoFront.Add(cur);


409 


410  // create model


411  var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());


412  var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();


413 


414  var t = new SymbolicExpressionTree(treeGen.Exec(code, param, nParam, scalingFactor, scalingOffset));


415  var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);


416  model.Scale(problemData); // apply linear scaling


417 


418  var sol = model.CreateRegressionSolution(this.problemData);


419  sol.Name = string.Format("{0:N5} {1}", q, complexity);


420 


421  paretoBestModels.Add(sol);


422  }


423  for (int i = paretoFront.Count  2; i >= 0; i) {


424  var @ref = paretoFront[i];


425  var domRes = DominationCalculator<int>.Dominates(cur, @ref, max, true);


426  if (domRes == DominationResult.Dominates) {


427  paretoFront.RemoveAt(i);


428  paretoBestModels.RemoveAt(i);


429  }


430  }


431  }


432  #endregion


433  }


434 


435 


436  /// <summary>


437  /// Static method to initialize a state for the algorithm


438  /// </summary>


439  /// <param name="problemData">The problem data</param>


440  /// <param name="randSeed">Random seed.</param>


441  /// <param name="maxVariables">Maximum number of variable references that are allowed in the expression.</param>


442  /// <param name="scaleVariables">Optionally scale input variables to the interval [0..1] (recommended)</param>


443  /// <param name="constOptIterations">Maximum number of iterations for constants optimization (LevenbergMarquardt)</param>


444  /// <param name="lambda">Penalty factor for regularization (0..inf.), small penalty disabled regularization.</param>


445  /// <param name="policy">Tree search policy (random, ucb, epsgreedy, ...)</param>


446  /// <param name="collectParameterOptimalModels">Optionally collect all Paretooptimal solutions having minimal length and error.</param>


447  /// <param name="lowerEstimationLimit">Optionally limit the result of the expression to this lower value.</param>


448  /// <param name="upperEstimationLimit">Optionally limit the result of the expression to this upper value.</param>


449  /// <param name="allowProdOfVars">Allow products of expressions.</param>


450  /// <param name="allowExp">Allow expressions with exponentials.</param>


451  /// <param name="allowLog">Allow expressions with logarithms</param>


452  /// <param name="allowInv">Allow expressions with 1/x</param>


453  /// <param name="allowMultipleTerms">Allow expressions which are sums of multiple terms.</param>


454  /// <returns></returns>


455 


456  public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,


457  bool scaleVariables = true, int constOptIterations = 1, double lambda = 0.0,


458  IPolicy policy = null,


459  bool collectParameterOptimalModels = false,


460  double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,


461  bool allowProdOfVars = true,


462  bool allowExp = true,


463  bool allowLog = true,


464  bool allowInv = true,


465  bool allowMultipleTerms = false


466  ) {


467  return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,


468  policy, collectParameterOptimalModels,


469  lowerEstimationLimit, upperEstimationLimit,


470  allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);


471  }


472 


473  // returns the quality of the evaluated solution


474  public static double MakeStep(IState state) {


475  var mctsState = state as State;


476  if (mctsState == null) throw new ArgumentException("state");


477  if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");


478 


479  return TreeSearch(mctsState);


480  }


481  #endregion


482 


483  private static double TreeSearch(State mctsState) {


484  var automaton = mctsState.automaton;


485  var tree = mctsState.tree;


486  var eval = mctsState.evalFun;


487  var rand = mctsState.random;


488  var treePolicy = mctsState.treePolicy;


489  double q = 0;


490  bool success = false;


491  do {


492  automaton.Reset();


493  success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, out q);


494  mctsState.totalRollouts++;


495  } while (!success && !tree.Done);


496  mctsState.effectiveRollouts++;


497 


498  if (mctsState.effectiveRollouts % 10 == 1) {


499  //Console.WriteLine(WriteTree(tree));


500  //Console.WriteLine(TraceTree(tree));


501  }


502  return q;


503  }


504 


505  private static Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();


506  private static Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();


507  private static Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();


508 


509 


510 


511  // search forward


512  private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,


513  out double q) {


514  // ROLLOUT AND EXPANSION


515  // We are navigating a graph (states might be reached via different paths) instead of a tree.


516  // State equivalence is checked through ExprHash (based on the generated code through the path).


517 


518  // We switch between rolloutmode and expansion mode


519  // Rolloutmode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB)


520  // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression)


521  // In expansion mode we might reenter the graph and switch back to rolloutmode


522  // We do this until we reach a complete expression (final state)


523 


524  // Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent


525  // Subgraphs which have been completely searched are marked as done.


526  // Rollout could lead to a state where all followstates are done. In this case we call the rollout ineffective.


527 


528  while (!automaton.IsFinalState(automaton.CurrentState)) {


529  if (children.ContainsKey(tree)) {


530  if (children[tree].All(ch => ch.Done)) {


531  tree.Done = true;


532  break;


533  }


534  // ROLLOUT INSIDE TREE


535  // UCT selection within tree


536  int selectedIdx = 0;


537  if (children[tree].Count > 1) {


538  selectedIdx = treePolicy.Select(children[tree].Select(ch => ch.actionStatistics), rand);


539  }


540  tree = children[tree][selectedIdx];


541 


542  // move the automaton forward until reaching the state


543  // all steps where no alternatives are possible are immediately taken


544  // TODO: simplification of the automaton


545  int[] possibleFollowStates;


546  int nFs;


547  automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);


548  while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {


549  automaton.Goto(possibleFollowStates[0]);


550  automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);


551  }


552  Debug.Assert(possibleFollowStates.Contains(tree.state));


553  automaton.Goto(tree.state);


554  } else {


555  // EXPAND


556  int[] possibleFollowStates;


557  int nFs;


558  automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);


559  while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {


560  // no alternatives > just go to the next state


561  automaton.Goto(possibleFollowStates[0]);


562  automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);


563  }


564  if (nFs == 0) {


565  // stuck in a dead end (no final state and no allowed follow states)


566  tree.Done = true;


567  break;


568  }


569  var newChildren = new List<Tree>(nFs);


570  children.Add(tree, newChildren);


571  for (int i = 0; i < nFs; i++) {


572  Tree child = null;


573  // for selected states (EvalStates) we introduce state unification (detection of equivalent states)


574  if (automaton.IsEvalState(possibleFollowStates[i])) {


575  var hc = Hashcode(automaton);


576  if (!nodes.TryGetValue(hc, out child)) {


577  child = new Tree() {


578  children = null,


579  state = possibleFollowStates[i],


580  actionStatistics = treePolicy.CreateActionStatistics(),


581  expr = string.Empty, // ExprStr(automaton),


582  level = tree.level + 1


583  };


584  nodes.Add(hc, child);


585  }


586  // only allow forward edges (don't add the child if we would go back in the graph)


587  else if (child.level > tree.level) {


588  // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link


589  // to all parents


590  BackpropagateStatistics(child.actionStatistics, tree);


591  } else {


592  // prevent cycles


593  Debug.Assert(child.level <= tree.level);


594  child = null;


595  }


596  } else {


597  child = new Tree() {


598  children = null,


599  state = possibleFollowStates[i],


600  actionStatistics = treePolicy.CreateActionStatistics(),


601  expr = string.Empty, // ExprStr(automaton),


602  level = tree.level + 1


603  };


604  }


605  if (child != null)


606  newChildren.Add(child);


607  }


608 


609  if (!newChildren.Any()) {


610  // stuck in a dead end (no final state and no allowed follow states)


611  tree.Done = true;


612  break;


613  }


614 


615  foreach (var ch in newChildren) {


616  if (!parents.ContainsKey(ch)) {


617  parents.Add(ch, new List<Tree>());


618  }


619  parents[ch].Add(tree);


620  }


621 


622 


623  // follow one of the children


624  tree = SelectFinalOrRandom2(automaton, tree, rand);


625  automaton.Goto(tree.state);


626  }


627  }


628 


629  bool success;


630 


631  // EVALUATE TREE


632  if (automaton.IsFinalState(automaton.CurrentState)) {


633  tree.Done = true;


634  tree.expr = ExprStr(automaton);


635  byte[] code; int nParams;


636  automaton.GetCode(out code, out nParams);


637  q = eval(code, nParams);


638  q = TransformQuality(q);


639  success = true;


640  } else {


641  // we got stuck in rollout (not evaluation necessary!)


642  q = 0.0;


643  success = false;


644  }


645 


646  // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS


647  // Update statistics


648  // Set branch to done if all children are done.


649  BackpropagateQuality(tree, q, treePolicy);


650 


651  return success;


652  }


653 


654 


655  private static double TransformQuality(double q) {


656  // no transformation


657  return q;


658 


659  // EXPERIMENTAL!


660  // optimal result: q = 1 > return huge value


661  // if (q >= 1.0) return 1E16;


662  // // return number of 9s in R²


663  // return Math.Log10(1  q);


664  }


665 


666  // backpropagate existing statistics to all parents


667  private static void BackpropagateStatistics(IActionStatistics stats, Tree tree) {


668  tree.actionStatistics.Add(stats);


669  if (parents.ContainsKey(tree)) {


670  foreach (var parent in parents[tree]) {


671  BackpropagateStatistics(stats, parent);


672  }


673  }


674  }


675 


676  private static ulong Hashcode(Automaton automaton) {


677  byte[] code;


678  int nParams;


679  automaton.GetCode(out code, out nParams);


680  return ExprHash.GetHash(code, nParams);


681  }


682 


683  private static void BackpropagateQuality(Tree tree, double q, IPolicy policy) {


684  if (q > 0) policy.Update(tree.actionStatistics, q);


685  if (children.ContainsKey(tree) && children[tree].All(ch => ch.Done)) {


686  tree.Done = true;


687  // children[tree] = null; keep all nodes


688  }


689 


690  if (parents.ContainsKey(tree)) {


691  foreach (var parent in parents[tree]) {


692  BackpropagateQuality(parent, q, policy);


693  }


694  }


695  }


696 


697  private static Tree SelectFinalOrRandom2(Automaton automaton, Tree tree, IRandom rand) {


698  // if one of the new children leads to a final state then go there


699  // otherwise choose a random child


700  int selectedChildIdx = 1;


701  // find first final state if there is one


702  var children = MctsSymbolicRegressionStatic.children[tree];


703  for (int i = 0; i < children.Count; i++) {


704  if (automaton.IsFinalState(children[i].state)) {


705  selectedChildIdx = i;


706  break;


707  }


708  }


709  // no final state > select the first child


710  if (selectedChildIdx == 1) {


711  selectedChildIdx = 0;


712  }


713  return children[selectedChildIdx];


714  }


715 


716  // tree search might fail because of constraints for expressions


717  // in this case we get stuck we just restart


718  // see ConstraintHandler.cs for more info


719  private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,


720  out double q) {


721  Tree selectedChild = null;


722  Contract.Assert(tree.state == automaton.CurrentState);


723  Contract.Assert(!tree.Done);


724  if (tree.children == null) {


725  if (automaton.IsFinalState(tree.state)) {


726  // final state


727  tree.Done = true;


728 


729  // EVALUATE


730  byte[] code; int nParams;


731  automaton.GetCode(out code, out nParams);


732  q = eval(code, nParams);


733 


734  treePolicy.Update(tree.actionStatistics, q);


735  return true; // we reached a final state


736  } else {


737  // EXPAND


738  int[] possibleFollowStates;


739  int nFs;


740  automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);


741  if (nFs == 0) {


742  // stuck in a dead end (no final state and no allowed follow states)


743  q = 0;


744  tree.Done = true;


745  tree.children = null;


746  return false;


747  }


748  tree.children = new Tree[nFs];


749  for (int i = 0; i < tree.children.Length; i++)


750  tree.children[i] = new Tree() {


751  children = null,


752  state = possibleFollowStates[i],


753  actionStatistics = treePolicy.CreateActionStatistics()


754  };


755 


756  selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];


757  }


758  } else {


759  // tree.children != null


760  // UCT selection within tree


761  int selectedIdx = 0;


762  if (tree.children.Length > 1) {


763  selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);


764  }


765  selectedChild = tree.children[selectedIdx];


766  }


767  // make selected step and recurse


768  automaton.Goto(selectedChild.state);


769  var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);


770  if (success) {


771  // only update if successful


772  treePolicy.Update(tree.actionStatistics, q);


773  }


774 


775  tree.Done = tree.children.All(ch => ch.Done);


776  if (tree.Done) {


777  tree.children = null; // cut off the subbranch if it has been fully explored


778  }


779  return success;


780  }


781 


782  private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {


783  // if one of the new children leads to a final state then go there


784  // otherwise choose a random child


785  int selectedChildIdx = 1;


786  // find first final state if there is one


787  for (int i = 0; i < tree.children.Length; i++) {


788  if (automaton.IsFinalState(tree.children[i].state)) {


789  selectedChildIdx = i;


790  break;


791  }


792  }


793  // no final state > select a the first child


794  if (selectedChildIdx == 1) {


795  selectedChildIdx = 0;


796  }


797  return tree.children[selectedChildIdx];


798  }


799 


800  // scales data and extracts values from dataset into arrays


801  private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,


802  out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {


803  xs = new double[problemData.AllowedInputVariables.Count()][];


804 


805  var i = 0;


806  if (scaleVariables) {


807  scalingFactor = new double[xs.Length + 1];


808  scalingOffset = new double[xs.Length + 1];


809  } else {


810  scalingFactor = null;


811  scalingOffset = null;


812  }


813  foreach (var var in problemData.AllowedInputVariables) {


814  if (scaleVariables) {


815  var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();


816  var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();


817  var range = maxX  minX;


818 


819  // scaledX = (x  min) / range


820  var sf = 1.0 / range;


821  var offset = minX / range;


822  scalingFactor[i] = sf;


823  scalingOffset[i] = offset;


824  i++;


825  }


826  }


827 


828  if (scaleVariables) {


829  // transform target variable to zeromean


830  scalingFactor[i] = 1.0;


831  scalingOffset[i] = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Average();


832  }


833 


834  GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);


835  }


836 


837  // extract values from dataset into arrays


838  private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,


839  out double[][] xs, out double[] y) {


840  xs = new double[problemData.AllowedInputVariables.Count()][];


841 


842  int i = 0;


843  foreach (var var in problemData.AllowedInputVariables) {


844  var sf = scalingFactor == null ? 1.0 : scalingFactor[i];


845  var offset = scalingFactor == null ? 0.0 : scalingOffset[i];


846  xs[i++] =


847  problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();


848  }


849 


850  {


851  var sf = scalingFactor == null ? 1.0 : scalingFactor[i];


852  var offset = scalingFactor == null ? 0.0 : scalingOffset[i];


853  y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(yi => yi * sf + offset).ToArray();


854  }


855  }


856 


857  // for debugging only


858 


859 


860  private static string ExprStr(Automaton automaton) {


861  byte[] code;


862  int nParams;


863  automaton.GetCode(out code, out nParams);


864  return Disassembler.CodeToString(code);


865  }


866 


867  private static string WriteStatistics(Tree tree) {


868  var sb = new System.IO.StringWriter();


869  sb.WriteLine("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);


870  if (children.ContainsKey(tree)) {


871  foreach (var ch in children[tree]) {


872  sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);


873  }


874  }


875  return sb.ToString();


876  }


877 


878  private static string TraceTree(Tree tree) {


879  var sb = new StringBuilder();


880  sb.Append(


881  @"digraph {


882  ratio = fill;


883  node [style=filled];


884  ");


885  int nodeId = 0;


886 


887  TraceTreeRec(tree, 0, sb, ref nodeId);


888  sb.Append("}");


889  return sb.ToString();


890  }


891 


892  private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId) {


893  var avgNodeQ = tree.actionStatistics.AverageQuality;


894  var tries = tree.actionStatistics.Tries;


895  if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;


896  var hue = (1  avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue


897 


898  sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine();


899 


900  var list = new List<Tuple<int, int, Tree>>();


901  if (children.ContainsKey(tree)) {


902  foreach (var ch in children[tree]) {


903  nextId++;


904  avgNodeQ = ch.actionStatistics.AverageQuality;


905  tries = ch.actionStatistics.Tries;


906  if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;


907  hue = (1  avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue


908  sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();


909  sb.AppendFormat("{0} > {1}", parentId, nextId, avgNodeQ).AppendLine();


910  list.Add(Tuple.Create(tries, nextId, ch));


911  }


912  foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {


913  TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId);


914  }


915  }


916  }


917 


918  private static string WriteTree(Tree tree) {


919  var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);


920  var nodeIds = new Dictionary<Tree, int>();


921  sb.Write(


922  @"digraph {


923  ratio = fill;


924  node [style=filled];


925  ");


926  int threshold = nodes.Count > 500 ? 10 : 0;


927  foreach (var kvp in children) {


928  var parent = kvp.Key;


929  int parentId;


930  if (!nodeIds.TryGetValue(parent, out parentId)) {


931  parentId = nodeIds.Count + 1;


932  var avgNodeQ = parent.actionStatistics.AverageQuality;


933  var tries = parent.actionStatistics.Tries;


934  if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;


935  var hue = (1  avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue


936  if (parent.actionStatistics.Tries > threshold)


937  sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);


938  nodeIds.Add(parent, parentId);


939  }


940  foreach (var child in kvp.Value) {


941  int childId;


942  if (!nodeIds.TryGetValue(child, out childId)) {


943  childId = nodeIds.Count + 1;


944  nodeIds.Add(child, childId);


945  }


946  var avgNodeQ = child.actionStatistics.AverageQuality;


947  var tries = child.actionStatistics.Tries;


948  if (tries < 1) continue;


949  if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;


950  var hue = (1  avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue


951  if (tries > threshold) {


952  sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);


953  var edgeLabel = child.expr;


954  // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");


955  sb.Write("{0} > {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel);


956  }


957  }


958  }


959 


960  sb.Write("}");


961  return sb.ToString();


962  }


963  }


964  }

