21 


22  using System;


23  using System.Collections.Generic;


24  using System.Diagnostics;


25  using System.Diagnostics.Contracts;


26  using System.Linq;


27  using System.Text;


28  using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;


29  using HeuristicLab.Core;


30  using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;


31  using HeuristicLab.Optimization;


32  using HeuristicLab.Problems.DataAnalysis;


33  using HeuristicLab.Problems.DataAnalysis.Symbolic;


34  using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;


35  using HeuristicLab.Random;


36 


37  namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {


38  public static class MctsSymbolicRegressionStatic {


39  // OBJECTIVES:


40  // 1) solve toy problems without numeric constants (to show that structure search is effective / efficient)


41  //  e.g. Keijzer, Nguyen ... where no numeric constants are involved


42  //  assumptions:


43  //  we don't know the necessary operations or functions > all available functions could be necessary


44  //  but we do not need to tune numeric constants > no scaling of input variables x!


45  // 2) Solve toy problems with numeric constants to make the algorithm invariant concerning variable scale.


46  // This is important for real world applications.


47  //  e.g. Korns or Vladislavleva problems where numeric constants are involved


48  //  assumptions:


49  //  any numeric constant is possible (apriori we might assume that small abs. constants are more likely)


50  //  standardization of variables is possible (or might be necessary) as we adjust numeric parameters of the expression anyway


51  //  to simplify the problem we can restrict the set of functions e.g. we assume which functions are necessary for the problem instance


52  // > several steps: (a) polyinomials, (b) rational polynomials, (c) exponential or logarithmic functions, rational functions with exponential and logarithmic parts


53  // 3) efficiency and effectiveness for realworld problems


54  //  e.g. Tower problem


55  //  (1) and (2) combined, structure search must be effective in combination with numeric optimization of constants


56  //


57 


58  // TODO: The samples of x1*... or x2*... do not give any information about the relevance of the interaction term x1*x2 in general!


59  // > E.g. if x1, x2 ~ N(0, 1) or U(1, 1) this is trivial to show


60  // > Therefore, looking at rollout statistics for arm selection is useless in the general case!


61  // > It is necessary to rely on other features for the arm selection.


62  // > TODO: Which heuristics can we apply?


63  // TODO: Solve Poly10


64  // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?


65  // TODO: Why is the algorithm so slow for rather greedy policies (e.g. low C value in UCB)?


66  // TODO: check if we can use a quality measure with range [1..1] in policies


67  // TODO: unit tests for benchmark problems which contain log / exp / x^1 but without numeric constants


68  // TODO: check if transformation of y is correct and works (Obj 2)


69  // TODO: The algorithm is not invariant to location and scale of variables.


70  // Include offset for variables as parameter (for Objective 2)


71  // TODO: why does LM optimization converge so slowly with exp(x), log(x), and 1/x allowed (Obj 2)?


72  // TODO: support e(x) and possibly (1/x) (Obj 1)


73  // TODO: is it OK to initialize all constants to 1 (Obj 2)?


74  // TODO: improve memory usage


75  // TODO: support empty test partition


76  // TODO: the algorithm should be invariant to linear transformations of the space (y = f(x') = f( Ax ) ) for invertible transformations A > unit tests


77  #region static API


78 


79  public interface IState {


80  bool Done { get; }


81  ISymbolicRegressionModel BestModel { get; }


82  double BestSolutionTrainingQuality { get; }


83  double BestSolutionTestQuality { get; }


84  IEnumerable<ISymbolicRegressionSolution> ParetoBestModels { get; }


85  int TotalRollouts { get; }


86  int EffectiveRollouts { get; }


87  int FuncEvaluations { get; }


88  int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations


89  // TODO other stats on LM optimizer might be interesting here


90  }


91 


92  // created through factory method


93  private class State : IState {


94  private const int MaxParams = 100;


95 


96  // state variables used by MCTS


97  internal readonly Automaton automaton;


98  internal IRandom random { get; private set; }


99  internal readonly Tree tree;


100  internal readonly Func<byte[], int, double> evalFun;


101  internal readonly IPolicy treePolicy;


102  // MCTS might get stuck. Track statistics on the number of effective rollouts


103  internal int totalRollouts;


104  internal int effectiveRollouts;


105 


106 


107  // state variables used only internally (for eval function)


108  private readonly IRegressionProblemData problemData;


109  private readonly double[][] x;


110  private readonly double[] y;


111  private readonly double[][] testX;


112  private readonly double[] testY;


113  private readonly double[] scalingFactor;


114  private readonly double[] scalingOffset;


115  private readonly double yStdDev; // for scaling parameters (e.g. stopping condition for LM)


116  private readonly int constOptIterations;


117  private readonly double lambda; // weight of penalty term for regularization


118  private readonly double lowerEstimationLimit, upperEstimationLimit;


119  private readonly bool collectParetoOptimalModels;


120  private readonly List<ISymbolicRegressionSolution> paretoBestModels = new List<ISymbolicRegressionSolution>();


121  private readonly List<double[]> paretoFront = new List<double[]>(); // matching the models


122 


123  private readonly ExpressionEvaluator evaluator, testEvaluator;


124 


125  internal readonly Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();


126  internal readonly Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();


127  internal readonly Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();


128 


129  // values for best solution


130  private double bestR;


131  private byte[] bestCode;


132  private int bestNParams;


133  private double[] bestConsts;


134 


135  // stats


136  private int funcEvaluations;


137  private int gradEvaluations;


138 


139  // buffers


140  private readonly double[] ones; // vector of ones (as default params)


141  private readonly double[] constsBuf;


142  private readonly double[] predBuf, testPredBuf;


143  private readonly double[][] gradBuf;


144 


145  // debugging stats


146  // calculate for each level the number of alternatives the average 'inequality' of tries and 'inequality' of quality over the alternatives for each trie


147  // inequality can be calculated using the Gini coefficient


148  internal readonly double[] pathGiniCoeffs = new double[100];


149  internal readonly double[] pathQs = new double[100];


150  internal readonly double[] levelBestQ = new double[100];


151  // internal readonly double[] levelMaxTries = new double[100];


152  internal readonly double[] pathBestQ = new double[100]; // as long as pathBestQs = levelBestQs we are following the correct path


153  internal readonly string[] levelBestAction = new string[100];


154  internal readonly string[] curAction = new string[100];


155  internal readonly double[] pathSelectedQ = new double[100];


156 


157  public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,


158  int constOptIterations, double lambda,


159  IPolicy treePolicy = null,


160  bool collectParetoOptimalModels = false,


161  double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,


162  bool allowProdOfVars = true,


163  bool allowExp = true,


164  bool allowLog = true,


165  bool allowInv = true,


166  bool allowMultipleTerms = false) {


167 


168  if (lambda < 0) throw new ArgumentException("Lambda must be larger or equal zero", "lambda");


169 


170  this.problemData = problemData;


171  this.constOptIterations = constOptIterations;


172  this.lambda = lambda;


173  this.evalFun = this.Eval;


174  this.lowerEstimationLimit = lowerEstimationLimit;


175  this.upperEstimationLimit = upperEstimationLimit;


176  this.collectParetoOptimalModels = collectParetoOptimalModels;


177 


178  random = new MersenneTwister(randSeed);


179 


180  // prepare data for evaluation


181  double[][] x;


182  double[] y;


183  double[][] testX;


184  double[] testY;


185  double[] scalingFactor;


186  double[] scalingOffset;


187  // get training and test datasets (scale linearly based on training set if required)


188  GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);


189  GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);


190  this.x = x;


191  this.y = y;


192  this.yStdDev = HeuristicLab.Common.EnumerableStatisticExtensions.StandardDeviation(y);


193  this.testX = testX;


194  this.testY = testY;


195  this.scalingFactor = scalingFactor;


196  this.scalingOffset = scalingOffset;


197  this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);


