1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  *


5  * This file is part of HeuristicLab.


6  *


7  * HeuristicLab is free software: you can redistribute it and/or modify


8  * it under the terms of the GNU General Public License as published by


9  * the Free Software Foundation, either version 3 of the License, or


10  * (at your option) any later version.


11  *


12  * HeuristicLab is distributed in the hope that it will be useful,


13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


15  * GNU General Public License for more details.


16  *


17  * You should have received a copy of the GNU General Public License


18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


19  */


20  #endregion


21 


22  using System;


23  using System.Collections.Generic;


24  using System.Diagnostics;


25  using System.Linq;


26  using System.Text;


27  using HeuristicLab.Core;


28  using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;


29  using HeuristicLab.Optimization;


30  using HeuristicLab.Problems.DataAnalysis;


31  using HeuristicLab.Problems.DataAnalysis.Symbolic;


32  using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;


33  using HeuristicLab.Random;


34 


35  namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {


36  public static class MctsSymbolicRegressionStatic {


37  // OBJECTIVES:


38  // 1) solve toy problems without numeric constants (to show that structure search is effective / efficient)


39  //  e.g. Keijzer, Nguyen ... where no numeric constants are involved


40  //  assumptions:


41  //  we don't know the necessary operations or functions > all available functions could be necessary


42  //  but we do not need to tune numeric constants > no scaling of input variables x!


43  // 2) Solve toy problems with numeric constants to make the algorithm invariant concerning variable scale.


44  // This is important for real world applications.


45  //  e.g. Korns or Vladislavleva problems where numeric constants are involved


46  //  assumptions:


47  //  any numeric constant is possible (apriori we might assume that small abs. constants are more likely)


48  //  standardization of variables is possible (or might be necessary) as we adjust numeric parameters of the expression anyway


49  //  to simplify the problem we can restrict the set of functions e.g. we assume which functions are necessary for the problem instance


50  // > several steps: (a) polyinomials, (b) rational polynomials, (c) exponential or logarithmic functions, rational functions with exponential and logarithmic parts


51  // 3) efficiency and effectiveness for realworld problems


52  //  e.g. Tower problem


53  //  (1) and (2) combined, structure search must be effective in combination with numeric optimization of constants


54  //


55 


56  // TODO: The samples of x1*... or x2*... do not give any information about the relevance of the interaction term x1*x2 in general!


57  // > E.g. if x1, x2 ~ N(0, 1) or U(1, 1) this is trivial to show


58  // > Therefore, looking at rollout statistics for arm selection is useless in the general case!


59  // > It is necessary to rely on other features for the arm selection.


60  // > TODO: Which heuristics can we apply?


61  // TODO: Solve Poly10


62  // TODO: rename everything as this is not MCTS anymore


63  // TODO: when a path to an expression is explored first (e.g. x1 + x2)


64  // and later we find the a longer form x1 + x1 + x2 where the number of variable references


65  // exceeds the maximum in the automaton this leads to an error (see unit tests)


66  // ~~obsolete TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?


67  // ~~obsolete TODO: Why is the algorithm so slow for rather greedy policies (e.g. low C value in UCB)?


68  // ~~obsolete TODO: check if we can use a quality measure with range [1..1] in policies


69  // TODO: unit tests for benchmark problems which contain log / exp / x^1 but without numeric constants


70  // TODO: check if transformation of y is correct and works (Obj 2)


71  // TODO: The algorithm is not invariant to location and scale of variables.


72  // Include offset for variables as parameter (for Objective 2)


73  // TODO: why does LM optimization converge so slowly with exp(x), log(x), and 1/x allowed (Obj 2)?


74  // TODO: support e(x) and possibly (1/x) (Obj 1)


75  // TODO: is it OK to initialize all constants to 1 (Obj 2)?


76  // TODO: improve memory usage


77  // TODO: analyze / improve perf of ExprHashing (canonical form for expressions)


78  // TODO: support empty test partition


79  // TODO: the algorithm should be invariant to linear transformations of the space (y = f(x') = f( Ax ) ) for invertible transformations A > unit tests


80  #region static API


81 


82  public interface IState {


83  bool Done { get; }


84  ISymbolicRegressionModel BestModel { get; }


85  double BestSolutionTrainingQuality { get; }


86  double BestSolutionTestQuality { get; }


87  IEnumerable<ISymbolicRegressionSolution> ParetoBestModels { get; }


88  int TotalRollouts { get; }


89  int EffectiveRollouts { get; }


90  int FuncEvaluations { get; }


91  int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations


92  // TODO other stats on LM optimizer might be interesting here


93  }


94 


95  // created through factory method


96  private class State : IState {


97  private const int MaxParams = 100;


98 


99  // state variables used by MCTS


100  internal readonly Automaton automaton;


101  internal IRandom random { get; private set; }


102  internal readonly Tree tree;


103  internal readonly Func<byte[], int, double> evalFun;


104  // MCTS might get stuck. Track statistics on the number of effective rollouts


105  internal int totalRollouts;


106  internal int effectiveRollouts;


107 


108 


109  // state variables used only internally (for eval function)


110  private readonly IRegressionProblemData problemData;


111  private readonly double[][] x;


112  private readonly double[] y;


113  private readonly double[][] testX;


114  private readonly double[] testY;


115  private readonly double[] scalingFactor;


116  private readonly double[] scalingOffset;


117  private readonly double yStdDev; // for scaling parameters (e.g. stopping condition for LM)


118  private readonly int constOptIterations;


119  private readonly double lambda; // weight of penalty term for regularization


120  private readonly double lowerEstimationLimit, upperEstimationLimit;


121  private readonly bool collectParetoOptimalModels;


122  private readonly List<ISymbolicRegressionSolution> paretoBestModels = new List<ISymbolicRegressionSolution>();


123  private readonly List<double[]> paretoFront = new List<double[]>(); // matching the models


124 


125  private readonly ExpressionEvaluator evaluator, testEvaluator;


126 


127  internal readonly Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();


128  internal readonly Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();


129  internal readonly Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();


130 


131  // values for best solution


132  private double bestR;


133  private byte[] bestCode;


134  private int bestNParams;


135  private double[] bestConsts;


136 


137  // stats


138  private int funcEvaluations;


139  private int gradEvaluations;


140 


141  // buffers


142  private readonly double[] ones; // vector of ones (as default params)


143  private readonly double[] constsBuf;


144  private readonly double[] predBuf, testPredBuf;


145  private readonly double[][] gradBuf;


146 


147  public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,


148  int constOptIterations, double lambda,


149  bool collectParetoOptimalModels = false,


150  double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,


151  bool allowProdOfVars = true,


152  bool allowExp = true,


153  bool allowLog = true,


154  bool allowInv = true,


155  bool allowMultipleTerms = false) {


156 


157  if (lambda < 0) throw new ArgumentException("Lambda must be larger or equal zero", "lambda");


158 


159  this.problemData = problemData;


160  this.constOptIterations = constOptIterations;


161  this.lambda = lambda;


162  this.evalFun = this.Eval;


163  this.lowerEstimationLimit = lowerEstimationLimit;


164  this.upperEstimationLimit = upperEstimationLimit;


165  this.collectParetoOptimalModels = collectParetoOptimalModels;


166 


167  random = new MersenneTwister(randSeed);


168 


169  // prepare data for evaluation


170  double[][] x;


171  double[] y;


172  double[][] testX;


173  double[] testY;


174  double[] scalingFactor;


175  double[] scalingOffset;


176  // get training and test datasets (scale linearly based on training set if required)


177  GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);


178  GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);


179  this.x = x;


180  this.y = y;


181  this.yStdDev = HeuristicLab.Common.EnumerableStatisticExtensions.StandardDeviation(y);


182  this.testX = testX;


183  this.testY = testY;


184  this.scalingFactor = scalingFactor;


185  this.scalingOffset = scalingOffset;


186  this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);


