#region License Information /* HeuristicLab * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.Contracts; using System.Linq; using System.Text; using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies; using HeuristicLab.Core; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Optimization; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; using HeuristicLab.Random; namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression { public static class MctsSymbolicRegressionStatic { // OBJECTIVES: // 1) solve toy problems without numeric constants (to show that structure search is effective / efficient) // - e.g. Keijzer, Nguyen ... where no numeric constants are involved // - assumptions: // - we don't know the necessary operations or functions -> all available functions could be necessary // - but we do not need to tune numeric constants -> no scaling of input variables x! // 2) Solve toy problems with numeric constants to make the algorithm invariant concerning variable scale. // This is important for real world applications. // - e.g. Korns or Vladislavleva problems where numeric constants are involved // - assumptions: // - any numeric constant is possible (a-priori we might assume that small abs. constants are more likely) // - standardization of variables is possible (or might be necessary) as we adjust numeric parameters of the expression anyway // - to simplify the problem we can restrict the set of functions e.g. we assume which functions are necessary for the problem instance // -> several steps: (a) polyinomials, (b) rational polynomials, (c) exponential or logarithmic functions, rational functions with exponential and logarithmic parts // 3) efficiency and effectiveness for real-world problems // - e.g. Tower problem // - (1) and (2) combined, structure search must be effective in combination with numeric optimization of constants // // TODO: The samples of x1*... or x2*... do not give any information about the relevance of the interaction term x1*x2 in general! // --> E.g. if x1, x2 ~ N(0, 1) or U(-1, 1) this is trivial to show // --> Therefore, looking at rollout statistics for arm selection is useless in the general case! // --> It is necessary to rely on other features for the arm selection. // --> TODO: Which heuristics can we apply? // TODO: Solve Poly-10 // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved? // TODO: Why is the algorithm so slow for rather greedy policies (e.g. low C value in UCB)? // TODO: check if we can use a quality measure with range [-1..1] in policies // TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants // TODO: check if transformation of y is correct and works (Obj 2) // TODO: The algorithm is not invariant to location and scale of variables. // Include offset for variables as parameter (for Objective 2) // TODO: why does LM optimization converge so slowly with exp(x), log(x), and 1/x allowed (Obj 2)? // TODO: support e(-x) and possibly (1/-x) (Obj 1) // TODO: is it OK to initialize all constants to 1 (Obj 2)? // TODO: improve memory usage // TODO: support empty test partition // TODO: the algorithm should be invariant to linear transformations of the space (y = f(x') = f( Ax ) ) for invertible transformations A --> unit tests #region static API public interface IState { bool Done { get; } ISymbolicRegressionModel BestModel { get; } double BestSolutionTrainingQuality { get; } double BestSolutionTestQuality { get; } IEnumerable ParetoBestModels { get; } int TotalRollouts { get; } int EffectiveRollouts { get; } int FuncEvaluations { get; } int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations // TODO other stats on LM optimizer might be interesting here } // created through factory method private class State : IState { private const int MaxParams = 100; // state variables used by MCTS internal readonly Automaton automaton; internal IRandom random { get; private set; } internal readonly Tree tree; internal readonly Func evalFun; internal readonly IPolicy treePolicy; // MCTS might get stuck. Track statistics on the number of effective rollouts internal int totalRollouts; internal int effectiveRollouts; // state variables used only internally (for eval function) private readonly IRegressionProblemData problemData; private readonly double[][] x; private readonly double[] y; private readonly double[][] testX; private readonly double[] testY; private readonly double[] scalingFactor; private readonly double[] scalingOffset; private readonly double yStdDev; // for scaling parameters (e.g. stopping condition for LM) private readonly int constOptIterations; private readonly double lambda; // weight of penalty term for regularization private readonly double lowerEstimationLimit, upperEstimationLimit; private readonly bool collectParetoOptimalModels; private readonly List paretoBestModels = new List(); private readonly List paretoFront = new List(); // matching the models private readonly ExpressionEvaluator evaluator, testEvaluator; internal readonly Dictionary> children = new Dictionary>(); internal readonly Dictionary> parents = new Dictionary>(); internal readonly Dictionary nodes = new Dictionary(); // values for best solution private double bestR; private byte[] bestCode; private int bestNParams; private double[] bestConsts; // stats private int funcEvaluations; private int gradEvaluations; // buffers private readonly double[] ones; // vector of ones (as default params) private readonly double[] constsBuf; private readonly double[] predBuf, testPredBuf; private readonly double[][] gradBuf; // debugging stats // calculate for each level the number of alternatives the average 'inequality' of tries and 'inequality' of quality over the alternatives for each trie // inequality can be calculated using the Gini coefficient internal readonly double[] pathGiniCoeffs = new double[100]; internal readonly double[] pathQs = new double[100]; internal readonly double[] levelBestQ = new double[100]; // internal readonly double[] levelMaxTries = new double[100]; internal readonly double[] pathBestQ = new double[100]; // as long as pathBestQs = levelBestQs we are following the correct path internal readonly string[] levelBestAction = new string[100]; internal readonly string[] curAction = new string[100]; internal readonly double[] pathSelectedQ = new double[100]; public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables, int constOptIterations, double lambda, IPolicy treePolicy = null, bool collectParetoOptimalModels = false, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, bool allowProdOfVars = true, bool allowExp = true, bool allowLog = true, bool allowInv = true, bool allowMultipleTerms = false) { if (lambda < 0) throw new ArgumentException("Lambda must be larger or equal zero", "lambda"); this.problemData = problemData; this.constOptIterations = constOptIterations; this.lambda = lambda; this.evalFun = this.Eval; this.lowerEstimationLimit = lowerEstimationLimit; this.upperEstimationLimit = upperEstimationLimit; this.collectParetoOptimalModels = collectParetoOptimalModels; random = new MersenneTwister(randSeed); // prepare data for evaluation double[][] x; double[] y; double[][] testX; double[] testY; double[] scalingFactor; double[] scalingOffset; // get training and test datasets (scale linearly based on training set if required) GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset); GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY); this.x = x; this.y = y; this.yStdDev = HeuristicLab.Common.EnumerableStatisticExtensions.StandardDeviation(y); this.testX = testX; this.testY = testY; this.scalingFactor = scalingFactor; this.scalingOffset = scalingOffset; this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit); // we need a separate evaluator because the vector length for the test dataset might differ this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit); this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms); this.treePolicy = treePolicy ?? new Ucb(); this.tree = new Tree() { state = automaton.CurrentState, actionStatistics = treePolicy.CreateActionStatistics(), expr = "", level = 0 }; // reset best solution this.bestR = 0; // code for default solution (constant model) this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit }; this.bestNParams = 0; this.bestConsts = null; // init buffers this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray(); constsBuf = new double[MaxParams]; this.predBuf = new double[y.Length]; this.testPredBuf = new double[testY.Length]; this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray(); } #region IState inferface public bool Done { get { return tree != null && tree.Done; } } public double BestSolutionTrainingQuality { get { evaluator.Exec(bestCode, x, bestConsts, predBuf); return Rho(y, predBuf); } } public double BestSolutionTestQuality { get { testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf); return Rho(testY, testPredBuf); } } // takes the code of the best solution and creates and equivalent symbolic regression model public ISymbolicRegressionModel BestModel { get { var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray()); var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter(); var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset)); var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit); model.Scale(problemData); // apply linear scaling return model; } } public IEnumerable ParetoBestModels { get { return paretoBestModels; } } public int TotalRollouts { get { return totalRollouts; } } public int EffectiveRollouts { get { return effectiveRollouts; } } public int FuncEvaluations { get { return funcEvaluations; } } public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations #endregion private double Eval(byte[] code, int nParams) { double[] optConsts; double q; Eval(code, nParams, out q, out optConsts); // single objective best if (q > bestR) { bestR = q; bestNParams = nParams; this.bestCode = new byte[code.Length]; this.bestConsts = new double[bestNParams]; Array.Copy(code, bestCode, code.Length); Array.Copy(optConsts, bestConsts, bestNParams); } if (collectParetoOptimalModels) { // multi-objective best var complexity = // SymbolicDataAnalysisModelComplexityCalculator.CalculateComplexity() TODO: implement Kommenda's tree complexity directly in the evaluator Array.FindIndex(code, (opc) => opc == (byte)OpCodes.Exit); // use length of expression as surrogate for complexity UpdateParetoFront(q, complexity, code, optConsts, nParams, scalingFactor, scalingOffset); } return q; } private void Eval(byte[] code, int nParams, out double rho, out double[] optConsts) { // we make a first pass to determine a valid starting configuration for all constants // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator) // scale and offset are set to optimal starting configuration // assumes scale is the first param and offset is the last param // reset constants Array.Copy(ones, constsBuf, nParams); evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true); funcEvaluations++; if (nParams == 0 || constOptIterations < 0) { // if we don't need to optimize parameters then we are done // changing scale and offset does not influence r² rho = Rho(y, predBuf); optConsts = constsBuf; } else { // optimize constants using the starting point calculated above OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations); evaluator.Exec(code, x, constsBuf, predBuf); funcEvaluations++; rho = Rho(y, predBuf); optConsts = constsBuf; } } #region helpers private static double Rho(IEnumerable x, IEnumerable y) { OnlineCalculatorError error; double r = OnlinePearsonsRCalculator.Calculate(x, y, out error); return error == OnlineCalculatorError.None ? r : 0.0; } private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) { double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?) Array.Copy(consts, optConsts, nParams); // direct usage of LM is recommended in alglib manual for better performance than the lsfit interface (which uses lm internally). alglib.minlmstate state; alglib.minlmreport rep = null; alglib.minlmcreatevj(y.Length + 1, optConsts, out state); // +1 for penalty term // Using the change of the gradient as stopping criterion is recommended in alglib manual. // However, the most recent version of alglib (as of Oct 2017) only supports epsX as stopping criterion alglib.minlmsetcond(state, epsg: 1E-6 * yStdDev, epsf: epsF, epsx: 0.0, maxits: nIters); // alglib.minlmsetgradientcheck(state, 1E-5); alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code); alglib.minlmresults(state, out optConsts, out rep); funcEvaluations += rep.nfunc; gradEvaluations += rep.njac * nParams; if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype); // only use optimized constants if successful if (rep.terminationtype >= 0) { Array.Copy(optConsts, consts, optConsts.Length); } } private void Func(double[] arg, double[] fi, object obj) { var code = (byte[])obj; int n = predBuf.Length; evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen for (int r = 0; r < n; r++) { var res = predBuf[r] - y[r]; fi[r] = res; } var penaltyIdx = fi.Length - 1; fi[penaltyIdx] = 0.0; // calc length of parameter vector for regularization var aa = 0.0; for (int i = 0; i < arg.Length; i++) { aa += arg[i] * arg[i]; } if (lambda > 0 && aa > 0) { // scale lambda using stdDev(y) to make the parameter independent of the scale of y // scale lambda using n to make parameter independent of the number of training points // take the root because LM squares the result fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa); } } private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) { int n = predBuf.Length; int nParams = arg.Length; var code = (byte[])obj; evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen for (int r = 0; r < n; r++) { var res = predBuf[r] - y[r]; fi[r] = res; for (int k = 0; k < nParams; k++) { jac[r, k] = gradBuf[k][r]; } } // calc length of parameter vector for regularization double aa = 0.0; for (int i = 0; i < arg.Length; i++) { aa += arg[i] * arg[i]; } var penaltyIdx = fi.Length - 1; if (lambda > 0 && aa > 0) { fi[penaltyIdx] = 0.0; // scale lambda using stdDev(y) to make the parameter independent of the scale of y // scale lambda using n to make parameter independent of the number of training points // take the root because alglib LM squares the result fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa); for (int i = 0; i < arg.Length; i++) { jac[penaltyIdx, i] = 0.5 / fi[penaltyIdx] * 2 * n * lambda / yStdDev * arg[i]; } } else { fi[penaltyIdx] = 0.0; for (int i = 0; i < arg.Length; i++) { jac[penaltyIdx, i] = 0.0; } } } private void UpdateParetoFront(double q, int complexity, byte[] code, double[] param, int nParam, double[] scalingFactor, double[] scalingOffset) { double[] best = new double[2]; double[] cur = new double[2] { q, complexity }; bool[] max = new[] { true, false }; var isNonDominated = true; foreach (var e in paretoFront) { var domRes = DominationCalculator.Dominates(cur, e, max, true); if (domRes == DominationResult.IsDominated) { isNonDominated = false; break; } } if (isNonDominated) { paretoFront.Add(cur); // create model var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray()); var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter(); var t = new SymbolicExpressionTree(treeGen.Exec(code, param, nParam, scalingFactor, scalingOffset)); var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit); model.Scale(problemData); // apply linear scaling var sol = model.CreateRegressionSolution(this.problemData); sol.Name = string.Format("{0:N5} {1}", q, complexity); paretoBestModels.Add(sol); } for (int i = paretoFront.Count - 2; i >= 0; i--) { var @ref = paretoFront[i]; var domRes = DominationCalculator.Dominates(cur, @ref, max, true); if (domRes == DominationResult.Dominates) { paretoFront.RemoveAt(i); paretoBestModels.RemoveAt(i); } } } #endregion #if DEBUG internal void ClearStats() { for (int i = 0; i < pathGiniCoeffs.Length; i++) pathGiniCoeffs[i] = -1; for (int i = 0; i < pathQs.Length; i++) pathGiniCoeffs[i] = -99; for (int i = 0; i < pathBestQ.Length; i++) pathBestQ[i] = -99; for (int i = 0; i < pathSelectedQ.Length; i++) pathSelectedQ[i] = -99; } internal void WriteGiniStats() { Console.WriteLine(string.Join("\t", pathGiniCoeffs.TakeWhile(x => x >= 0).Select(x => string.Format("{0:N3}", x)))); } internal void WriteQs() { // Console.WriteLine(string.Join("\t", pathQs.TakeWhile(x => x >= -100).Select(x => string.Format("{0:N3}", x)))); var sb = new StringBuilder(); // length int i = 0; while (i < pathBestQ.Length && pathBestQ[i] > -99 && pathBestQ[i] == levelBestQ[i]) { i++; } sb.AppendFormat("{0,-3}",i); i = 0; // sb.AppendFormat("{0:N3}", levelBestQ[0]); while (i < pathSelectedQ.Length && pathSelectedQ[i] > -99) { sb.AppendFormat("\t{0:N3}", pathSelectedQ[i]); i++; } Console.WriteLine(sb.ToString()); sb.Clear(); i = 0; // sb.AppendFormat("{0:N3}", levelBestQ[0]); while (i < pathBestQ.Length && pathBestQ[i] > -99) { sb.AppendFormat("\t{0:N3}", pathBestQ[i]); i++; } Console.WriteLine(sb.ToString()); sb.Clear(); i = 0; while (i < pathBestQ.Length && pathBestQ[i] > -99) { sb.AppendFormat("\t{0:N3}", levelBestQ[i]); i++; } Console.WriteLine(sb.ToString()); sb.Clear(); i = 0; while (i < pathBestQ.Length && pathBestQ[i] > -99) { sb.AppendFormat("\t{0,-5}", (curAction[i] != null && curAction[i].Length > 5) ? curAction[i].Substring(0, 5) : curAction[i]); i++; } Console.WriteLine(sb.ToString()); sb.Clear(); i = 0; while (i < pathBestQ.Length && pathBestQ[i] > -99) { sb.AppendFormat("\t{0,-5}", (levelBestAction[i] != null && levelBestAction[i].Length > 5) ? levelBestAction[i].Substring(0, 5) : levelBestAction[i]); i++; } Console.WriteLine(sb.ToString()); Console.WriteLine(); } #endif } /// /// Static method to initialize a state for the algorithm /// /// The problem data /// Random seed. /// Maximum number of variable references that are allowed in the expression. /// Optionally scale input variables to the interval [0..1] (recommended) /// Maximum number of iterations for constants optimization (Levenberg-Marquardt) /// Penalty factor for regularization (0..inf.), small penalty disabled regularization. /// Tree search policy (random, ucb, eps-greedy, ...) /// Optionally collect all Pareto-optimal solutions having minimal length and error. /// Optionally limit the result of the expression to this lower value. /// Optionally limit the result of the expression to this upper value. /// Allow products of expressions. /// Allow expressions with exponentials. /// Allow expressions with logarithms /// Allow expressions with 1/x /// Allow expressions which are sums of multiple terms. /// public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3, bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0, IPolicy policy = null, bool collectParameterOptimalModels = false, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, bool allowProdOfVars = true, bool allowExp = true, bool allowLog = true, bool allowInv = true, bool allowMultipleTerms = false ) { return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda, policy, collectParameterOptimalModels, lowerEstimationLimit, upperEstimationLimit, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms); } // returns the quality of the evaluated solution public static double MakeStep(IState state) { var mctsState = state as State; if (mctsState == null) throw new ArgumentException("state"); if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions."); return TreeSearch(mctsState); } #endregion private static double TreeSearch(State mctsState) { var automaton = mctsState.automaton; var tree = mctsState.tree; var eval = mctsState.evalFun; var rand = mctsState.random; var treePolicy = mctsState.treePolicy; double q = 0; bool success = false; do { #if DEBUG mctsState.ClearStats(); #endif automaton.Reset(); success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, mctsState, out q); mctsState.totalRollouts++; } while (!success && !tree.Done); mctsState.effectiveRollouts++; #if DEBUG // mctsState.WriteGiniStats(); Console.WriteLine(ExprStr(automaton)); mctsState.WriteQs(); // Console.WriteLine(WriteStatistics(tree, mctsState)); #endif //if (mctsState.effectiveRollouts % 100 == 1) { // Console.WriteLine(WriteTree(tree, mctsState)); // Console.WriteLine(TraceTree(tree, mctsState)); //} return q; } // search forward private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func eval, IPolicy treePolicy, State state, out double q) { // ROLLOUT AND EXPANSION // We are navigating a graph (states might be reached via different paths) instead of a tree. // State equivalence is checked through ExprHash (based on the generated code through the path). // We switch between rollout-mode and expansion mode // Rollout-mode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB) // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression) // In expansion mode we might re-enter the graph and switch back to rollout-mode // We do this until we reach a complete expression (final state) // Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent // Sub-graphs which have been completely searched are marked as done. // Roll-out could lead to a state where all follow-states are done. In this case we call the rollout ineffective. while (!automaton.IsFinalState(automaton.CurrentState)) { if (state.children.ContainsKey(tree)) { if (state.children[tree].All(ch => ch.Done)) { tree.Done = true; break; } // ROLLOUT INSIDE TREE // UCT selection within tree int selectedIdx = 0; if (state.children[tree].Count > 1) { selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand); } // STATS state.pathGiniCoeffs[tree.level] = InequalityCoefficient(state.children[tree].Select(ch => (double)ch.actionStatistics.AverageQuality)); state.pathQs[tree.level] = tree.actionStatistics.AverageQuality; tree = state.children[tree][selectedIdx]; // move the automaton forward until reaching the state // all steps where no alternatives are possible are immediately taken // TODO: simplification of the automaton int[] possibleFollowStates; int nFs; automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); // TODO! // while (possibleFollowStates[0] != tree.state && nFs == 1 && // !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { // automaton.Goto(possibleFollowStates[0]); // automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); // } Debug.Assert(possibleFollowStates.Contains(tree.state)); automaton.Goto(tree.state); } else { // EXPAND int[] possibleFollowStates; int nFs; string actionString = ""; automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); // TODO // while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { // actionString += " " + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[0]); // // no alternatives -> just go to the next state // automaton.Goto(possibleFollowStates[0]); // automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); // } if (nFs == 0) { // stuck in a dead end (no final state and no allowed follow states) tree.Done = true; break; } var newChildren = new List(nFs); state.children.Add(tree, newChildren); for (int i = 0; i < nFs; i++) { Tree child = null; // for selected states (EvalStates) we introduce state unification (detection of equivalent states) if (automaton.IsEvalState(possibleFollowStates[i])) { var hc = Hashcode(automaton); if (!state.nodes.TryGetValue(hc, out child)) { child = new Tree() { children = null, state = possibleFollowStates[i], actionStatistics = treePolicy.CreateActionStatistics(), expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]), level = tree.level + 1 }; state.nodes.Add(hc, child); } // only allow forward edges (don't add the child if we would go back in the graph) else if (child.level > tree.level) { // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link // to all parents BackpropagateStatistics(child.actionStatistics, tree, state); } else { // prevent cycles Debug.Assert(child.level <= tree.level); child = null; } } else { child = new Tree() { children = null, state = possibleFollowStates[i], actionStatistics = treePolicy.CreateActionStatistics(), expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]), level = tree.level + 1 }; } if (child != null) newChildren.Add(child); } if (!newChildren.Any()) { // stuck in a dead end (no final state and no allowed follow states) tree.Done = true; break; } foreach (var ch in newChildren) { if (!state.parents.ContainsKey(ch)) { state.parents.Add(ch, new List()); } state.parents[ch].Add(tree); } // follow one of the children tree = SelectStateLeadingToFinal(automaton, tree, rand, state); automaton.Goto(tree.state); } } bool success; // EVALUATE TREE if (automaton.IsFinalState(automaton.CurrentState)) { tree.Done = true; tree.expr = ExprStr(automaton); byte[] code; int nParams; automaton.GetCode(out code, out nParams); q = eval(code, nParams); // Console.WriteLine("{0:N4}\t{1}", q*q, tree.expr); q = TransformQuality(q); success = true; BackpropagateQuality(tree, q, treePolicy, state); } else { // we got stuck in roll-out (not evaluation necessary!) // Console.WriteLine("\t" + ExprStr(automaton) + " STOP"); q = 0.0; success = false; } // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS // Update statistics // Set branch to done if all children are done. BackpropagateDone(tree, state); BackpropagateDebugStats(tree, q, state); return success; } private static double InequalityCoefficient(IEnumerable xs) { var arr = xs.ToArray(); var sad = 0.0; var sum = 0.0; for(int i=0;i return huge value // if (q >= 1.0) return 1E16; // // return number of 9s in R² // return -Math.Log10(1 - q); } // backpropagate existing statistics to all parents private static void BackpropagateStatistics(IActionStatistics stats, Tree tree, State state) { tree.actionStatistics.Add(stats); if (state.parents.ContainsKey(tree)) { foreach (var parent in state.parents[tree]) { BackpropagateStatistics(stats, parent, state); } } } private static ulong Hashcode(Automaton automaton) { byte[] code; int nParams; automaton.GetCode(out code, out nParams); return ExprHash.GetHash(code, nParams); } private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) { policy.Update(tree.actionStatistics, q); if (state.parents.ContainsKey(tree)) { foreach (var parent in state.parents[tree]) { BackpropagateQuality(parent, q, policy, state); } } } private static void BackpropagateDone(Tree tree, State state) { if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) { tree.Done = true; // children[tree] = null; keep all nodes } if (state.parents.ContainsKey(tree)) { foreach (var parent in state.parents[tree]) { BackpropagateDone(parent, state); } } } private static void BackpropagateDebugStats(Tree tree, double q, State state) { if (state.parents.ContainsKey(tree)) { foreach (var parent in state.parents[tree]) { BackpropagateDebugStats(parent, q, state); } } state.pathSelectedQ[tree.level] = tree.actionStatistics.AverageQuality; state.pathBestQ[tree.level] = tree.actionStatistics.BestQuality; state.curAction[tree.level] = tree.expr; if (state.levelBestQ[tree.level] < tree.actionStatistics.BestQuality) { state.levelBestQ[tree.level] = tree.actionStatistics.BestQuality; state.levelBestAction[tree.level] = tree.expr; } } private static Tree SelectStateLeadingToFinal(Automaton automaton, Tree tree, IRandom rand, State state) { // find the child with the smallest state value (smaller values are closer to the final state) int selectedChildIdx = 0; var children = state.children[tree]; Tree minChild = children.First(); for (int i = 1; i < children.Count; i++) { if(children[i].state < minChild.state) selectedChildIdx = i; } return children[selectedChildIdx]; } // tree search might fail because of constraints for expressions // in this case we get stuck we just restart // see ConstraintHandler.cs for more info private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func eval, IPolicy treePolicy, out double q) { Tree selectedChild = null; Contract.Assert(tree.state == automaton.CurrentState); Contract.Assert(!tree.Done); if (tree.children == null) { if (automaton.IsFinalState(tree.state)) { // final state tree.Done = true; // EVALUATE byte[] code; int nParams; automaton.GetCode(out code, out nParams); q = eval(code, nParams); treePolicy.Update(tree.actionStatistics, q); return true; // we reached a final state } else { // EXPAND int[] possibleFollowStates; int nFs; automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); if (nFs == 0) { // stuck in a dead end (no final state and no allowed follow states) q = 0; tree.Done = true; tree.children = null; return false; } tree.children = new Tree[nFs]; for (int i = 0; i < tree.children.Length; i++) tree.children[i] = new Tree() { children = null, state = possibleFollowStates[i], actionStatistics = treePolicy.CreateActionStatistics() }; selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0]; } } else { // tree.children != null // UCT selection within tree int selectedIdx = 0; if (tree.children.Length > 1) { selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand); } selectedChild = tree.children[selectedIdx]; } // make selected step and recurse automaton.Goto(selectedChild.state); var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q); if (success) { // only update if successful treePolicy.Update(tree.actionStatistics, q); } tree.Done = tree.children.All(ch => ch.Done); if (tree.Done) { tree.children = null; // cut off the sub-branch if it has been fully explored } return success; } private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) { // if one of the new children leads to a final state then go there // otherwise choose a random child int selectedChildIdx = -1; // find first final state if there is one for (int i = 0; i < tree.children.Length; i++) { if (automaton.IsFinalState(tree.children[i].state)) { selectedChildIdx = i; break; } } // no final state -> select a the first child if (selectedChildIdx == -1) { selectedChildIdx = 0; } return tree.children[selectedChildIdx]; } // scales data and extracts values from dataset into arrays private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable rows, out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) { xs = new double[problemData.AllowedInputVariables.Count()][]; var i = 0; if (scaleVariables) { scalingFactor = new double[xs.Length + 1]; scalingOffset = new double[xs.Length + 1]; } else { scalingFactor = null; scalingOffset = null; } foreach (var var in problemData.AllowedInputVariables) { if (scaleVariables) { var minX = problemData.Dataset.GetDoubleValues(var, rows).Min(); var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max(); var range = maxX - minX; // scaledX = (x - min) / range var sf = 1.0 / range; var offset = -minX / range; scalingFactor[i] = sf; scalingOffset[i] = offset; i++; } } if (scaleVariables) { // transform target variable to zero-mean scalingFactor[i] = 1.0; scalingOffset[i] = -problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Average(); } GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y); } // extract values from dataset into arrays private static void GenerateData(IRegressionProblemData problemData, IEnumerable rows, double[] scalingFactor, double[] scalingOffset, out double[][] xs, out double[] y) { xs = new double[problemData.AllowedInputVariables.Count()][]; int i = 0; foreach (var var in problemData.AllowedInputVariables) { var sf = scalingFactor == null ? 1.0 : scalingFactor[i]; var offset = scalingFactor == null ? 0.0 : scalingOffset[i]; xs[i++] = problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray(); } { var sf = scalingFactor == null ? 1.0 : scalingFactor[i]; var offset = scalingFactor == null ? 0.0 : scalingOffset[i]; y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(yi => yi * sf + offset).ToArray(); } } // for debugging only private static string ExprStr(Automaton automaton) { byte[] code; int nParams; automaton.GetCode(out code, out nParams); return Disassembler.CodeToString(code); } private static string WriteStatistics(Tree tree, State state) { var sb = new System.IO.StringWriter(); sb.Write("{0}\t{1:N5}\t", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality); if (state.children.ContainsKey(tree)) { foreach (var ch in state.children[tree]) { sb.Write("{0}\t{1:N5}\t", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality); } } sb.WriteLine(); return sb.ToString(); } private static string TraceTree(Tree tree, State state) { var sb = new StringBuilder(); sb.Append( @"digraph { ratio = fill; node [style=filled]; "); int nodeId = 0; TraceTreeRec(tree, 0, sb, ref nodeId, state); sb.Append("}"); return sb.ToString(); } private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) { var avgNodeQ = tree.actionStatistics.AverageQuality; var tries = tree.actionStatistics.Tries; if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue hue = 0.0; sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine(); var list = new List>(); if (state.children.ContainsKey(tree)) { foreach (var ch in state.children[tree]) { nextId++; avgNodeQ = ch.actionStatistics.AverageQuality; tries = ch.actionStatistics.Tries; if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue hue = 0.0; sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine(); sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", parentId, nextId, avgNodeQ, ch.expr).AppendLine(); list.Add(Tuple.Create(tries, nextId, ch)); } foreach(var tup in list) { var ch = tup.Item3; var chId = tup.Item2; if(state.children.ContainsKey(ch) && state.children[ch].Count == 1) { var chch = state.children[ch].First(); nextId++; avgNodeQ = chch.actionStatistics.AverageQuality; tries = chch.actionStatistics.Tries; if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue hue = 0.0; sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine(); sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", chId, nextId, avgNodeQ, chch.expr).AppendLine(); } } foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) { TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId, state); } } } private static string WriteTree(Tree tree, State state) { var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture); var nodeIds = new Dictionary(); sb.Write( @"digraph { ratio = fill; node [style=filled]; "); int threshold = /* state.nodes.Count > 500 ? 10 : */ 0; foreach (var kvp in state.children) { var parent = kvp.Key; int parentId; if (!nodeIds.TryGetValue(parent, out parentId)) { parentId = nodeIds.Count + 1; var avgNodeQ = parent.actionStatistics.AverageQuality; var tries = parent.actionStatistics.Tries; if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue hue = 0.0; if (parent.actionStatistics.Tries > threshold) sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue); nodeIds.Add(parent, parentId); } foreach (var child in kvp.Value) { int childId; if (!nodeIds.TryGetValue(child, out childId)) { childId = nodeIds.Count + 1; nodeIds.Add(child, childId); } var avgNodeQ = child.actionStatistics.AverageQuality; var tries = child.actionStatistics.Tries; if (tries < 1) continue; if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue hue = 0.0; if (tries > threshold) { sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue); var edgeLabel = child.expr; // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, ""); sb.Write("{0} -> {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel); } } } sb.Write("}"); return sb.ToString(); } } }