198  // we need a separate evaluator because the vector length for the test dataset might differ


199  this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);


200 


201  this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);


202  this.treePolicy = treePolicy ?? new Ucb();


203  this.tree = new Tree() {


204  state = automaton.CurrentState,


205  actionStatistics = treePolicy.CreateActionStatistics(),


206  expr = "",


207  level = 0


208  };


209 


210  // reset best solution


211  this.bestR = 0;


212  // code for default solution (constant model)


213  this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };


214  this.bestNParams = 0;


215  this.bestConsts = null;


216 


217  // init buffers


218  this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();


219  constsBuf = new double[MaxParams];


220  this.predBuf = new double[y.Length];


221  this.testPredBuf = new double[testY.Length];


222 


223  this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();


224  }


225 


226  #region IState inferface


227  public bool Done { get { return tree != null && tree.Done; } }


228 


229  public double BestSolutionTrainingQuality {


230  get {


231  evaluator.Exec(bestCode, x, bestConsts, predBuf);


232  return Rho(y, predBuf);


233  }


234  }


235 


236  public double BestSolutionTestQuality {


237  get {


238  testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);


239  return Rho(testY, testPredBuf);


240  }


241  }


242 


243  // takes the code of the best solution and creates and equivalent symbolic regression model


244  public ISymbolicRegressionModel BestModel {


245  get {


246  var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());


247  var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();


248 


249  var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));


250  var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);


251  model.Scale(problemData); // apply linear scaling


252  return model;


253  }


254  }


255  public IEnumerable<ISymbolicRegressionSolution> ParetoBestModels {


256  get { return paretoBestModels; }


257  }


258 


259  public int TotalRollouts { get { return totalRollouts; } }


260  public int EffectiveRollouts { get { return effectiveRollouts; } }


261  public int FuncEvaluations { get { return funcEvaluations; } }


262  public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations


263 


264  #endregion


265 