187  // we need a separate evaluator because the vector length for the test dataset might differ


188  this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);


189 


190  this.automaton = new Automaton(x, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms, maxVariables);


191  this.tree = new Tree() {


192  state = automaton.CurrentState,


193  expr = "",


194  level = 0


195  };


196 


197  // reset best solution


198  this.bestR = 0;


199  // code for default solution (constant model)


200  this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };


201  this.bestNParams = 0;


202  this.bestConsts = null;


203 


204  // init buffers


205  this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();


206  constsBuf = new double[MaxParams];


207  this.predBuf = new double[y.Length];


208  this.testPredBuf = new double[testY.Length];


209 


210  this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();


211  }


212 


213  #region IState inferface


214  public bool Done { get { return tree != null && tree.Done; } }


215 


216  public double BestSolutionTrainingQuality {


217  get {


218  evaluator.Exec(bestCode, x, bestConsts, predBuf);


219  return Rho(y, predBuf);


220  }


221  }


222 


223  public double BestSolutionTestQuality {


224  get {


225  testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);


226  return Rho(testY, testPredBuf);


227  }


228  }


229 


230  // takes the code of the best solution and creates and equivalent symbolic regression model


231  public ISymbolicRegressionModel BestModel {


232  get {


233  var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());


234  var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();


235 


236  var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));


