[13645] | 1 | #region License Information
|
---|
| 2 | /* HeuristicLab
|
---|
[14185] | 3 | * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
|
---|
[13645] | 4 | *
|
---|
| 5 | * This file is part of HeuristicLab.
|
---|
| 6 | *
|
---|
| 7 | * HeuristicLab is free software: you can redistribute it and/or modify
|
---|
| 8 | * it under the terms of the GNU General Public License as published by
|
---|
| 9 | * the Free Software Foundation, either version 3 of the License, or
|
---|
| 10 | * (at your option) any later version.
|
---|
| 11 | *
|
---|
| 12 | * HeuristicLab is distributed in the hope that it will be useful,
|
---|
| 13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
---|
| 14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
---|
| 15 | * GNU General Public License for more details.
|
---|
| 16 | *
|
---|
| 17 | * You should have received a copy of the GNU General Public License
|
---|
| 18 | * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
|
---|
| 19 | */
|
---|
| 20 | #endregion
|
---|
| 21 |
|
---|
| 22 | using System;
|
---|
| 23 | using System.Collections.Generic;
|
---|
[15410] | 24 | using System.Diagnostics;
|
---|
[13645] | 25 | using System.Diagnostics.Contracts;
|
---|
| 26 | using System.Linq;
|
---|
[15414] | 27 | using System.Text;
|
---|
[13658] | 28 | using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
|
---|
[13645] | 29 | using HeuristicLab.Core;
|
---|
| 30 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
[15360] | 31 | using HeuristicLab.Optimization;
|
---|
[13645] | 32 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 33 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
| 34 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
| 35 | using HeuristicLab.Random;
|
---|
| 36 |
|
---|
| 37 | namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
|
---|
| 38 | public static class MctsSymbolicRegressionStatic {
|
---|
[15403] | 39 | // OBJECTIVES:
|
---|
| 40 | // 1) solve toy problems without numeric constants (to show that structure search is effective / efficient)
|
---|
| 41 | // - e.g. Keijzer, Nguyen ... where no numeric constants are involved
|
---|
| 42 | // - assumptions:
|
---|
| 43 | // - we don't know the necessary operations or functions -> all available functions could be necessary
|
---|
| 44 | // - but we do not need to tune numeric constants -> no scaling of input variables x!
|
---|
| 45 | // 2) Solve toy problems with numeric constants to make the algorithm invariant concerning variable scale.
|
---|
| 46 | // This is important for real world applications.
|
---|
| 47 | // - e.g. Korns or Vladislavleva problems where numeric constants are involved
|
---|
| 48 | // - assumptions:
|
---|
| 49 | // - any numeric constant is possible (a-priori we might assume that small abs. constants are more likely)
|
---|
| 50 | // - standardization of variables is possible (or might be necessary) as we adjust numeric parameters of the expression anyway
|
---|
| 51 | // - to simplify the problem we can restrict the set of functions e.g. we assume which functions are necessary for the problem instance
|
---|
| 52 | // -> several steps: (a) polyinomials, (b) rational polynomials, (c) exponential or logarithmic functions, rational functions with exponential and logarithmic parts
|
---|
| 53 | // 3) efficiency and effectiveness for real-world problems
|
---|
| 54 | // - e.g. Tower problem
|
---|
| 55 | // - (1) and (2) combined, structure search must be effective in combination with numeric optimization of constants
|
---|
| 56 | //
|
---|
| 57 |
|
---|
[15410] | 58 | // TODO: Taking averages of R² values is probably not ideal as an improvement of R² from 0.99 to 0.999 should
|
---|
| 59 | // weight more than an improvement from 0.98 to 0.99. Also, we are more interested in the best value of a
|
---|
| 60 | // branch and less in the expected value. (--> Review "Extreme Bandit" literature again)
|
---|
[15414] | 61 | // TODO: Solve Poly-10
|
---|
[15410] | 62 | // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?
|
---|
[15416] | 63 | // TODO: Why is the algorithm so slow for rather greedy policies (e.g. low C value in UCB)?
|
---|
| 64 | // TODO: check if we can use a quality measure with range [-1..1] in policies
|
---|
[15414] | 65 | // TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants
|
---|
[15404] | 66 | // TODO: check if transformation of y is correct and works (Obj 2)
|
---|
| 67 | // TODO: The algorithm is not invariant to location and scale of variables.
|
---|
| 68 | // Include offset for variables as parameter (for Objective 2)
|
---|
| 69 | // TODO: why does LM optimization converge so slowly with exp(x), log(x), and 1/x allowed (Obj 2)?
|
---|
| 70 | // TODO: support e(-x) and possibly (1/-x) (Obj 1)
|
---|
| 71 | // TODO: is it OK to initialize all constants to 1 (Obj 2)?
|
---|
[15414] | 72 | // TODO: improve memory usage
|
---|
[15420] | 73 | // TODO: support empty test partition
|
---|
[13645] | 74 | #region static API
|
---|
| 75 |
|
---|
| 76 | public interface IState {
|
---|
| 77 | bool Done { get; }
|
---|
| 78 | ISymbolicRegressionModel BestModel { get; }
|
---|
| 79 | double BestSolutionTrainingQuality { get; }
|
---|
| 80 | double BestSolutionTestQuality { get; }
|
---|
[15360] | 81 | IEnumerable<ISymbolicRegressionSolution> ParetoBestModels { get; }
|
---|
[13651] | 82 | int TotalRollouts { get; }
|
---|
| 83 | int EffectiveRollouts { get; }
|
---|
| 84 | int FuncEvaluations { get; }
|
---|
| 85 | int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
|
---|
| 86 | // TODO other stats on LM optimizer might be interesting here
|
---|
[13645] | 87 | }
|
---|
| 88 |
|
---|
| 89 | // created through factory method
|
---|
| 90 | private class State : IState {
|
---|
| 91 | private const int MaxParams = 100;
|
---|
| 92 |
|
---|
| 93 | // state variables used by MCTS
|
---|
| 94 | internal readonly Automaton automaton;
|
---|
| 95 | internal IRandom random { get; private set; }
|
---|
| 96 | internal readonly Tree tree;
|
---|
| 97 | internal readonly Func<byte[], int, double> evalFun;
|
---|
[13658] | 98 | internal readonly IPolicy treePolicy;
|
---|
[13651] | 99 | // MCTS might get stuck. Track statistics on the number of effective rollouts
|
---|
| 100 | internal int totalRollouts;
|
---|
| 101 | internal int effectiveRollouts;
|
---|
[13645] | 102 |
|
---|
| 103 |
|
---|
| 104 | // state variables used only internally (for eval function)
|
---|
| 105 | private readonly IRegressionProblemData problemData;
|
---|
| 106 | private readonly double[][] x;
|
---|
| 107 | private readonly double[] y;
|
---|
| 108 | private readonly double[][] testX;
|
---|
| 109 | private readonly double[] testY;
|
---|
| 110 | private readonly double[] scalingFactor;
|
---|
| 111 | private readonly double[] scalingOffset;
|
---|
[15403] | 112 | private readonly double yStdDev; // for scaling parameters (e.g. stopping condition for LM)
|
---|
[13645] | 113 | private readonly int constOptIterations;
|
---|
[15403] | 114 | private readonly double lambda; // weight of penalty term for regularization
|
---|
[13645] | 115 | private readonly double lowerEstimationLimit, upperEstimationLimit;
|
---|
[15403] | 116 | private readonly bool collectParetoOptimalModels;
|
---|
[15360] | 117 | private readonly List<ISymbolicRegressionSolution> paretoBestModels = new List<ISymbolicRegressionSolution>();
|
---|
| 118 | private readonly List<double[]> paretoFront = new List<double[]>(); // matching the models
|
---|
[13645] | 119 |
|
---|
| 120 | private readonly ExpressionEvaluator evaluator, testEvaluator;
|
---|
| 121 |
|
---|
[15416] | 122 | internal readonly Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();
|
---|
| 123 | internal readonly Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();
|
---|
| 124 | internal readonly Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();
|
---|
| 125 |
|
---|
[13645] | 126 | // values for best solution
|
---|
[15416] | 127 | private double bestR;
|
---|
[13645] | 128 | private byte[] bestCode;
|
---|
| 129 | private int bestNParams;
|
---|
| 130 | private double[] bestConsts;
|
---|
| 131 |
|
---|
[13651] | 132 | // stats
|
---|
| 133 | private int funcEvaluations;
|
---|
| 134 | private int gradEvaluations;
|
---|
| 135 |
|
---|
[13645] | 136 | // buffers
|
---|
| 137 | private readonly double[] ones; // vector of ones (as default params)
|
---|
| 138 | private readonly double[] constsBuf;
|
---|
| 139 | private readonly double[] predBuf, testPredBuf;
|
---|
| 140 | private readonly double[][] gradBuf;
|
---|
| 141 |
|
---|
[15420] | 142 | // debugging stats
|
---|
| 143 | // calculate for each level the number of alternatives the average 'inequality' of tries and 'inequality' of quality over the alternatives for each trie
|
---|
| 144 | // inequality can be calculated using the Gini coefficient
|
---|
| 145 | internal readonly double[] giniCoeffs = new double[100];
|
---|
| 146 |
|
---|
| 147 |
|
---|
[15403] | 148 | public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,
|
---|
| 149 | int constOptIterations, double lambda,
|
---|
[13658] | 150 | IPolicy treePolicy = null,
|
---|
[15403] | 151 | bool collectParetoOptimalModels = false,
|
---|
[13645] | 152 | double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
|
---|
| 153 | bool allowProdOfVars = true,
|
---|
| 154 | bool allowExp = true,
|
---|
| 155 | bool allowLog = true,
|
---|
| 156 | bool allowInv = true,
|
---|
| 157 | bool allowMultipleTerms = false) {
|
---|
| 158 |
|
---|
[15403] | 159 | if (lambda < 0) throw new ArgumentException("Lambda must be larger or equal zero", "lambda");
|
---|
| 160 |
|
---|
[13645] | 161 | this.problemData = problemData;
|
---|
| 162 | this.constOptIterations = constOptIterations;
|
---|
[15403] | 163 | this.lambda = lambda;
|
---|
[13645] | 164 | this.evalFun = this.Eval;
|
---|
| 165 | this.lowerEstimationLimit = lowerEstimationLimit;
|
---|
| 166 | this.upperEstimationLimit = upperEstimationLimit;
|
---|
[15403] | 167 | this.collectParetoOptimalModels = collectParetoOptimalModels;
|
---|
[13645] | 168 |
|
---|
| 169 | random = new MersenneTwister(randSeed);
|
---|
| 170 |
|
---|
| 171 | // prepare data for evaluation
|
---|
| 172 | double[][] x;
|
---|
| 173 | double[] y;
|
---|
| 174 | double[][] testX;
|
---|
| 175 | double[] testY;
|
---|
| 176 | double[] scalingFactor;
|
---|
| 177 | double[] scalingOffset;
|
---|
| 178 | // get training and test datasets (scale linearly based on training set if required)
|
---|
| 179 | GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
|
---|
| 180 | GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
|
---|
| 181 | this.x = x;
|
---|
| 182 | this.y = y;
|
---|
[15403] | 183 | this.yStdDev = HeuristicLab.Common.EnumerableStatisticExtensions.StandardDeviation(y);
|
---|
[13645] | 184 | this.testX = testX;
|
---|
| 185 | this.testY = testY;
|
---|
| 186 | this.scalingFactor = scalingFactor;
|
---|
| 187 | this.scalingOffset = scalingOffset;
|
---|
| 188 | this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
|
---|
| 189 | // we need a separate evaluator because the vector length for the test dataset might differ
|
---|
| 190 | this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
|
---|
| 191 |
|
---|
[15414] | 192 | this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
|
---|
[13658] | 193 | this.treePolicy = treePolicy ?? new Ucb();
|
---|
[15410] | 194 | this.tree = new Tree() {
|
---|
| 195 | state = automaton.CurrentState,
|
---|
| 196 | actionStatistics = treePolicy.CreateActionStatistics(),
|
---|
[15414] | 197 | expr = "",
|
---|
| 198 | level = 0
|
---|
[15410] | 199 | };
|
---|
[13645] | 200 |
|
---|
| 201 | // reset best solution
|
---|
[15416] | 202 | this.bestR = 0;
|
---|
[13645] | 203 | // code for default solution (constant model)
|
---|
| 204 | this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
|
---|
| 205 | this.bestNParams = 0;
|
---|
| 206 | this.bestConsts = null;
|
---|
| 207 |
|
---|
| 208 | // init buffers
|
---|
| 209 | this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
|
---|
| 210 | constsBuf = new double[MaxParams];
|
---|
| 211 | this.predBuf = new double[y.Length];
|
---|
| 212 | this.testPredBuf = new double[testY.Length];
|
---|
| 213 |
|
---|
| 214 | this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
|
---|
| 215 | }
|
---|
| 216 |
|
---|
| 217 | #region IState inferface
|
---|
[13658] | 218 | public bool Done { get { return tree != null && tree.Done; } }
|
---|
[13645] | 219 |
|
---|
| 220 | public double BestSolutionTrainingQuality {
|
---|
| 221 | get {
|
---|
| 222 | evaluator.Exec(bestCode, x, bestConsts, predBuf);
|
---|
[15416] | 223 | return Rho(y, predBuf);
|
---|
[13645] | 224 | }
|
---|
| 225 | }
|
---|
| 226 |
|
---|
| 227 | public double BestSolutionTestQuality {
|
---|
| 228 | get {
|
---|
| 229 | testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
|
---|
[15416] | 230 | return Rho(testY, testPredBuf);
|
---|
[13645] | 231 | }
|
---|
| 232 | }
|
---|
| 233 |
|
---|
| 234 | // takes the code of the best solution and creates and equivalent symbolic regression model
|
---|
| 235 | public ISymbolicRegressionModel BestModel {
|
---|
| 236 | get {
|
---|
| 237 | var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
|
---|
| 238 | var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
|
---|
| 239 |
|
---|
| 240 | var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
|
---|
[13941] | 241 | var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
|
---|
[15403] | 242 | model.Scale(problemData); // apply linear scaling
|
---|
[13645] | 243 | return model;
|
---|
| 244 | }
|
---|
| 245 | }
|
---|
[15360] | 246 | public IEnumerable<ISymbolicRegressionSolution> ParetoBestModels {
|
---|
| 247 | get { return paretoBestModels; }
|
---|
| 248 | }
|
---|
[13651] | 249 |
|
---|
| 250 | public int TotalRollouts { get { return totalRollouts; } }
|
---|
| 251 | public int EffectiveRollouts { get { return effectiveRollouts; } }
|
---|
| 252 | public int FuncEvaluations { get { return funcEvaluations; } }
|
---|
| 253 | public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
|
---|
| 254 |
|
---|
[13645] | 255 | #endregion
|
---|
| 256 |
|
---|
| 257 | private double Eval(byte[] code, int nParams) {
|
---|
| 258 | double[] optConsts;
|
---|
| 259 | double q;
|
---|
| 260 | Eval(code, nParams, out q, out optConsts);
|
---|
| 261 |
|
---|
[15360] | 262 | // single objective best
|
---|
[15416] | 263 | if (q > bestR) {
|
---|
| 264 | bestR = q;
|
---|
[13645] | 265 | bestNParams = nParams;
|
---|
| 266 | this.bestCode = new byte[code.Length];
|
---|
| 267 | this.bestConsts = new double[bestNParams];
|
---|
| 268 |
|
---|
| 269 | Array.Copy(code, bestCode, code.Length);
|
---|
| 270 | Array.Copy(optConsts, bestConsts, bestNParams);
|
---|
| 271 | }
|
---|
[15403] | 272 | if (collectParetoOptimalModels) {
|
---|
| 273 | // multi-objective best
|
---|
| 274 | var complexity = // SymbolicDataAnalysisModelComplexityCalculator.CalculateComplexity() TODO: implement Kommenda's tree complexity directly in the evaluator
|
---|
| 275 | Array.FindIndex(code, (opc) => opc == (byte)OpCodes.Exit); // use length of expression as surrogate for complexity
|
---|
| 276 | UpdateParetoFront(q, complexity, code, optConsts, nParams, scalingFactor, scalingOffset);
|
---|
| 277 | }
|
---|
[13645] | 278 | return q;
|
---|
| 279 | }
|
---|
| 280 |
|
---|
[15416] | 281 | private void Eval(byte[] code, int nParams, out double rho, out double[] optConsts) {
|
---|
[13645] | 282 | // we make a first pass to determine a valid starting configuration for all constants
|
---|
| 283 | // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
|
---|
| 284 | // scale and offset are set to optimal starting configuration
|
---|
| 285 | // assumes scale is the first param and offset is the last param
|
---|
| 286 |
|
---|
| 287 | // reset constants
|
---|
| 288 | Array.Copy(ones, constsBuf, nParams);
|
---|
| 289 | evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
|
---|
[13651] | 290 | funcEvaluations++;
|
---|
[13645] | 291 |
|
---|
[15403] | 292 | if (nParams == 0 || constOptIterations < 0) {
|
---|
[13645] | 293 | // if we don't need to optimize parameters then we are done
|
---|
| 294 | // changing scale and offset does not influence r²
|
---|
[15416] | 295 | rho = Rho(y, predBuf);
|
---|
[13645] | 296 | optConsts = constsBuf;
|
---|
| 297 | } else {
|
---|
| 298 | // optimize constants using the starting point calculated above
|
---|
| 299 | OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
|
---|
[13651] | 300 |
|
---|
[13645] | 301 | evaluator.Exec(code, x, constsBuf, predBuf);
|
---|
[13651] | 302 | funcEvaluations++;
|
---|
| 303 |
|
---|
[15416] | 304 | rho = Rho(y, predBuf);
|
---|
[13645] | 305 | optConsts = constsBuf;
|
---|
| 306 | }
|
---|
| 307 | }
|
---|
| 308 |
|
---|
| 309 |
|
---|
| 310 |
|
---|
| 311 | #region helpers
|
---|
[15416] | 312 | private static double Rho(IEnumerable<double> x, IEnumerable<double> y) {
|
---|
[13645] | 313 | OnlineCalculatorError error;
|
---|
| 314 | double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
|
---|
[15416] | 315 | return error == OnlineCalculatorError.None ? r : 0.0;
|
---|
[13645] | 316 | }
|
---|
| 317 |
|
---|
| 318 |
|
---|
| 319 | private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
|
---|
[13651] | 320 | double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
|
---|
[13645] | 321 | Array.Copy(consts, optConsts, nParams);
|
---|
| 322 |
|
---|
[15403] | 323 | // direct usage of LM is recommended in alglib manual for better performance than the lsfit interface (which uses lm internally).
|
---|
[13645] | 324 | alglib.minlmstate state;
|
---|
| 325 | alglib.minlmreport rep = null;
|
---|
[15403] | 326 | alglib.minlmcreatevj(y.Length + 1, optConsts, out state); // +1 for penalty term
|
---|
| 327 | // Using the change of the gradient as stopping criterion is recommended in alglib manual.
|
---|
| 328 | // However, the most recent version of alglib (as of Oct 2017) only supports epsX as stopping criterion
|
---|
| 329 | alglib.minlmsetcond(state, epsg: 1E-6 * yStdDev, epsf: epsF, epsx: 0.0, maxits: nIters);
|
---|
| 330 | // alglib.minlmsetgradientcheck(state, 1E-5);
|
---|
[13645] | 331 | alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
|
---|
| 332 | alglib.minlmresults(state, out optConsts, out rep);
|
---|
[13651] | 333 | funcEvaluations += rep.nfunc;
|
---|
| 334 | gradEvaluations += rep.njac * nParams;
|
---|
[13645] | 335 |
|
---|
| 336 | if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
|
---|
| 337 |
|
---|
| 338 | // only use optimized constants if successful
|
---|
| 339 | if (rep.terminationtype >= 0) {
|
---|
| 340 | Array.Copy(optConsts, consts, optConsts.Length);
|
---|
| 341 | }
|
---|
| 342 | }
|
---|
| 343 |
|
---|
| 344 | private void Func(double[] arg, double[] fi, object obj) {
|
---|
| 345 | var code = (byte[])obj;
|
---|
[15403] | 346 | int n = predBuf.Length;
|
---|
[13645] | 347 | evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
|
---|
[15403] | 348 | for (int r = 0; r < n; r++) {
|
---|
[13645] | 349 | var res = predBuf[r] - y[r];
|
---|
| 350 | fi[r] = res;
|
---|
| 351 | }
|
---|
[15403] | 352 |
|
---|
| 353 | var penaltyIdx = fi.Length - 1;
|
---|
| 354 | fi[penaltyIdx] = 0.0;
|
---|
| 355 | // calc length of parameter vector for regularization
|
---|
| 356 | var aa = 0.0;
|
---|
| 357 | for (int i = 0; i < arg.Length; i++) {
|
---|
| 358 | aa += arg[i] * arg[i];
|
---|
| 359 | }
|
---|
| 360 | if (lambda > 0 && aa > 0) {
|
---|
| 361 | // scale lambda using stdDev(y) to make the parameter independent of the scale of y
|
---|
| 362 | // scale lambda using n to make parameter independent of the number of training points
|
---|
| 363 | // take the root because LM squares the result
|
---|
| 364 | fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
|
---|
| 365 | }
|
---|
[13645] | 366 | }
|
---|
[15403] | 367 |
|
---|
[13645] | 368 | private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
|
---|
[15403] | 369 | int n = predBuf.Length;
|
---|
[13645] | 370 | int nParams = arg.Length;
|
---|
| 371 | var code = (byte[])obj;
|
---|
| 372 | evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
|
---|
[15403] | 373 | for (int r = 0; r < n; r++) {
|
---|
[13645] | 374 | var res = predBuf[r] - y[r];
|
---|
| 375 | fi[r] = res;
|
---|
| 376 |
|
---|
| 377 | for (int k = 0; k < nParams; k++) {
|
---|
| 378 | jac[r, k] = gradBuf[k][r];
|
---|
| 379 | }
|
---|
| 380 | }
|
---|
[15403] | 381 | // calc length of parameter vector for regularization
|
---|
| 382 | double aa = 0.0;
|
---|
| 383 | for (int i = 0; i < arg.Length; i++) {
|
---|
| 384 | aa += arg[i] * arg[i];
|
---|
| 385 | }
|
---|
| 386 |
|
---|
| 387 | var penaltyIdx = fi.Length - 1;
|
---|
| 388 | if (lambda > 0 && aa > 0) {
|
---|
| 389 | fi[penaltyIdx] = 0.0;
|
---|
| 390 | // scale lambda using stdDev(y) to make the parameter independent of the scale of y
|
---|
| 391 | // scale lambda using n to make parameter independent of the number of training points
|
---|
| 392 | // take the root because alglib LM squares the result
|
---|
| 393 | fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
|
---|
| 394 |
|
---|
| 395 | for (int i = 0; i < arg.Length; i++) {
|
---|
| 396 | jac[penaltyIdx, i] = 0.5 / fi[penaltyIdx] * 2 * n * lambda / yStdDev * arg[i];
|
---|
| 397 | }
|
---|
| 398 | } else {
|
---|
| 399 | fi[penaltyIdx] = 0.0;
|
---|
| 400 | for (int i = 0; i < arg.Length; i++) {
|
---|
| 401 | jac[penaltyIdx, i] = 0.0;
|
---|
| 402 | }
|
---|
| 403 | }
|
---|
[13645] | 404 | }
|
---|
[15403] | 405 |
|
---|
| 406 |
|
---|
| 407 | private void UpdateParetoFront(double q, int complexity, byte[] code, double[] param, int nParam,
|
---|
| 408 | double[] scalingFactor, double[] scalingOffset) {
|
---|
| 409 | double[] best = new double[2];
|
---|
| 410 | double[] cur = new double[2] { q, complexity };
|
---|
| 411 | bool[] max = new[] { true, false };
|
---|
| 412 | var isNonDominated = true;
|
---|
| 413 | foreach (var e in paretoFront) {
|
---|
| 414 | var domRes = DominationCalculator<int>.Dominates(cur, e, max, true);
|
---|
| 415 | if (domRes == DominationResult.IsDominated) {
|
---|
| 416 | isNonDominated = false;
|
---|
| 417 | break;
|
---|
| 418 | }
|
---|
| 419 | }
|
---|
| 420 | if (isNonDominated) {
|
---|
| 421 | paretoFront.Add(cur);
|
---|
| 422 |
|
---|
| 423 | // create model
|
---|
| 424 | var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
|
---|
| 425 | var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
|
---|
| 426 |
|
---|
| 427 | var t = new SymbolicExpressionTree(treeGen.Exec(code, param, nParam, scalingFactor, scalingOffset));
|
---|
| 428 | var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
|
---|
| 429 | model.Scale(problemData); // apply linear scaling
|
---|
| 430 |
|
---|
| 431 | var sol = model.CreateRegressionSolution(this.problemData);
|
---|
| 432 | sol.Name = string.Format("{0:N5} {1}", q, complexity);
|
---|
| 433 |
|
---|
| 434 | paretoBestModels.Add(sol);
|
---|
| 435 | }
|
---|
| 436 | for (int i = paretoFront.Count - 2; i >= 0; i--) {
|
---|
| 437 | var @ref = paretoFront[i];
|
---|
| 438 | var domRes = DominationCalculator<int>.Dominates(cur, @ref, max, true);
|
---|
| 439 | if (domRes == DominationResult.Dominates) {
|
---|
| 440 | paretoFront.RemoveAt(i);
|
---|
| 441 | paretoBestModels.RemoveAt(i);
|
---|
| 442 | }
|
---|
| 443 | }
|
---|
| 444 | }
|
---|
[15420] | 445 |
|
---|
[13645] | 446 | #endregion
|
---|
[15420] | 447 |
|
---|
| 448 | #if DEBUG
|
---|
| 449 | internal void ClearStats() {
|
---|
| 450 | for (int i = 0; i < giniCoeffs.Length; i++) giniCoeffs[i] = -1;
|
---|
| 451 | }
|
---|
| 452 | internal void WriteStats() {
|
---|
| 453 | Console.WriteLine(string.Join("\t", giniCoeffs.TakeWhile(x => x >= 0).Select(x => string.Format("{0:N3}", x))));
|
---|
| 454 | }
|
---|
| 455 |
|
---|
| 456 | #endif
|
---|
| 457 |
|
---|
[13645] | 458 | }
|
---|
| 459 |
|
---|
[15403] | 460 |
|
---|
| 461 | /// <summary>
|
---|
| 462 | /// Static method to initialize a state for the algorithm
|
---|
| 463 | /// </summary>
|
---|
| 464 | /// <param name="problemData">The problem data</param>
|
---|
| 465 | /// <param name="randSeed">Random seed.</param>
|
---|
| 466 | /// <param name="maxVariables">Maximum number of variable references that are allowed in the expression.</param>
|
---|
| 467 | /// <param name="scaleVariables">Optionally scale input variables to the interval [0..1] (recommended)</param>
|
---|
| 468 | /// <param name="constOptIterations">Maximum number of iterations for constants optimization (Levenberg-Marquardt)</param>
|
---|
| 469 | /// <param name="lambda">Penalty factor for regularization (0..inf.), small penalty disabled regularization.</param>
|
---|
| 470 | /// <param name="policy">Tree search policy (random, ucb, eps-greedy, ...)</param>
|
---|
| 471 | /// <param name="collectParameterOptimalModels">Optionally collect all Pareto-optimal solutions having minimal length and error.</param>
|
---|
| 472 | /// <param name="lowerEstimationLimit">Optionally limit the result of the expression to this lower value.</param>
|
---|
| 473 | /// <param name="upperEstimationLimit">Optionally limit the result of the expression to this upper value.</param>
|
---|
| 474 | /// <param name="allowProdOfVars">Allow products of expressions.</param>
|
---|
| 475 | /// <param name="allowExp">Allow expressions with exponentials.</param>
|
---|
| 476 | /// <param name="allowLog">Allow expressions with logarithms</param>
|
---|
| 477 | /// <param name="allowInv">Allow expressions with 1/x</param>
|
---|
| 478 | /// <param name="allowMultipleTerms">Allow expressions which are sums of multiple terms.</param>
|
---|
| 479 | /// <returns></returns>
|
---|
| 480 |
|
---|
[13658] | 481 | public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
|
---|
[15403] | 482 | bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0,
|
---|
[13658] | 483 | IPolicy policy = null,
|
---|
[15403] | 484 | bool collectParameterOptimalModels = false,
|
---|
[13658] | 485 | double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
|
---|
[13645] | 486 | bool allowProdOfVars = true,
|
---|
| 487 | bool allowExp = true,
|
---|
| 488 | bool allowLog = true,
|
---|
| 489 | bool allowInv = true,
|
---|
| 490 | bool allowMultipleTerms = false
|
---|
| 491 | ) {
|
---|
[15403] | 492 | return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,
|
---|
| 493 | policy, collectParameterOptimalModels,
|
---|
[13645] | 494 | lowerEstimationLimit, upperEstimationLimit,
|
---|
| 495 | allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
|
---|
| 496 | }
|
---|
| 497 |
|
---|
| 498 | // returns the quality of the evaluated solution
|
---|
| 499 | public static double MakeStep(IState state) {
|
---|
| 500 | var mctsState = state as State;
|
---|
| 501 | if (mctsState == null) throw new ArgumentException("state");
|
---|
| 502 | if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
|
---|
| 503 |
|
---|
| 504 | return TreeSearch(mctsState);
|
---|
| 505 | }
|
---|
| 506 | #endregion
|
---|
| 507 |
|
---|
| 508 | private static double TreeSearch(State mctsState) {
|
---|
| 509 | var automaton = mctsState.automaton;
|
---|
| 510 | var tree = mctsState.tree;
|
---|
| 511 | var eval = mctsState.evalFun;
|
---|
| 512 | var rand = mctsState.random;
|
---|
[13658] | 513 | var treePolicy = mctsState.treePolicy;
|
---|
[13651] | 514 | double q = 0;
|
---|
| 515 | bool success = false;
|
---|
| 516 | do {
|
---|
[15420] | 517 | #if DEBUG
|
---|
| 518 | mctsState.ClearStats();
|
---|
| 519 | #endif
|
---|
[13651] | 520 | automaton.Reset();
|
---|
[15416] | 521 | success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, mctsState, out q);
|
---|
[13651] | 522 | mctsState.totalRollouts++;
|
---|
[13658] | 523 | } while (!success && !tree.Done);
|
---|
[13651] | 524 | mctsState.effectiveRollouts++;
|
---|
[15410] | 525 |
|
---|
[15420] | 526 | #if DEBUG
|
---|
| 527 | mctsState.WriteStats();
|
---|
| 528 | #endif
|
---|
[15416] | 529 | //if (mctsState.effectiveRollouts % 100 == 1) {
|
---|
[15420] | 530 | // Console.WriteLine(WriteTree(tree, mctsState));
|
---|
| 531 | // Console.WriteLine(TraceTree(tree, mctsState));
|
---|
[15416] | 532 | //}
|
---|
[13651] | 533 | return q;
|
---|
[13645] | 534 | }
|
---|
| 535 |
|
---|
[15410] | 536 | // search forward
|
---|
| 537 | private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
|
---|
[15416] | 538 | State state,
|
---|
[15410] | 539 | out double q) {
|
---|
| 540 | // ROLLOUT AND EXPANSION
|
---|
| 541 | // We are navigating a graph (states might be reached via different paths) instead of a tree.
|
---|
| 542 | // State equivalence is checked through ExprHash (based on the generated code through the path).
|
---|
| 543 |
|
---|
| 544 | // We switch between rollout-mode and expansion mode
|
---|
| 545 | // Rollout-mode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB)
|
---|
| 546 | // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression)
|
---|
| 547 | // In expansion mode we might re-enter the graph and switch back to rollout-mode
|
---|
| 548 | // We do this until we reach a complete expression (final state)
|
---|
| 549 |
|
---|
[15414] | 550 | // Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent
|
---|
[15410] | 551 | // Sub-graphs which have been completely searched are marked as done.
|
---|
| 552 | // Roll-out could lead to a state where all follow-states are done. In this case we call the rollout ineffective.
|
---|
| 553 |
|
---|
| 554 | while (!automaton.IsFinalState(automaton.CurrentState)) {
|
---|
[15416] | 555 | if (state.children.ContainsKey(tree)) {
|
---|
| 556 | if (state.children[tree].All(ch => ch.Done)) {
|
---|
[15414] | 557 | tree.Done = true;
|
---|
| 558 | break;
|
---|
| 559 | }
|
---|
[15410] | 560 | // ROLLOUT INSIDE TREE
|
---|
| 561 | // UCT selection within tree
|
---|
| 562 | int selectedIdx = 0;
|
---|
[15416] | 563 | if (state.children[tree].Count > 1) {
|
---|
| 564 | selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand);
|
---|
[15410] | 565 | }
|
---|
[15420] | 566 |
|
---|
| 567 | // STATS
|
---|
| 568 | state.giniCoeffs[tree.level] = InequalityCoefficient(state.children[tree].Select(ch => (double)ch.actionStatistics.AverageQuality));
|
---|
| 569 |
|
---|
[15416] | 570 | tree = state.children[tree][selectedIdx];
|
---|
[15410] | 571 |
|
---|
| 572 | // move the automaton forward until reaching the state
|
---|
| 573 | // all steps where no alternatives are possible are immediately taken
|
---|
| 574 | // TODO: simplification of the automaton
|
---|
| 575 | int[] possibleFollowStates;
|
---|
| 576 | int nFs;
|
---|
| 577 | automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
|
---|
[15420] | 578 | while (automaton.CurrentState != tree.state && nFs == 1 &&
|
---|
| 579 | !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
|
---|
[15410] | 580 | automaton.Goto(possibleFollowStates[0]);
|
---|
| 581 | automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
|
---|
| 582 | }
|
---|
| 583 | Debug.Assert(possibleFollowStates.Contains(tree.state));
|
---|
| 584 | automaton.Goto(tree.state);
|
---|
| 585 | } else {
|
---|
| 586 | // EXPAND
|
---|
| 587 | int[] possibleFollowStates;
|
---|
| 588 | int nFs;
|
---|
[15416] | 589 | string actionString = "";
|
---|
[15410] | 590 | automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
|
---|
[15414] | 591 | while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
|
---|
[15416] | 592 | actionString += " " + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[0]);
|
---|
[15410] | 593 | // no alternatives -> just go to the next state
|
---|
| 594 | automaton.Goto(possibleFollowStates[0]);
|
---|
| 595 | automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
|
---|
| 596 | }
|
---|
| 597 | if (nFs == 0) {
|
---|
| 598 | // stuck in a dead end (no final state and no allowed follow states)
|
---|
| 599 | tree.Done = true;
|
---|
| 600 | break;
|
---|
| 601 | }
|
---|
| 602 | var newChildren = new List<Tree>(nFs);
|
---|
[15416] | 603 | state.children.Add(tree, newChildren);
|
---|
[15410] | 604 | for (int i = 0; i < nFs; i++) {
|
---|
| 605 | Tree child = null;
|
---|
[15414] | 606 | // for selected states (EvalStates) we introduce state unification (detection of equivalent states)
|
---|
[15410] | 607 | if (automaton.IsEvalState(possibleFollowStates[i])) {
|
---|
| 608 | var hc = Hashcode(automaton);
|
---|
[15416] | 609 | if (!state.nodes.TryGetValue(hc, out child)) {
|
---|
[15410] | 610 | child = new Tree() {
|
---|
| 611 | children = null,
|
---|
| 612 | state = possibleFollowStates[i],
|
---|
| 613 | actionStatistics = treePolicy.CreateActionStatistics(),
|
---|
[15416] | 614 | expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
|
---|
[15414] | 615 | level = tree.level + 1
|
---|
[15410] | 616 | };
|
---|
[15416] | 617 | state.nodes.Add(hc, child);
|
---|
[15414] | 618 | }
|
---|
| 619 | // only allow forward edges (don't add the child if we would go back in the graph)
|
---|
[15416] | 620 | else if (child.level > tree.level) {
|
---|
[15410] | 621 | // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
|
---|
| 622 | // to all parents
|
---|
[15416] | 623 | BackpropagateStatistics(child.actionStatistics, tree, state);
|
---|
[15414] | 624 | } else {
|
---|
| 625 | // prevent cycles
|
---|
| 626 | Debug.Assert(child.level <= tree.level);
|
---|
| 627 | child = null;
|
---|
[15416] | 628 | }
|
---|
[15410] | 629 | } else {
|
---|
| 630 | child = new Tree() {
|
---|
| 631 | children = null,
|
---|
| 632 | state = possibleFollowStates[i],
|
---|
| 633 | actionStatistics = treePolicy.CreateActionStatistics(),
|
---|
[15416] | 634 | expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
|
---|
[15414] | 635 | level = tree.level + 1
|
---|
[15410] | 636 | };
|
---|
| 637 | }
|
---|
[15414] | 638 | if (child != null)
|
---|
| 639 | newChildren.Add(child);
|
---|
[15410] | 640 | }
|
---|
| 641 |
|
---|
[15414] | 642 | if (!newChildren.Any()) {
|
---|
| 643 | // stuck in a dead end (no final state and no allowed follow states)
|
---|
| 644 | tree.Done = true;
|
---|
| 645 | break;
|
---|
| 646 | }
|
---|
| 647 |
|
---|
[15410] | 648 | foreach (var ch in newChildren) {
|
---|
[15416] | 649 | if (!state.parents.ContainsKey(ch)) {
|
---|
| 650 | state.parents.Add(ch, new List<Tree>());
|
---|
[15410] | 651 | }
|
---|
[15416] | 652 | state.parents[ch].Add(tree);
|
---|
[15410] | 653 | }
|
---|
| 654 |
|
---|
[15414] | 655 |
|
---|
[15416] | 656 | // follow one of the children
|
---|
| 657 | tree = SelectStateLeadingToFinal(automaton, tree, rand, state);
|
---|
[15410] | 658 | automaton.Goto(tree.state);
|
---|
| 659 | }
|
---|
| 660 | }
|
---|
| 661 |
|
---|
| 662 | bool success;
|
---|
| 663 |
|
---|
| 664 | // EVALUATE TREE
|
---|
| 665 | if (automaton.IsFinalState(automaton.CurrentState)) {
|
---|
| 666 | tree.Done = true;
|
---|
[15414] | 667 | tree.expr = ExprStr(automaton);
|
---|
[15410] | 668 | byte[] code; int nParams;
|
---|
| 669 | automaton.GetCode(out code, out nParams);
|
---|
| 670 | q = eval(code, nParams);
|
---|
[15416] | 671 | // Console.WriteLine("{0:N4}\t{1}", q*q, tree.expr);
|
---|
[15410] | 672 | q = TransformQuality(q);
|
---|
| 673 | success = true;
|
---|
| 674 | } else {
|
---|
| 675 | // we got stuck in roll-out (not evaluation necessary!)
|
---|
[15416] | 676 | // Console.WriteLine("\t" + ExprStr(automaton) + " STOP");
|
---|
[15410] | 677 | q = 0.0;
|
---|
| 678 | success = false;
|
---|
| 679 | }
|
---|
| 680 |
|
---|
| 681 | // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS
|
---|
| 682 | // Update statistics
|
---|
| 683 | // Set branch to done if all children are done.
|
---|
[15416] | 684 | BackpropagateQuality(tree, q, treePolicy, state);
|
---|
[15410] | 685 |
|
---|
| 686 | return success;
|
---|
| 687 | }
|
---|
| 688 |
|
---|
[15420] | 689 | private static double InequalityCoefficient(IEnumerable<double> xs) {
|
---|
| 690 | var arr = xs.ToArray();
|
---|
| 691 | var sad = 0.0;
|
---|
| 692 | var sum = 0.0;
|
---|
[15410] | 693 |
|
---|
[15420] | 694 | for(int i=0;i<arr.Length;i++) {
|
---|
| 695 | for(int j=0;j<arr.Length;j++) {
|
---|
| 696 | sad += Math.Abs(arr[i] - arr[j]);
|
---|
| 697 | sum += arr[j];
|
---|
| 698 | }
|
---|
| 699 | }
|
---|
| 700 | return 0.5 * sad / sum;
|
---|
| 701 | }
|
---|
| 702 |
|
---|
[15410] | 703 | private static double TransformQuality(double q) {
|
---|
| 704 | // no transformation
|
---|
[15416] | 705 | // return q;
|
---|
[15410] | 706 |
|
---|
| 707 | // EXPERIMENTAL!
|
---|
[15416] | 708 |
|
---|
| 709 | // Fisher transformation
|
---|
| 710 | // (assumes q is Correl(pred, target)
|
---|
| 711 |
|
---|
| 712 | q = Math.Min(q, 0.99999999);
|
---|
| 713 | q = Math.Max(q, -0.99999999);
|
---|
| 714 | return 0.5 * Math.Log((1 + q) / (1 - q));
|
---|
| 715 |
|
---|
[15410] | 716 | // optimal result: q = 1 -> return huge value
|
---|
[15414] | 717 | // if (q >= 1.0) return 1E16;
|
---|
| 718 | // // return number of 9s in R²
|
---|
| 719 | // return -Math.Log10(1 - q);
|
---|
[15410] | 720 | }
|
---|
| 721 |
|
---|
| 722 | // backpropagate existing statistics to all parents
|
---|
[15416] | 723 | private static void BackpropagateStatistics(IActionStatistics stats, Tree tree, State state) {
|
---|
[15410] | 724 | tree.actionStatistics.Add(stats);
|
---|
[15416] | 725 | if (state.parents.ContainsKey(tree)) {
|
---|
| 726 | foreach (var parent in state.parents[tree]) {
|
---|
| 727 | BackpropagateStatistics(stats, parent, state);
|
---|
[15410] | 728 | }
|
---|
| 729 | }
|
---|
| 730 | }
|
---|
| 731 |
|
---|
| 732 | private static ulong Hashcode(Automaton automaton) {
|
---|
| 733 | byte[] code;
|
---|
| 734 | int nParams;
|
---|
| 735 | automaton.GetCode(out code, out nParams);
|
---|
| 736 | return ExprHash.GetHash(code, nParams);
|
---|
| 737 | }
|
---|
| 738 |
|
---|
[15416] | 739 | private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) {
|
---|
[15410] | 740 | if (q > 0) policy.Update(tree.actionStatistics, q);
|
---|
[15416] | 741 | if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) {
|
---|
[15410] | 742 | tree.Done = true;
|
---|
| 743 | // children[tree] = null; keep all nodes
|
---|
| 744 | }
|
---|
| 745 |
|
---|
[15416] | 746 | if (state.parents.ContainsKey(tree)) {
|
---|
| 747 | foreach (var parent in state.parents[tree]) {
|
---|
| 748 | BackpropagateQuality(parent, q, policy, state);
|
---|
[15410] | 749 | }
|
---|
| 750 | }
|
---|
| 751 | }
|
---|
| 752 |
|
---|
[15416] | 753 | private static Tree SelectStateLeadingToFinal(Automaton automaton, Tree tree, IRandom rand, State state) {
|
---|
| 754 | // find the child with the smallest state value (smaller values are closer to the final state)
|
---|
| 755 | int selectedChildIdx = 0;
|
---|
| 756 | var children = state.children[tree];
|
---|
| 757 | Tree minChild = children.First();
|
---|
| 758 | for (int i = 1; i < children.Count; i++) {
|
---|
| 759 | if(children[i].state < minChild.state)
|
---|
[15410] | 760 | selectedChildIdx = i;
|
---|
| 761 | }
|
---|
| 762 | return children[selectedChildIdx];
|
---|
| 763 | }
|
---|
| 764 |
|
---|
[13651] | 765 | // tree search might fail because of constraints for expressions
|
---|
| 766 | // in this case we get stuck we just restart
|
---|
| 767 | // see ConstraintHandler.cs for more info
|
---|
[13658] | 768 | private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
|
---|
| 769 | out double q) {
|
---|
[13645] | 770 | Tree selectedChild = null;
|
---|
| 771 | Contract.Assert(tree.state == automaton.CurrentState);
|
---|
[13658] | 772 | Contract.Assert(!tree.Done);
|
---|
[13645] | 773 | if (tree.children == null) {
|
---|
| 774 | if (automaton.IsFinalState(tree.state)) {
|
---|
| 775 | // final state
|
---|
[13658] | 776 | tree.Done = true;
|
---|
[13645] | 777 |
|
---|
| 778 | // EVALUATE
|
---|
| 779 | byte[] code; int nParams;
|
---|
| 780 | automaton.GetCode(out code, out nParams);
|
---|
| 781 | q = eval(code, nParams);
|
---|
[13658] | 782 |
|
---|
| 783 | treePolicy.Update(tree.actionStatistics, q);
|
---|
[13651] | 784 | return true; // we reached a final state
|
---|
[13645] | 785 | } else {
|
---|
| 786 | // EXPAND
|
---|
| 787 | int[] possibleFollowStates;
|
---|
| 788 | int nFs;
|
---|
| 789 | automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
|
---|
[13651] | 790 | if (nFs == 0) {
|
---|
| 791 | // stuck in a dead end (no final state and no allowed follow states)
|
---|
| 792 | q = 0;
|
---|
[13658] | 793 | tree.Done = true;
|
---|
[13651] | 794 | tree.children = null;
|
---|
| 795 | return false;
|
---|
| 796 | }
|
---|
[13645] | 797 | tree.children = new Tree[nFs];
|
---|
[13651] | 798 | for (int i = 0; i < tree.children.Length; i++)
|
---|
[15410] | 799 | tree.children[i] = new Tree() {
|
---|
| 800 | children = null,
|
---|
| 801 | state = possibleFollowStates[i],
|
---|
| 802 | actionStatistics = treePolicy.CreateActionStatistics()
|
---|
| 803 | };
|
---|
[13645] | 804 |
|
---|
[13657] | 805 | selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
|
---|
[13645] | 806 | }
|
---|
| 807 | } else {
|
---|
| 808 | // tree.children != null
|
---|
| 809 | // UCT selection within tree
|
---|
[13658] | 810 | int selectedIdx = 0;
|
---|
| 811 | if (tree.children.Length > 1) {
|
---|
| 812 | selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);
|
---|
| 813 | }
|
---|
| 814 | selectedChild = tree.children[selectedIdx];
|
---|
[13645] | 815 | }
|
---|
| 816 | // make selected step and recurse
|
---|
| 817 | automaton.Goto(selectedChild.state);
|
---|
[13658] | 818 | var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);
|
---|
[13651] | 819 | if (success) {
|
---|
| 820 | // only update if successful
|
---|
[13658] | 821 | treePolicy.Update(tree.actionStatistics, q);
|
---|
[13651] | 822 | }
|
---|
[13645] | 823 |
|
---|
[13658] | 824 | tree.Done = tree.children.All(ch => ch.Done);
|
---|
| 825 | if (tree.Done) {
|
---|
[13651] | 826 | tree.children = null; // cut off the sub-branch if it has been fully explored
|
---|
[13645] | 827 | }
|
---|
[13651] | 828 | return success;
|
---|
[13645] | 829 | }
|
---|
| 830 |
|
---|
| 831 | private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
|
---|
| 832 | // if one of the new children leads to a final state then go there
|
---|
| 833 | // otherwise choose a random child
|
---|
| 834 | int selectedChildIdx = -1;
|
---|
| 835 | // find first final state if there is one
|
---|
| 836 | for (int i = 0; i < tree.children.Length; i++) {
|
---|
| 837 | if (automaton.IsFinalState(tree.children[i].state)) {
|
---|
| 838 | selectedChildIdx = i;
|
---|
| 839 | break;
|
---|
| 840 | }
|
---|
| 841 | }
|
---|
[13669] | 842 | // no final state -> select a the first child
|
---|
[13645] | 843 | if (selectedChildIdx == -1) {
|
---|
[13669] | 844 | selectedChildIdx = 0;
|
---|
[13645] | 845 | }
|
---|
| 846 | return tree.children[selectedChildIdx];
|
---|
| 847 | }
|
---|
| 848 |
|
---|
| 849 | // scales data and extracts values from dataset into arrays
|
---|
| 850 | private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,
|
---|
| 851 | out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
|
---|
| 852 | xs = new double[problemData.AllowedInputVariables.Count()][];
|
---|
| 853 |
|
---|
| 854 | var i = 0;
|
---|
| 855 | if (scaleVariables) {
|
---|
[15403] | 856 | scalingFactor = new double[xs.Length + 1];
|
---|
| 857 | scalingOffset = new double[xs.Length + 1];
|
---|
[13645] | 858 | } else {
|
---|
| 859 | scalingFactor = null;
|
---|
| 860 | scalingOffset = null;
|
---|
| 861 | }
|
---|
| 862 | foreach (var var in problemData.AllowedInputVariables) {
|
---|
| 863 | if (scaleVariables) {
|
---|
| 864 | var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
|
---|
| 865 | var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
|
---|
| 866 | var range = maxX - minX;
|
---|
| 867 |
|
---|
| 868 | // scaledX = (x - min) / range
|
---|
| 869 | var sf = 1.0 / range;
|
---|
| 870 | var offset = -minX / range;
|
---|
| 871 | scalingFactor[i] = sf;
|
---|
| 872 | scalingOffset[i] = offset;
|
---|
| 873 | i++;
|
---|
| 874 | }
|
---|
| 875 | }
|
---|
| 876 |
|
---|
[15403] | 877 | if (scaleVariables) {
|
---|
| 878 | // transform target variable to zero-mean
|
---|
| 879 | scalingFactor[i] = 1.0;
|
---|
| 880 | scalingOffset[i] = -problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Average();
|
---|
| 881 | }
|
---|
| 882 |
|
---|
[13645] | 883 | GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
|
---|
| 884 | }
|
---|
| 885 |
|
---|
| 886 | // extract values from dataset into arrays
|
---|
| 887 | private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,
|
---|
| 888 | out double[][] xs, out double[] y) {
|
---|
| 889 | xs = new double[problemData.AllowedInputVariables.Count()][];
|
---|
| 890 |
|
---|
| 891 | int i = 0;
|
---|
| 892 | foreach (var var in problemData.AllowedInputVariables) {
|
---|
| 893 | var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
|
---|
| 894 | var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
|
---|
| 895 | xs[i++] =
|
---|
| 896 | problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
|
---|
| 897 | }
|
---|
| 898 |
|
---|
[15403] | 899 | {
|
---|
| 900 | var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
|
---|
| 901 | var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
|
---|
| 902 | y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(yi => yi * sf + offset).ToArray();
|
---|
| 903 | }
|
---|
[13645] | 904 | }
|
---|
[15410] | 905 |
|
---|
| 906 | // for debugging only
|
---|
| 907 |
|
---|
| 908 |
|
---|
| 909 | private static string ExprStr(Automaton automaton) {
|
---|
| 910 | byte[] code;
|
---|
| 911 | int nParams;
|
---|
| 912 | automaton.GetCode(out code, out nParams);
|
---|
| 913 | return Disassembler.CodeToString(code);
|
---|
| 914 | }
|
---|
| 915 |
|
---|
[15420] | 916 |
|
---|
[15416] | 917 | private static string WriteStatistics(Tree tree, State state) {
|
---|
[15410] | 918 | var sb = new System.IO.StringWriter();
|
---|
| 919 | sb.WriteLine("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);
|
---|
[15416] | 920 | if (state.children.ContainsKey(tree)) {
|
---|
| 921 | foreach (var ch in state.children[tree]) {
|
---|
[15410] | 922 | sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);
|
---|
| 923 | }
|
---|
| 924 | }
|
---|
| 925 | return sb.ToString();
|
---|
| 926 | }
|
---|
[15414] | 927 |
|
---|
[15416] | 928 | private static string TraceTree(Tree tree, State state) {
|
---|
[15414] | 929 | var sb = new StringBuilder();
|
---|
| 930 | sb.Append(
|
---|
| 931 | @"digraph {
|
---|
| 932 | ratio = fill;
|
---|
| 933 | node [style=filled];
|
---|
| 934 | ");
|
---|
| 935 | int nodeId = 0;
|
---|
| 936 |
|
---|
[15416] | 937 | TraceTreeRec(tree, 0, sb, ref nodeId, state);
|
---|
[15414] | 938 | sb.Append("}");
|
---|
| 939 | return sb.ToString();
|
---|
| 940 | }
|
---|
| 941 |
|
---|
[15416] | 942 | private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) {
|
---|
[15414] | 943 | var avgNodeQ = tree.actionStatistics.AverageQuality;
|
---|
| 944 | var tries = tree.actionStatistics.Tries;
|
---|
| 945 | if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
|
---|
| 946 | var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
|
---|
[15416] | 947 | hue = 0.0;
|
---|
[15414] | 948 |
|
---|
[15416] | 949 | sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine();
|
---|
[15414] | 950 |
|
---|
| 951 | var list = new List<Tuple<int, int, Tree>>();
|
---|
[15416] | 952 | if (state.children.ContainsKey(tree)) {
|
---|
| 953 | foreach (var ch in state.children[tree]) {
|
---|
[15414] | 954 | nextId++;
|
---|
| 955 | avgNodeQ = ch.actionStatistics.AverageQuality;
|
---|
| 956 | tries = ch.actionStatistics.Tries;
|
---|
| 957 | if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
|
---|
| 958 | hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
|
---|
[15416] | 959 | hue = 0.0;
|
---|
| 960 | sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
|
---|
| 961 | sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", parentId, nextId, avgNodeQ, ch.expr).AppendLine();
|
---|
[15414] | 962 | list.Add(Tuple.Create(tries, nextId, ch));
|
---|
| 963 | }
|
---|
[15416] | 964 |
|
---|
| 965 | foreach(var tup in list) {
|
---|
| 966 | var ch = tup.Item3;
|
---|
| 967 | var chId = tup.Item2;
|
---|
| 968 | if(state.children.ContainsKey(ch) && state.children[ch].Count == 1) {
|
---|
| 969 | var chch = state.children[ch].First();
|
---|
| 970 | nextId++;
|
---|
| 971 | avgNodeQ = chch.actionStatistics.AverageQuality;
|
---|
| 972 | tries = chch.actionStatistics.Tries;
|
---|
| 973 | if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
|
---|
| 974 | hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
|
---|
| 975 | hue = 0.0;
|
---|
| 976 | sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
|
---|
| 977 | sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", chId, nextId, avgNodeQ, chch.expr).AppendLine();
|
---|
| 978 | }
|
---|
| 979 | }
|
---|
| 980 |
|
---|
[15414] | 981 | foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {
|
---|
[15416] | 982 | TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId, state);
|
---|
[15414] | 983 | }
|
---|
| 984 | }
|
---|
| 985 | }
|
---|
| 986 |
|
---|
[15416] | 987 | private static string WriteTree(Tree tree, State state) {
|
---|
[15410] | 988 | var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);
|
---|
| 989 | var nodeIds = new Dictionary<Tree, int>();
|
---|
| 990 | sb.Write(
|
---|
| 991 | @"digraph {
|
---|
| 992 | ratio = fill;
|
---|
| 993 | node [style=filled];
|
---|
| 994 | ");
|
---|
[15416] | 995 | int threshold = /* state.nodes.Count > 500 ? 10 : */ 0;
|
---|
| 996 | foreach (var kvp in state.children) {
|
---|
[15410] | 997 | var parent = kvp.Key;
|
---|
| 998 | int parentId;
|
---|
[15414] | 999 | if (!nodeIds.TryGetValue(parent, out parentId)) {
|
---|
[15410] | 1000 | parentId = nodeIds.Count + 1;
|
---|
[15414] | 1001 | var avgNodeQ = parent.actionStatistics.AverageQuality;
|
---|
[15410] | 1002 | var tries = parent.actionStatistics.Tries;
|
---|
| 1003 | if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
|
---|
[15414] | 1004 | var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
|
---|
[15416] | 1005 | hue = 0.0;
|
---|
[15414] | 1006 | if (parent.actionStatistics.Tries > threshold)
|
---|
[15416] | 1007 | sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);
|
---|
[15410] | 1008 | nodeIds.Add(parent, parentId);
|
---|
| 1009 | }
|
---|
[15414] | 1010 | foreach (var child in kvp.Value) {
|
---|
[15410] | 1011 | int childId;
|
---|
[15414] | 1012 | if (!nodeIds.TryGetValue(child, out childId)) {
|
---|
[15410] | 1013 | childId = nodeIds.Count + 1;
|
---|
| 1014 | nodeIds.Add(child, childId);
|
---|
| 1015 | }
|
---|
| 1016 | var avgNodeQ = child.actionStatistics.AverageQuality;
|
---|
| 1017 | var tries = child.actionStatistics.Tries;
|
---|
| 1018 | if (tries < 1) continue;
|
---|
| 1019 | if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
|
---|
[15414] | 1020 | var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
|
---|
[15416] | 1021 | hue = 0.0;
|
---|
[15414] | 1022 | if (tries > threshold) {
|
---|
[15416] | 1023 | sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);
|
---|
[15414] | 1024 | var edgeLabel = child.expr;
|
---|
| 1025 | // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
|
---|
| 1026 | sb.Write("{0} -> {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel);
|
---|
| 1027 | }
|
---|
[15410] | 1028 | }
|
---|
| 1029 | }
|
---|
| 1030 |
|
---|
| 1031 | sb.Write("}");
|
---|
| 1032 | return sb.ToString();
|
---|
| 1033 | }
|
---|
[13645] | 1034 | }
|
---|
| 1035 | }
|
---|