266  private double Eval(byte[] code, int nParams) {


267  double[] optConsts;


268  double q;


269  Eval(code, nParams, out q, out optConsts);


270 


271  // single objective best


272  if (q > bestR) {


273  bestR = q;


274  bestNParams = nParams;


275  this.bestCode = new byte[code.Length];


276  this.bestConsts = new double[bestNParams];


277 


278  Array.Copy(code, bestCode, code.Length);


279  Array.Copy(optConsts, bestConsts, bestNParams);


280  }


281  if (collectParetoOptimalModels) {


282  // multiobjective best


283  var complexity = // SymbolicDataAnalysisModelComplexityCalculator.CalculateComplexity() TODO: implement Kommenda's tree complexity directly in the evaluator


284  Array.FindIndex(code, (opc) => opc == (byte)OpCodes.Exit); // use length of expression as surrogate for complexity


285  UpdateParetoFront(q, complexity, code, optConsts, nParams, scalingFactor, scalingOffset);


286  }


287  return q;


288  }


289 


290  private void Eval(byte[] code, int nParams, out double rho, out double[] optConsts) {


291  // we make a first pass to determine a valid starting configuration for all constants


292  // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)


293  // scale and offset are set to optimal starting configuration


294  // assumes scale is the first param and offset is the last param


295 


296  // reset constants


297  Array.Copy(ones, constsBuf, nParams);


298  evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);


299  funcEvaluations++;


300 


301  if (nParams == 0  constOptIterations < 0) {


302  // if we don't need to optimize parameters then we are done


303  // changing scale and offset does not influence r²


304  rho = Rho(y, predBuf);


305  optConsts = constsBuf;


306  } else {


307  // optimize constants using the starting point calculated above


308  OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);


309 


310  evaluator.Exec(code, x, constsBuf, predBuf);


311  funcEvaluations++;


312 


313  rho = Rho(y, predBuf);


314  optConsts = constsBuf;


315  }


316  }


317 


318 


319 


320  #region helpers


321  private static double Rho(IEnumerable<double> x, IEnumerable<double> y) {


322  OnlineCalculatorError error;


323  double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);


324  return error == OnlineCalculatorError.None ? r : 0.0;


325  }


326 


327 


328  private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {


329  double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)


330  Array.Copy(consts, optConsts, nParams);


331 


332  // direct usage of LM is recommended in alglib manual for better performance than the lsfit interface (which uses lm internally).


333  alglib.minlmstate state;


334  alglib.minlmreport rep = null;


335  alglib.minlmcreatevj(y.Length + 1, optConsts, out state); // +1 for penalty term


336  // Using the change of the gradient as stopping criterion is recommended in alglib manual.


337  // However, the most recent version of alglib (as of Oct 2017) only supports epsX as stopping criterion


338  alglib.minlmsetcond(state, epsg: 1E6 * yStdDev, epsf: epsF, epsx: 0.0, maxits: nIters);


339  // alglib.minlmsetgradientcheck(state, 1E5);


340  alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);


341  alglib.minlmresults(state, out optConsts, out rep);


342  funcEvaluations += rep.nfunc;


343  gradEvaluations += rep.njac * nParams;


344 


345  if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);


346 


347  // only use optimized constants if successful


348  if (rep.terminationtype >= 0) {


349  Array.Copy(optConsts, consts, optConsts.Length);


350  }


351  }


352 


353  private void Func(double[] arg, double[] fi, object obj) {


354  var code = (byte[])obj;


355  int n = predBuf.Length;


356  evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen


357  for (int r = 0; r < n; r++) {


358  var res = predBuf[r]  y[r];


359  fi[r] = res;


360  }


361 


362  var penaltyIdx = fi.Length  1;


363  fi[penaltyIdx] = 0.0;


364  // calc length of parameter vector for regularization


365  var aa = 0.0;


366  for (int i = 0; i < arg.Length; i++) {


367  aa += arg[i] * arg[i];


368  }


369  if (lambda > 0 && aa > 0) {


370  // scale lambda using stdDev(y) to make the parameter independent of the scale of y


371  // scale lambda using n to make parameter independent of the number of training points


372  // take the root because LM squares the result


373  fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);


374  }


375  }


376 


377  private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {


378  int n = predBuf.Length;


379  int nParams = arg.Length;


380  var code = (byte[])obj;


381  evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen


382  for (int r = 0; r < n; r++) {


383  var res = predBuf[r]  y[r];


384  fi[r] = res;


385 


386  for (int k = 0; k < nParams; k++) {


387  jac[r, k] = gradBuf[k][r];


388  }


389  }


390  // calc length of parameter vector for regularization


391  double aa = 0.0;


392  for (int i = 0; i < arg.Length; i++) {


393  aa += arg[i] * arg[i];


394  }


395 


396  var penaltyIdx = fi.Length  1;


397  if (lambda > 0 && aa > 0) {


398  fi[penaltyIdx] = 0.0;


399  // scale lambda using stdDev(y) to make the parameter independent of the scale of y


400  // scale lambda using n to make parameter independent of the number of training points


401  // take the root because alglib LM squares the result


402  fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);


403 


404  for (int i = 0; i < arg.Length; i++) {


405  jac[penaltyIdx, i] = 0.5 / fi[penaltyIdx] * 2 * n * lambda / yStdDev * arg[i];


406  }


407  } else {


408  fi[penaltyIdx] = 0.0;


409  for (int i = 0; i < arg.Length; i++) {


410  jac[penaltyIdx, i] = 0.0;


411  }


412  }


413  }


414 


415 