237  var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);


238  model.Scale(problemData); // apply linear scaling


239  return model;


240  }


241  }


242  public IEnumerable<ISymbolicRegressionSolution> ParetoBestModels {


243  get { return paretoBestModels; }


244  }


245 


246  public int TotalRollouts { get { return totalRollouts; } }


247  public int EffectiveRollouts { get { return effectiveRollouts; } }


248  public int FuncEvaluations { get { return funcEvaluations; } }


249  public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations


250 


251  #endregion


252 


253 


254  #if DEBUG


255  public string ExprStr(Automaton automaton) {


256  byte[] code;


257  int nParams;


258  automaton.GetCode(out code, out nParams);


259  var generator = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());


260  var @params = Enumerable.Repeat(1.0, nParams).ToArray();


261  var root = generator.Exec(code, @params, nParams, null, null);


262  var formatter = new InfixExpressionFormatter();


263  return formatter.Format(new SymbolicExpressionTree(root));


264  }


265  #endif


266 


267  private double Eval(byte[] code, int nParams) {


268  double[] optConsts;


269  double q;


270  Eval(code, nParams, out q, out optConsts);


271 


272  // single objective best


273  if (q > bestR) {


274  bestR = q;


275  bestNParams = nParams;


276  this.bestCode = new byte[code.Length];


277  this.bestConsts = new double[bestNParams];


278 


279  Array.Copy(code, bestCode, code.Length);


280  Array.Copy(optConsts, bestConsts, bestNParams);


281  }


282  if (collectParetoOptimalModels) {


283  // multiobjective best


284  var complexity = // SymbolicDataAnalysisModelComplexityCalculator.CalculateComplexity() TODO: implement Kommenda's tree complexity directly in the evaluator


285  Array.FindIndex(code, (opc) => opc == (byte)OpCodes.Exit); // use length of expression as surrogate for complexity


286  UpdateParetoFront(q, complexity, code, optConsts, nParams, scalingFactor, scalingOffset);


287  }


288  return q;


289  }


290 


291  private void Eval(byte[] code, int nParams, out double rho, out double[] optConsts) {


292  // we make a first pass to determine a valid starting configuration for all constants


293  // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)


294  // scale and offset are set to optimal starting configuration


295  // assumes scale is the first param and offset is the last param


296 


297  // reset constants


298  Array.Copy(ones, constsBuf, nParams);


299  evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);


300  funcEvaluations++;


301 


302  if (nParams == 0  constOptIterations < 0) {


303  // if we don't need to optimize parameters then we are done


304  // changing scale and offset does not influence r²


305  rho = Rho(y, predBuf);


306  optConsts = constsBuf;


307  } else {


308  // optimize constants using the starting point calculated above


309  OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);


310 


311  evaluator.Exec(code, x, constsBuf, predBuf);


312  funcEvaluations++;


313 


314  rho = Rho(y, predBuf);


315  optConsts = constsBuf;


316  }


317  }


318 


319 


320 


321  #region helpers


322  private static double Rho(IEnumerable<double> x, IEnumerable<double> y) {


323  OnlineCalculatorError error;


324  double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);