416  private void UpdateParetoFront(double q, int complexity, byte[] code, double[] param, int nParam,


417  double[] scalingFactor, double[] scalingOffset) {


418  double[] best = new double[2];


419  double[] cur = new double[2] { q, complexity };


420  bool[] max = new[] { true, false };


421  var isNonDominated = true;


422  foreach (var e in paretoFront) {


423  var domRes = DominationCalculator<int>.Dominates(cur, e, max, true);


424  if (domRes == DominationResult.IsDominated) {


425  isNonDominated = false;


426  break;


427  }


428  }


429  if (isNonDominated) {


430  paretoFront.Add(cur);


431 


432  // create model


433  var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());


434  var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();


435 


436  var t = new SymbolicExpressionTree(treeGen.Exec(code, param, nParam, scalingFactor, scalingOffset));


437  var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);


438  model.Scale(problemData); // apply linear scaling


439 


440  var sol = model.CreateRegressionSolution(this.problemData);


441  sol.Name = string.Format("{0:N5} {1}", q, complexity);


442 


443  paretoBestModels.Add(sol);


444  }


445  for (int i = paretoFront.Count  2; i >= 0; i) {


446  var @ref = paretoFront[i];


447  var domRes = DominationCalculator<int>.Dominates(cur, @ref, max, true);


448  if (domRes == DominationResult.Dominates) {


449  paretoFront.RemoveAt(i);


450  paretoBestModels.RemoveAt(i);


451  }


452  }


453  }


454 


455  #endregion


456 


457  #if DEBUG


458  internal void ClearStats() {


459  for (int i = 0; i < pathGiniCoeffs.Length; i++) pathGiniCoeffs[i] = 1;


460  for (int i = 0; i < pathQs.Length; i++) pathGiniCoeffs[i] = 99;


461  for (int i = 0; i < pathBestQ.Length; i++) pathBestQ[i] = 99;


462  for (int i = 0; i < pathSelectedQ.Length; i++) pathSelectedQ[i] = 99;


463  }


464  internal void WriteGiniStats() {


465  Console.WriteLine(string.Join("\t", pathGiniCoeffs.TakeWhile(x => x >= 0).Select(x => string.Format("{0:N3}", x))));


466  }


467  internal void WriteQs() {


468  // Console.WriteLine(string.Join("\t", pathQs.TakeWhile(x => x >= 100).Select(x => string.Format("{0:N3}", x))));


469  var sb = new StringBuilder();


470  // length


471  int i = 0;


472  while (i < pathBestQ.Length && pathBestQ[i] > 99 && pathBestQ[i] == levelBestQ[i]) {


473  i++;


474  }


475  sb.AppendFormat("{0,3}",i);


476 


477  i = 0;


478  // sb.AppendFormat("{0:N3}", levelBestQ[0]);


479  while (i < pathSelectedQ.Length && pathSelectedQ[i] > 99) {


480  sb.AppendFormat("\t{0:N3}", pathSelectedQ[i]);


481  i++;


482  }


483  Console.WriteLine(sb.ToString());


484  sb.Clear();


485  i = 0;


486  // sb.AppendFormat("{0:N3}", levelBestQ[0]);


487  while (i < pathBestQ.Length && pathBestQ[i] > 99) {


488  sb.AppendFormat("\t{0:N3}", pathBestQ[i]);


489  i++;


490  }


491  Console.WriteLine(sb.ToString());


492  sb.Clear();


493  i = 0;


494  while (i < pathBestQ.Length && pathBestQ[i] > 99) {


495  sb.AppendFormat("\t{0:N3}", levelBestQ[i]);


496  i++;


497  }


498  Console.WriteLine(sb.ToString());


499 


500  sb.Clear();


501  i = 0;


502  while (i < pathBestQ.Length && pathBestQ[i] > 99) {


503  sb.AppendFormat("\t{0,5}", (curAction[i] != null && curAction[i].Length > 5) ? curAction[i].Substring(0, 5) : curAction[i]);


504  i++;


505  }


506  Console.WriteLine(sb.ToString());


507  sb.Clear();


508  i = 0;


509  while (i < pathBestQ.Length && pathBestQ[i] > 99) {


510  sb.AppendFormat("\t{0,5}", (levelBestAction[i] != null && levelBestAction[i].Length > 5) ? levelBestAction[i].Substring(0, 5) : levelBestAction[i]);


511  i++;


512  }


513  Console.WriteLine(sb.ToString());


514 


515  Console.WriteLine();


516  }


517 


518 


519  #endif


520 


521  }


522 


523 


524  /// <summary>


525  /// Static method to initialize a state for the algorithm


526  /// </summary>


527  /// <param name="problemData">The problem data</param>


528  /// <param name="randSeed">Random seed.</param>


529  /// <param name="maxVariables">Maximum number of variable references that are allowed in the expression.</param>


530  /// <param name="scaleVariables">Optionally scale input variables to the interval [0..1] (recommended)</param>


531  /// <param name="constOptIterations">Maximum number of iterations for constants optimization (LevenbergMarquardt)</param>


532  /// <param name="lambda">Penalty factor for regularization (0..inf.), small penalty disabled regularization.</param>


533  /// <param name="policy">Tree search policy (random, ucb, epsgreedy, ...)</param>


534  /// <param name="collectParameterOptimalModels">Optionally collect all Paretooptimal solutions having minimal length and error.</param>


535  /// <param name="lowerEstimationLimit">Optionally limit the result of the expression to this lower value.</param>


536  /// <param name="upperEstimationLimit">Optionally limit the result of the expression to this upper value.</param>


537  /// <param name="allowProdOfVars">Allow products of expressions.</param>


538  /// <param name="allowExp">Allow expressions with exponentials.</param>


539  /// <param name="allowLog">Allow expressions with logarithms</param>


540  /// <param name="allowInv">Allow expressions with 1/x</param>


541  /// <param name="allowMultipleTerms">Allow expressions which are sums of multiple terms.</param>


542  /// <returns></returns>


543 


544  public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,


545  bool scaleVariables = true, int constOptIterations = 1, double lambda = 0.0,


546  IPolicy policy = null,


547  bool collectParameterOptimalModels = false,


548  double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,


549  bool allowProdOfVars = true,


550  bool allowExp = true,


551  bool allowLog = true,


552  bool allowInv = true,


553  bool allowMultipleTerms = false


554  ) {


555  return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,


556  policy, collectParameterOptimalModels,


557  lowerEstimationLimit, upperEstimationLimit,


558  allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);


559  }


560 


561  // returns the quality of the evaluated solution


562  public static double MakeStep(IState state) {


563  var mctsState = state as State;


564  if (mctsState == null) throw new ArgumentException("state");


565  if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");


566 


567  return TreeSearch(mctsState);


568  }


569  #endregion


570 


571  private static double TreeSearch(State mctsState) {


572  var automaton = mctsState.automaton;


573  var tree = mctsState.tree;


574  var eval = mctsState.evalFun;


575  var rand = mctsState.random;


576  var treePolicy = mctsState.treePolicy;


577  double q = 0;


578  bool success = false;


579  do {


580  #if DEBUG


581  mctsState.ClearStats();


582  #endif


583  automaton.Reset();


584  success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, mctsState, out q);


585  mctsState.totalRollouts++;


586  } while (!success && !tree.Done);


587  mctsState.effectiveRollouts++;


588 


589  #if DEBUG


590  // mctsState.WriteGiniStats();


591  Console.WriteLine(ExprStr(automaton));


592  mctsState.WriteQs();


593  // Console.WriteLine(WriteStatistics(tree, mctsState));


594 


595  #endif


596  //if (mctsState.effectiveRollouts % 100 == 1) {


597  // Console.WriteLine(WriteTree(tree, mctsState));


598  // Console.WriteLine(TraceTree(tree, mctsState));


599  //}


600  return q;


601  }


602 


603  // search forward


604  private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,


605  State state,


606  out double q) {


607  // ROLLOUT AND EXPANSION


608  // We are navigating a graph (states might be reached via different paths) instead of a tree.


609  // State equivalence is checked through ExprHash (based on the generated code through the path).


610 


611  // We switch between rolloutmode and expansion mode


612  // Rolloutmode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB)


613  // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression)


614  // In expansion mode we might reenter the graph and switch back to rolloutmode


615  // We do this until we reach a complete expression (final state)


616 


617  // Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent


618  // Subgraphs which have been completely searched are marked as done.


619  // Rollout could lead to a state where all followstates are done. In this case we call the rollout ineffective.


620 


621  while (!automaton.IsFinalState(automaton.CurrentState)) {


622  if (state.children.ContainsKey(tree)) {


623  if (state.children[tree].All(ch => ch.Done)) {


624  tree.Done = true;


625  break;


626  }


627  // ROLLOUT INSIDE TREE


628  // UCT selection within tree


629  int selectedIdx = 0;


630  if (state.children[tree].Count > 1) {


631  selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand);


632  }


633 


634  // STATS


635  state.pathGiniCoeffs[tree.level] = InequalityCoefficient(state.children[tree].Select(ch => (double)ch.actionStatistics.AverageQuality));


636  state.pathQs[tree.level] = tree.actionStatistics.AverageQuality;


637 


638  tree = state.children[tree][selectedIdx];


639 


640  // move the automaton forward until reaching the state


641  // all steps where no alternatives are possible are immediately taken


642  // TODO: simplification of the automaton


643  int[] possibleFollowStates;


644  int nFs;


645  automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);


646  // TODO!


647  // while (possibleFollowStates[0] != tree.state && nFs == 1 &&


648  // !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {


649  // automaton.Goto(possibleFollowStates[0]);


650  // automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);


651  // }


652  Debug.Assert(possibleFollowStates.Contains(tree.state));


653  automaton.Goto(tree.state);


654  } else {


655  // EXPAND


656  int[] possibleFollowStates;


657  int nFs;


658  string actionString = "";


659  automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);


660  // TODO


661  // while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {


662  // actionString += " " + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[0]);


663  // // no alternatives > just go to the next state


664  // automaton.Goto(possibleFollowStates[0]);


665  // automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);


666  // }