325  return error == OnlineCalculatorError.None ? r : 0.0;


326  }


327 


328 


329  private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {


330  double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)


331  Array.Copy(consts, optConsts, nParams);


332 


333  // direct usage of LM is recommended in alglib manual for better performance than the lsfit interface (which uses lm internally).


334  alglib.minlmstate state;


335  alglib.minlmreport rep = null;


336  alglib.minlmcreatevj(y.Length + 1, optConsts, out state); // +1 for penalty term


337  // Using the change of the gradient as stopping criterion is recommended in alglib manual.


338  // However, the most recent version of alglib (as of Oct 2017) only supports epsX as stopping criterion


339  alglib.minlmsetcond(state, epsg: 1E6 * yStdDev, epsf: epsF, epsx: 0.0, maxits: nIters);


340  // alglib.minlmsetgradientcheck(state, 1E5);


341  alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);


342  alglib.minlmresults(state, out optConsts, out rep);


343  funcEvaluations += rep.nfunc;


344  gradEvaluations += rep.njac * nParams;


345 


346  if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);


347 


348  // only use optimized constants if successful


349  if (rep.terminationtype >= 0) {


350  Array.Copy(optConsts, consts, optConsts.Length);


351  }


352  }


353 


354  private void Func(double[] arg, double[] fi, object obj) {


355  var code = (byte[])obj;


356  int n = predBuf.Length;


357  evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen


358  for (int r = 0; r < n; r++) {


359  var res = predBuf[r]  y[r];


360  fi[r] = res;


361  }


362 


363  var penaltyIdx = fi.Length  1;


364  fi[penaltyIdx] = 0.0;


365  // calc length of parameter vector for regularization


366  var aa = 0.0;


367  for (int i = 0; i < arg.Length; i++) {


368  aa += arg[i] * arg[i];


369  }


370  if (lambda > 0 && aa > 0) {


371  // scale lambda using stdDev(y) to make the parameter independent of the scale of y


372  // scale lambda using n to make parameter independent of the number of training points


373  // take the root because LM squares the result


374  fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);


375  }


376  }


377 


378  private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {


379  int n = predBuf.Length;


380  int nParams = arg.Length;


381  var code = (byte[])obj;


382  evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen


383  for (int r = 0; r < n; r++) {


384  var res = predBuf[r]  y[r];


385  fi[r] = res;


386 


387  for (int k = 0; k < nParams; k++) {


388  jac[r, k] = gradBuf[k][r];


389  }


390  }


391  // calc length of parameter vector for regularization


392  double aa = 0.0;


393  for (int i = 0; i < arg.Length; i++) {


394  aa += arg[i] * arg[i];


395  }


396 


397  var penaltyIdx = fi.Length  1;


398  if (lambda > 0 && aa > 0) {


399  fi[penaltyIdx] = 0.0;


400  // scale lambda using stdDev(y) to make the parameter independent of the scale of y


401  // scale lambda using n to make parameter independent of the number of training points


402  // take the root because alglib LM squares the result


403  fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);


404 


405  for (int i = 0; i < arg.Length; i++) {


406  jac[penaltyIdx, i] = 0.5 / fi[penaltyIdx] * 2 * n * lambda / yStdDev * arg[i];


407  }


408  } else {


409  fi[penaltyIdx] = 0.0;


410  for (int i = 0; i < arg.Length; i++) {


411  jac[penaltyIdx, i] = 0.0;


412  }


413  }


414  }


415 


416 


417  private void UpdateParetoFront(double q, int complexity, byte[] code, double[] param, int nParam,


418  double[] scalingFactor, double[] scalingOffset) {


419  double[] best = new double[2];


420  double[] cur = new double[2] { q, complexity };


421  bool[] max = new[] { true, false };


422  var isNonDominated = true;


423  foreach (var e in paretoFront) {


424  var domRes = DominationCalculator<int>.Dominates(cur, e, max, true);


425  if (domRes == DominationResult.IsDominated) {


426  isNonDominated = false;


427  break;


428  }


429  }


430  if (isNonDominated) {


431  paretoFront.Add(cur);


432 


433  // create model


434  var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());


435  var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();


436 


437  var t = new SymbolicExpressionTree(treeGen.Exec(code, param, nParam, scalingFactor, scalingOffset));


438  var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);