667  if (nFs == 0) {


668  // stuck in a dead end (no final state and no allowed follow states)


669  tree.Done = true;


670  break;


671  }


672  var newChildren = new List<Tree>(nFs);


673  state.children.Add(tree, newChildren);


674  for (int i = 0; i < nFs; i++) {


675  Tree child = null;


676  // for selected states (EvalStates) we introduce state unification (detection of equivalent states)


677  if (automaton.IsEvalState(possibleFollowStates[i])) {


678  var hc = Hashcode(automaton);


679  if (!state.nodes.TryGetValue(hc, out child)) {


680  child = new Tree() {


681  children = null,


682  state = possibleFollowStates[i],


683  actionStatistics = treePolicy.CreateActionStatistics(),


684  expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),


685  level = tree.level + 1


686  };


687  state.nodes.Add(hc, child);


688  }


689  // only allow forward edges (don't add the child if we would go back in the graph)


690  else if (child.level > tree.level) {


691  // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link


692  // to all parents


693  BackpropagateStatistics(child.actionStatistics, tree, state);


694  } else {


695  // prevent cycles


696  Debug.Assert(child.level <= tree.level);


697  child = null;


698  }


699  } else {


700  child = new Tree() {


701  children = null,


702  state = possibleFollowStates[i],


703  actionStatistics = treePolicy.CreateActionStatistics(),


704  expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),


705  level = tree.level + 1


706  };


707  }


708  if (child != null)


709  newChildren.Add(child);


710  }


711 


712  if (!newChildren.Any()) {


713  // stuck in a dead end (no final state and no allowed follow states)


714  tree.Done = true;


715  break;


716  }


717 


718  foreach (var ch in newChildren) {


719  if (!state.parents.ContainsKey(ch)) {


720  state.parents.Add(ch, new List<Tree>());


721  }


722  state.parents[ch].Add(tree);


723  }


724 


725 


726  // follow one of the children


727  tree = SelectStateLeadingToFinal(automaton, tree, rand, state);


728  automaton.Goto(tree.state);


729  }


730  }


731 


732  bool success;


733 


734  // EVALUATE TREE


735  if (automaton.IsFinalState(automaton.CurrentState)) {


736  tree.Done = true;


737  tree.expr = ExprStr(automaton);


738  byte[] code; int nParams;


739  automaton.GetCode(out code, out nParams);


740  q = eval(code, nParams);


741  // Console.WriteLine("{0:N4}\t{1}", q*q, tree.expr);


742  q = TransformQuality(q);


743  success = true;


744  BackpropagateQuality(tree, q, treePolicy, state);


745  } else {


746  // we got stuck in rollout (not evaluation necessary!)


747  // Console.WriteLine("\t" + ExprStr(automaton) + " STOP");


748  q = 0.0;


749  success = false;


750  }


751 


752  // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS


753  // Update statistics


754  // Set branch to done if all children are done.


755  BackpropagateDone(tree, state);


756  BackpropagateDebugStats(tree, q, state);


757 


758 


759  return success;


760  }


761 


762  private static double InequalityCoefficient(IEnumerable<double> xs) {


763  var arr = xs.ToArray();


764  var sad = 0.0;


765  var sum = 0.0;


766 


767  for(int i=0;i<arr.Length;i++) {


768  for(int j=0;j<arr.Length;j++) {


769  sad += Math.Abs(arr[i]  arr[j]);


770  sum += arr[j];


771  }


772  }


773  return 0.5 * sad / sum;


774  }


775 


776  private static double TransformQuality(double q) {


777  // no transformation


778  return q;


779 


780  // EXPERIMENTAL!


781 


782  // Fisher transformation


783  // (assumes q is Correl(pred, target)


784 


785  q = Math.Min(q, 0.99999999);


786  q = Math.Max(q, 0.99999999);


787  return 0.5 * Math.Log((1 + q) / (1  q));


788 


789  // optimal result: q = 1 > return huge value


790  // if (q >= 1.0) return 1E16;


791  // // return number of 9s in R²


792  // return Math.Log10(1  q);


793  }


794 


795  // backpropagate existing statistics to all parents


796  private static void BackpropagateStatistics(IActionStatistics stats, Tree tree, State state) {


797  tree.actionStatistics.Add(stats);


798  if (state.parents.ContainsKey(tree)) {


799  foreach (var parent in state.parents[tree]) {


800  BackpropagateStatistics(stats, parent, state);


801  }


802  }


803  }


804 


805  private static ulong Hashcode(Automaton automaton) {


806  byte[] code;


807  int nParams;


808  automaton.GetCode(out code, out nParams);


809  return ExprHash.GetHash(code, nParams);


810  }


811 


812  private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) {


813  policy.Update(tree.actionStatistics, q);


814 


815  if (state.parents.ContainsKey(tree)) {


816  foreach (var parent in state.parents[tree]) {


817  BackpropagateQuality(parent, q, policy, state);


818  }


819  }


820  }


821 


822  private static void BackpropagateDone(Tree tree, State state) {


823  if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) {


824  tree.Done = true;


825  // children[tree] = null; keep all nodes


826  }


827 