439  model.Scale(problemData); // apply linear scaling


440 


441  var sol = model.CreateRegressionSolution(this.problemData);


442  sol.Name = string.Format("{0:N5} {1}", q, complexity);


443 


444  paretoBestModels.Add(sol);


445  }


446  for (int i = paretoFront.Count  2; i >= 0; i) {


447  var @ref = paretoFront[i];


448  var domRes = DominationCalculator<int>.Dominates(cur, @ref, max, true);


449  if (domRes == DominationResult.Dominates) {


450  paretoFront.RemoveAt(i);


451  paretoBestModels.RemoveAt(i);


452  }


453  }


454  }


455 


456  #endregion


457 


458 


459  }


460 


461 


462  /// <summary>


463  /// Static method to initialize a state for the algorithm


464  /// </summary>


465  /// <param name="problemData">The problem data</param>


466  /// <param name="randSeed">Random seed.</param>


467  /// <param name="maxVariables">Maximum number of variable references that are allowed in the expression.</param>


468  /// <param name="scaleVariables">Optionally scale input variables to the interval [0..1] (recommended)</param>


469  /// <param name="constOptIterations">Maximum number of iterations for constants optimization (LevenbergMarquardt)</param>


470  /// <param name="lambda">Penalty factor for regularization (0..inf.), small penalty disabled regularization.</param>


471  /// <param name="policy">Tree search policy (random, ucb, epsgreedy, ...)</param>


472  /// <param name="collectParameterOptimalModels">Optionally collect all Paretooptimal solutions having minimal length and error.</param>


473  /// <param name="lowerEstimationLimit">Optionally limit the result of the expression to this lower value.</param>


474  /// <param name="upperEstimationLimit">Optionally limit the result of the expression to this upper value.</param>


475  /// <param name="allowProdOfVars">Allow products of expressions.</param>


476  /// <param name="allowExp">Allow expressions with exponentials.</param>


477  /// <param name="allowLog">Allow expressions with logarithms</param>


478  /// <param name="allowInv">Allow expressions with 1/x</param>


479  /// <param name="allowMultipleTerms">Allow expressions which are sums of multiple terms.</param>


480  /// <returns></returns>


481 


482  public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,


483  bool scaleVariables = true, int constOptIterations = 1, double lambda = 0.0,


484  bool collectParameterOptimalModels = false,


485  double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,


486  bool allowProdOfVars = true,


487  bool allowExp = true,


488  bool allowLog = true,


489  bool allowInv = true,


490  bool allowMultipleTerms = false


491  ) {


492  return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,


493  collectParameterOptimalModels,


494  lowerEstimationLimit, upperEstimationLimit,


495  allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);


496  }


497 


498  // returns the quality of the evaluated solution


499  public static double MakeStep(IState state) {


500  var mctsState = state as State;


501  if (mctsState == null) throw new ArgumentException("state");


502  if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");


503 


504  return TreeSearch(mctsState);


505  }


506  #endregion


507 


508  private static double TreeSearch(State mctsState) {


509  var automaton = mctsState.automaton;


510  var tree = mctsState.tree;


511  var eval = mctsState.evalFun;


512  var rand = mctsState.random;


513  double q = 0;


514  bool success = false;


515  do {


516 


517  automaton.Reset();


518  success = TryTreeSearchRec2(rand, tree, automaton, eval, mctsState, out q);


519  mctsState.totalRollouts++;


520  } while (!success && !tree.Done);


521  if (success) {


522  mctsState.effectiveRollouts++;


523 


524  #if DEBUG


525  Console.WriteLine(mctsState.ExprStr(automaton));


526  #endif


527 


528  return q;


529  } else return 0.0;


530  }


531 


532  // search forward


533  private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton,


534  Func<byte[], int, double> eval,


535  State state,


536  out double q) {


537  // ROLLOUT AND EXPANSION


538  // We are navigating a graph (states might be reached via different paths) instead of a tree.


539  // State equivalence is checked through ExprHash (based on the generated code through the path).


540 


541  // We switch between rolloutmode and expansion mode


542  // Rolloutmode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB)


543  // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression)


544  // In expansion mode we might reenter the graph and switch back to rolloutmode


545  // We do this until we reach a complete expression (final state)


546 


547  // Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent


548  // Subgraphs which have been completely searched are marked as done.


549  // Rollout could lead to a state where all followstates are done. In this case we call the rollout ineffective.


550 


551  while (!automaton.IsFinalState(automaton.CurrentState)) {


552  // Console.WriteLine(automaton.stateNames[automaton.CurrentState]);


553  if (state.children.ContainsKey(tree)) {


554  if (state.children[tree].All(ch => ch.Done)) {


555  tree.Done = true;


556  break;


557  }


558  // ROLLOUT INSIDE TREE


559  // UCT selection within tree


560  int selectedIdx = 0;


561  if (state.children[tree].Count > 1) {


562  selectedIdx = SelectInternal(state.children[tree], rand);


563  }


564 


565  tree = state.children[tree][selectedIdx];


566 


567  // all steps where no alternatives could be taken immediately (without expanding the tree)


568  // TODO: simplification of the automaton


569  int[] possibleFollowStates = new int[1000];


570  int nFs;


571  automaton.FollowStates(automaton.CurrentState, ref possibleFollowStates, out nFs);


572  Debug.Assert(possibleFollowStates.Contains(tree.state));


573  automaton.Goto(tree.state);


574  } else {


575  // EXPAND


576  int[] possibleFollowStates = new int[1000];


577  int nFs;


578  string actionString = "";


579  automaton.FollowStates(automaton.CurrentState, ref possibleFollowStates, out nFs);


580 


581  if (nFs == 0) {


582  // stuck in a dead end (no final state and no allowed follow states)


583  tree.Done = true;


584  break;


585  }


586  var newChildren = new List<Tree>(nFs);


587  state.children.Add(tree, newChildren);


588  for (int i = 0; i < nFs; i++) {


589  Tree child = null;


590  // for selected states (EvalStates) we introduce state unification (detection of equivalent states)


591  if (automaton.IsEvalState(possibleFollowStates[i])) {


592  var hc = Hashcode(automaton);


593  hc = ((hc << 5) + hc) ^ (ulong)tree.state; // TODO fix unit test for structure enumeration


594  if (!state.nodes.TryGetValue(hc, out child)) {


595  // Console.WriteLine("New expression (hash: {0}, state: {1})", Hashcode(automaton), automaton.stateNames[possibleFollowStates[i]]);


596  child = new Tree() {


597  state = possibleFollowStates[i],


598  expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),


599  level = tree.level + 1


600  };


601  state.nodes.Add(hc, child);


602  }


603  // only allow forward edges (don't add the child if we would go back in the graph)


604  else if (child.level > tree.level) {


605  // Console.WriteLine("Existing expression (hash: {0}, state: {1})", Hashcode(automaton), automaton.stateNames[possibleFollowStates[i]]);


606  // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link


607  // to all parents


608  BackpropagateStatistics(tree, state, child.visits);


609  } else {


610  // Console.WriteLine("Cycle (hash: {0}, state: {1})", Hashcode(automaton), automaton.stateNames[possibleFollowStates[i]]);


611  // prevent cycles


612  Debug.Assert(child.level <= tree.level);


613  child = null;


614  }


615  } else {


616  child = new Tree() {


617  state = possibleFollowStates[i],


618  expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),


619  level = tree.level + 1


620  };


621  }


622  if (child != null)


623  newChildren.Add(child);


624  }


625 


626  if (!newChildren.Any()) {


627  // stuck in a dead end (no final state and no allowed follow states)


628  tree.Done = true;


629  break;


630  }


631 


632  foreach (var ch in newChildren) {


633  if (!state.parents.ContainsKey(ch)) {


634  state.parents.Add(ch, new List<Tree>());


635  }


636  state.parents[ch].Add(tree);


637  }


638 


639 


640  // follow one of the children


641  tree = SelectStateLeadingToFinal(automaton, tree, rand, state);


642  automaton.Goto(tree.state);


643  }