828  if (state.parents.ContainsKey(tree)) {


829  foreach (var parent in state.parents[tree]) {


830  BackpropagateDone(parent, state);


831  }


832  }


833  }


834 


835  private static void BackpropagateDebugStats(Tree tree, double q, State state) {


836  if (state.parents.ContainsKey(tree)) {


837  foreach (var parent in state.parents[tree]) {


838  BackpropagateDebugStats(parent, q, state);


839  }


840  }


841 


842  state.pathSelectedQ[tree.level] = tree.actionStatistics.AverageQuality;


843  state.pathBestQ[tree.level] = tree.actionStatistics.BestQuality;


844  state.curAction[tree.level] = tree.expr;


845  if (state.levelBestQ[tree.level] < tree.actionStatistics.BestQuality) {


846  state.levelBestQ[tree.level] = tree.actionStatistics.BestQuality;


847  state.levelBestAction[tree.level] = tree.expr;


848  }


849  }


850 


851  private static Tree SelectStateLeadingToFinal(Automaton automaton, Tree tree, IRandom rand, State state) {


852  // find the child with the smallest state value (smaller values are closer to the final state)


853  int selectedChildIdx = 0;


854  var children = state.children[tree];


855  Tree minChild = children.First();


856  for (int i = 1; i < children.Count; i++) {


857  if(children[i].state < minChild.state)


858  selectedChildIdx = i;


859  }


860  return children[selectedChildIdx];


861  }


862 


863  // tree search might fail because of constraints for expressions


864  // in this case we get stuck we just restart


865  // see ConstraintHandler.cs for more info


866  private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,


867  out double q) {


868  Tree selectedChild = null;


869  Contract.Assert(tree.state == automaton.CurrentState);


870  Contract.Assert(!tree.Done);


871  if (tree.children == null) {


872  if (automaton.IsFinalState(tree.state)) {


873  // final state


874  tree.Done = true;


875 


876  // EVALUATE


877  byte[] code; int nParams;


878  automaton.GetCode(out code, out nParams);


879  q = eval(code, nParams);


880 


881  treePolicy.Update(tree.actionStatistics, q);


882  return true; // we reached a final state


883  } else {


884  // EXPAND


885  int[] possibleFollowStates;


886  int nFs;


887  automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);


888  if (nFs == 0) {


889  // stuck in a dead end (no final state and no allowed follow states)


890  q = 0;


891  tree.Done = true;


892  tree.children = null;


893  return false;


894  }


895  tree.children = new Tree[nFs];


896  for (int i = 0; i < tree.children.Length; i++)


897  tree.children[i] = new Tree() {


898  children = null,


899  state = possibleFollowStates[i],


900  actionStatistics = treePolicy.CreateActionStatistics()


901  };


902 


903  selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];


904  }


905  } else {


906  // tree.children != null


907  // UCT selection within tree


908  int selectedIdx = 0;


909  if (tree.children.Length > 1) {


910  selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);


911  }


912  selectedChild = tree.children[selectedIdx];


913  }


914  // make selected step and recurse


915  automaton.Goto(selectedChild.state);


916  var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);


917  if (success) {


918  // only update if successful


919  treePolicy.Update(tree.actionStatistics, q);


920  }


921 


922  tree.Done = tree.children.All(ch => ch.Done);


923  if (tree.Done) {


924  tree.children = null; // cut off the subbranch if it has been fully explored


925  }


926  return success;


927  }


928 


929  private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {


930  // if one of the new children leads to a final state then go there


931  // otherwise choose a random child


932  int selectedChildIdx = 1;


933  // find first final state if there is one


934  for (int i = 0; i < tree.children.Length; i++) {


935  if (automaton.IsFinalState(tree.children[i].state)) {


936  selectedChildIdx = i;


937  break;


938  }


939  }


940  // no final state > select a the first child


941  if (selectedChildIdx == 1) {


942  selectedChildIdx = 0;


943  }


944  return tree.children[selectedChildIdx];


945  }


946 


947  // scales data and extracts values from dataset into arrays


948  private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,


949  out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {


950  xs = new double[problemData.AllowedInputVariables.Count()][];


951 


952  var i = 0;


953  if (scaleVariables) {


954  scalingFactor = new double[xs.Length + 1];


955  scalingOffset = new double[xs.Length + 1];


956  } else {


957  scalingFactor = null;


958  scalingOffset = null;


959  }


960  foreach (var var in problemData.AllowedInputVariables) {


961  if (scaleVariables) {


962  var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();


963  var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();


964  var range = maxX  minX;


965 


966  // scaledX = (x  min) / range


967  var sf = 1.0 / range;


968  var offset = minX / range;


969  scalingFactor[i] = sf;


970  scalingOffset[i] = offset;


971  i++;


972  }


973  }


974 


975  if (scaleVariables) {


976  // transform target variable to zeromean


977  scalingFactor[i] = 1.0;


978  scalingOffset[i] = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Average();


979  }


980 


981  GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);


982  }


983 


984  // extract values from dataset into arrays


985  private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,