644  }


645 


646  bool success;


647 


648  // EVALUATE TREE


649  if (!tree.Done && automaton.IsFinalState(automaton.CurrentState)) {


650  tree.Done = true;


651  tree.expr = state.ExprStr(automaton);


652  byte[] code; int nParams;


653  automaton.GetCode(out code, out nParams);


654  q = eval(code, nParams);


655  success = true;


656  BackpropagateQuality(tree, q, state);


657  } else {


658  // we got stuck in rollout (not evaluation necessary!)


659  q = 0.0;


660  success = false;


661  }


662 


663  // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS


664  // Update statistics


665  // Set branch to done if all children are done.


666  BackpropagateDone(tree, state);


667  BackpropagateDebugStats(tree, q, state);


668 


669 


670  return success;


671  }


672 


673  private static int SelectInternal(List<Tree> list, IRandom rand) {


674  Debug.Assert(list.Any(t => !t.Done));


675 


676  // check if there is any node which has not been visited


677  for(int i=0;i<list.Count;i++) {


678  if (!list[i].Done && list[i].visits == 0) return i;


679  }


680 


681  // choose a random node.


682  var idx = rand.Next(list.Count);


683  while (list[idx].Done) { idx = rand.Next(list.Count); }


684  return idx;


685  }


686 


687  // backpropagate existing statistics to all parents


688  private static void BackpropagateStatistics(Tree tree, State state, int numVisits) {


689  tree.visits += numVisits;


690 


691  if (state.parents.ContainsKey(tree)) {


692  foreach (var parent in state.parents[tree]) {


693  BackpropagateStatistics(parent, state, numVisits);


694  }


695  }


696  }


697 


698  private static ulong Hashcode(Automaton automaton) {


699  byte[] code;


700  int nParams;


701  automaton.GetCode(out code, out nParams);


702  return (ulong)ExprHashSymbolic.GetHash(code, nParams);


703  }


704 


705  private static void BackpropagateQuality(Tree tree, double q, State state) {


706  tree.visits++;


707  // TODO: q is ignored for now


708 


709  if (state.parents.ContainsKey(tree)) {


710  foreach (var parent in state.parents[tree]) {


711  BackpropagateQuality(parent, q, state);


712  }


713  }


714  }


715 


716  private static void BackpropagateDone(Tree tree, State state) {


717  if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) {


718  tree.Done = true;


719  // children[tree] = null; keep all nodes


720  }


721 


722  if (state.parents.ContainsKey(tree)) {


723  foreach (var parent in state.parents[tree]) {


724  BackpropagateDone(parent, state);


725  }


726  }


727  }


728 


729  private static void BackpropagateDebugStats(Tree tree, double q, State state) {


730  if (state.parents.ContainsKey(tree)) {


731  foreach (var parent in state.parents[tree]) {


732  BackpropagateDebugStats(parent, q, state);


733  }


734  }


735 


736  }


737 


738  private static Tree SelectStateLeadingToFinal(Automaton automaton, Tree tree, IRandom rand, State state) {


739  // find the child with the smallest state value (smaller values are closer to the final state)


740  int selectedChildIdx = 0;


741  var children = state.children[tree];


742  Tree minChild = children.First();


743  for (int i = 1; i < children.Count; i++) {


744  if (children[i].state < minChild.state)


745  selectedChildIdx = i;


746  }


747  return children[selectedChildIdx];


748  }


749 


750  // scales data and extracts values from dataset into arrays


751  private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,


752  out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {


753  xs = new double[problemData.AllowedInputVariables.Count()][];


754 


755  var i = 0;


756  if (scaleVariables) {


757  scalingFactor = new double[xs.Length + 1];


758  scalingOffset = new double[xs.Length + 1];


759  } else {


760  scalingFactor = null;


761  scalingOffset = null;


762  }


763  foreach (var var in problemData.AllowedInputVariables) {


764  if (scaleVariables) {


765  var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();


766  var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();


767  var range = maxX  minX;


768 


769  // scaledX = (x  min) / range


770  var sf = 1.0 / range;


771  var offset = minX / range;


772  scalingFactor[i] = sf;


773  scalingOffset[i] = offset;


774  i++;


775  }


776  }


777 


778  if (scaleVariables) {


779  // transform target variable to zeromean


780  scalingFactor[i] = 1.0;


781  scalingOffset[i] = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Average();


782  }


783 


784  GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);


785  }


786 


787  // extract values from dataset into arrays


788  private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,


789  out double[][] xs, out double[] y) {


790  xs = new double[problemData.AllowedInputVariables.Count()][];


791 


792  int i = 0;


793  foreach (var var in problemData.AllowedInputVariables) {


794  var sf = scalingFactor == null ? 1.0 : scalingFactor[i];


795  var offset = scalingFactor == null ? 0.0 : scalingOffset[i];


796  xs[i++] =


797  problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();


798  }


799 


800  {


801  var sf = scalingFactor == null ? 1.0 : scalingFactor[i];


802  var offset = scalingFactor == null ? 0.0 : scalingOffset[i];


803  y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(yi => yi * sf + offset).ToArray();


804  }


805  }


806 


807  // for debugging only


808 


809 


810  private static string TraceTree(Tree tree, State state) {


811  var sb = new StringBuilder();


812  sb.Append(


813  @"digraph {


814  ratio = fill;


815  node [style=filled];


816  ");


817  int nodeId = 0;


818 


819  TraceTreeRec(tree, 0, sb, ref nodeId, state);


820  sb.Append("}");


821  return sb.ToString();


822  }


823 


824  private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) {


825  var tries = tree.visits;


826 


827  sb.AppendFormat("{0} [label=\"{1}\"]; ", parentId, tries).AppendLine();


828 


829  var list = new List<Tuple<int, int, Tree>>();


830  if (state.children.ContainsKey(tree)) {


831  foreach (var ch in state.children[tree]) {


832  nextId++;


833  tries = ch.visits;


834  sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine();


835  sb.AppendFormat("{0} > {1} [label=\"{2}\"]", parentId, nextId, ch.expr).AppendLine();


836  list.Add(Tuple.Create(tries, nextId, ch));


837  }


838 


839  foreach (var tup in list) {


840  var ch = tup.Item3;


841  var chId = tup.Item2;


842  if (state.children.ContainsKey(ch) && state.children[ch].Count == 1) {


843  var chch = state.children[ch].First();


844  nextId++;


845  tries = chch.visits;


846  sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine();


847  sb.AppendFormat("{0} > {1} [label=\"{2}\"]", chId, nextId, chch.expr).AppendLine();


848  }


849  }


850 


851  foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {


852  TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId, state);


853  }


854  }


855  }


856 


857  private static string WriteTree(Tree tree, State state) {


858  var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);


859  var nodeIds = new Dictionary<Tree, int>();


860  sb.Write(


861  @"digraph {


862  ratio = fill;


863  node [style=filled];


864  ");


865  int threshold = /* state.nodes.Count > 500 ? 10 : */ 0;


866  foreach (var kvp in state.children) {


867  var parent = kvp.Key;


868  int parentId;


869  if (!nodeIds.TryGetValue(parent, out parentId)) {


870  parentId = nodeIds.Count + 1;


871  var tries = parent.visits;


872  if (tries > threshold)


873  sb.Write("{0} [label=\"{1}\"]; ", parentId, tries);


874  nodeIds.Add(parent, parentId);


875  }


876  foreach (var child in kvp.Value) {


877  int childId;


878  if (!nodeIds.TryGetValue(child, out childId)) {


879  childId = nodeIds.Count + 1;


880  nodeIds.Add(child, childId);


881  }


882  var tries = child.visits;


883  if (tries < 1) continue;


884  if (tries > threshold) {


885  sb.Write("{0} [label=\"{1}\"]; ", childId, tries);


886  var edgeLabel = child.expr;


887  // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");


888  sb.Write("{0} > {1} [label=\"{2}\"]", parentId, childId, edgeLabel);


889  }


890  }


891  }


892 


893  sb.Write("}");


894  return sb.ToString();


895  }


896  }


897  }