986  out double[][] xs, out double[] y) {


987  xs = new double[problemData.AllowedInputVariables.Count()][];


988 


989  int i = 0;


990  foreach (var var in problemData.AllowedInputVariables) {


991  var sf = scalingFactor == null ? 1.0 : scalingFactor[i];


992  var offset = scalingFactor == null ? 0.0 : scalingOffset[i];


993  xs[i++] =


994  problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();


995  }


996 


997  {


998  var sf = scalingFactor == null ? 1.0 : scalingFactor[i];


999  var offset = scalingFactor == null ? 0.0 : scalingOffset[i];


1000  y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(yi => yi * sf + offset).ToArray();


1001  }


1002  }


1003 


1004  // for debugging only


1005 


1006 


1007  private static string ExprStr(Automaton automaton) {


1008  byte[] code;


1009  int nParams;


1010  automaton.GetCode(out code, out nParams);


1011  return Disassembler.CodeToString(code);


1012  }


1013 


1014 


1015  private static string WriteStatistics(Tree tree, State state) {


1016  var sb = new System.IO.StringWriter();


1017  sb.Write("{0}\t{1:N5}\t", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);


1018  if (state.children.ContainsKey(tree)) {


1019  foreach (var ch in state.children[tree]) {


1020  sb.Write("{0}\t{1:N5}\t", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);


1021  }


1022  }


1023  sb.WriteLine();


1024  return sb.ToString();


1025  }


1026 


1027  private static string TraceTree(Tree tree, State state) {


1028  var sb = new StringBuilder();


1029  sb.Append(


1030  @"digraph {


1031  ratio = fill;


1032  node [style=filled];


1033  ");


1034  int nodeId = 0;


1035 


1036  TraceTreeRec(tree, 0, sb, ref nodeId, state);


1037  sb.Append("}");


1038  return sb.ToString();


1039  }


1040 


1041  private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) {


1042  var avgNodeQ = tree.actionStatistics.AverageQuality;


1043  var tries = tree.actionStatistics.Tries;


1044  if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;


1045  var hue = (1  avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue


1046  hue = 0.0;


1047 


1048  sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine();


1049 


1050  var list = new List<Tuple<int, int, Tree>>();


1051  if (state.children.ContainsKey(tree)) {


1052  foreach (var ch in state.children[tree]) {


1053  nextId++;


1054  avgNodeQ = ch.actionStatistics.AverageQuality;


1055  tries = ch.actionStatistics.Tries;


1056  if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;


1057  hue = (1  avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue


1058  hue = 0.0;


1059  sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();


1060  sb.AppendFormat("{0} > {1} [label=\"{3}\"]", parentId, nextId, avgNodeQ, ch.expr).AppendLine();


1061  list.Add(Tuple.Create(tries, nextId, ch));


1062  }


1063 


1064  foreach(var tup in list) {


1065  var ch = tup.Item3;


1066  var chId = tup.Item2;


1067  if(state.children.ContainsKey(ch) && state.children[ch].Count == 1) {


1068  var chch = state.children[ch].First();


1069  nextId++;


1070  avgNodeQ = chch.actionStatistics.AverageQuality;


1071  tries = chch.actionStatistics.Tries;


1072  if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;


1073  hue = (1  avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue


1074  hue = 0.0;


1075  sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();


1076  sb.AppendFormat("{0} > {1} [label=\"{3}\"]", chId, nextId, avgNodeQ, chch.expr).AppendLine();


1077  }


1078  }


1079 


1080  foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {


1081  TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId, state);


1082  }


1083  }


1084  }


1085 


1086  private static string WriteTree(Tree tree, State state) {


1087  var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);


1088  var nodeIds = new Dictionary<Tree, int>();


1089  sb.Write(


1090  @"digraph {


1091  ratio = fill;


1092  node [style=filled];


1093  ");


1094  int threshold = /* state.nodes.Count > 500 ? 10 : */ 0;


1095  foreach (var kvp in state.children) {


1096  var parent = kvp.Key;


1097  int parentId;


1098  if (!nodeIds.TryGetValue(parent, out parentId)) {


1099  parentId = nodeIds.Count + 1;


1100  var avgNodeQ = parent.actionStatistics.AverageQuality;


1101  var tries = parent.actionStatistics.Tries;


1102  if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;


1103  var hue = (1  avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue


1104  hue = 0.0;


1105  if (parent.actionStatistics.Tries > threshold)


1106  sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);


1107  nodeIds.Add(parent, parentId);


1108  }


1109  foreach (var child in kvp.Value) {


1110  int childId;


1111  if (!nodeIds.TryGetValue(child, out childId)) {


1112  childId = nodeIds.Count + 1;


1113  nodeIds.Add(child, childId);


1114  }


1115  var avgNodeQ = child.actionStatistics.AverageQuality;


1116  var tries = child.actionStatistics.Tries;


1117  if (tries < 1) continue;


1118  if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;


1119  var hue = (1  avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue


1120  hue = 0.0;


1121  if (tries > threshold) {


1122  sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);


1123  var edgeLabel = child.expr;


1124  // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");


1125  sb.Write("{0} > {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel);


1126  }


1127  }


1128  }


1129 


1130  sb.Write("}");


1131  return sb.ToString();


1132  }


1133  }


1134  }

