1 | #region License Information
|
---|
2 | /* HeuristicLab
|
---|
3 | * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
|
---|
4 | *
|
---|
5 | * This file is part of HeuristicLab.
|
---|
6 | *
|
---|
7 | * HeuristicLab is free software: you can redistribute it and/or modify
|
---|
8 | * it under the terms of the GNU General Public License as published by
|
---|
9 | * the Free Software Foundation, either version 3 of the License, or
|
---|
10 | * (at your option) any later version.
|
---|
11 | *
|
---|
12 | * HeuristicLab is distributed in the hope that it will be useful,
|
---|
13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
---|
14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
---|
15 | * GNU General Public License for more details.
|
---|
16 | *
|
---|
17 | * You should have received a copy of the GNU General Public License
|
---|
18 | * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
|
---|
19 | */
|
---|
20 | #endregion
|
---|
21 |
|
---|
22 | using System;
|
---|
23 | using System.Collections.Generic;
|
---|
24 | using System.Diagnostics;
|
---|
25 | using System.Diagnostics.Contracts;
|
---|
26 | using System.Linq;
|
---|
27 | using System.Text;
|
---|
28 | using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
|
---|
29 | using HeuristicLab.Core;
|
---|
30 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
31 | using HeuristicLab.Optimization;
|
---|
32 | using HeuristicLab.Problems.DataAnalysis;
|
---|
33 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
34 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
35 | using HeuristicLab.Random;
|
---|
36 |
|
---|
37 | namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
|
---|
38 | public static class MctsSymbolicRegressionStatic {
|
---|
39 | // OBJECTIVES:
|
---|
40 | // 1) solve toy problems without numeric constants (to show that structure search is effective / efficient)
|
---|
41 | // - e.g. Keijzer, Nguyen ... where no numeric constants are involved
|
---|
42 | // - assumptions:
|
---|
43 | // - we don't know the necessary operations or functions -> all available functions could be necessary
|
---|
44 | // - but we do not need to tune numeric constants -> no scaling of input variables x!
|
---|
45 | // 2) Solve toy problems with numeric constants to make the algorithm invariant concerning variable scale.
|
---|
46 | // This is important for real world applications.
|
---|
47 | // - e.g. Korns or Vladislavleva problems where numeric constants are involved
|
---|
48 | // - assumptions:
|
---|
49 | // - any numeric constant is possible (a-priori we might assume that small abs. constants are more likely)
|
---|
50 | // - standardization of variables is possible (or might be necessary) as we adjust numeric parameters of the expression anyway
|
---|
51 | // - to simplify the problem we can restrict the set of functions e.g. we assume which functions are necessary for the problem instance
|
---|
52 | // -> several steps: (a) polyinomials, (b) rational polynomials, (c) exponential or logarithmic functions, rational functions with exponential and logarithmic parts
|
---|
53 | // 3) efficiency and effectiveness for real-world problems
|
---|
54 | // - e.g. Tower problem
|
---|
55 | // - (1) and (2) combined, structure search must be effective in combination with numeric optimization of constants
|
---|
56 | //
|
---|
57 |
|
---|
58 | // TODO: Taking averages of R² values is probably not ideal as an improvement of R² from 0.99 to 0.999 should
|
---|
59 | // weight more than an improvement from 0.98 to 0.99. Also, we are more interested in the best value of a
|
---|
60 | // branch and less in the expected value. (--> Review "Extreme Bandit" literature again)
|
---|
61 | // TODO: Solve Poly-10
|
---|
62 | // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?
|
---|
63 | // TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants
|
---|
64 | // TODO: check if transformation of y is correct and works (Obj 2)
|
---|
65 | // TODO: The algorithm is not invariant to location and scale of variables.
|
---|
66 | // Include offset for variables as parameter (for Objective 2)
|
---|
67 | // TODO: why does LM optimization converge so slowly with exp(x), log(x), and 1/x allowed (Obj 2)?
|
---|
68 | // TODO: support e(-x) and possibly (1/-x) (Obj 1)
|
---|
69 | // TODO: is it OK to initialize all constants to 1 (Obj 2)?
|
---|
70 | // TODO: improve memory usage
|
---|
71 | #region static API
|
---|
72 |
|
---|
73 | public interface IState {
|
---|
74 | bool Done { get; }
|
---|
75 | ISymbolicRegressionModel BestModel { get; }
|
---|
76 | double BestSolutionTrainingQuality { get; }
|
---|
77 | double BestSolutionTestQuality { get; }
|
---|
78 | IEnumerable<ISymbolicRegressionSolution> ParetoBestModels { get; }
|
---|
79 | int TotalRollouts { get; }
|
---|
80 | int EffectiveRollouts { get; }
|
---|
81 | int FuncEvaluations { get; }
|
---|
82 | int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
|
---|
83 | // TODO other stats on LM optimizer might be interesting here
|
---|
84 | }
|
---|
85 |
|
---|
86 | // created through factory method
|
---|
87 | private class State : IState {
|
---|
88 | private const int MaxParams = 100;
|
---|
89 |
|
---|
90 | // state variables used by MCTS
|
---|
91 | internal readonly Automaton automaton;
|
---|
92 | internal IRandom random { get; private set; }
|
---|
93 | internal readonly Tree tree;
|
---|
94 | internal readonly Func<byte[], int, double> evalFun;
|
---|
95 | internal readonly IPolicy treePolicy;
|
---|
96 | // MCTS might get stuck. Track statistics on the number of effective rollouts
|
---|
97 | internal int totalRollouts;
|
---|
98 | internal int effectiveRollouts;
|
---|
99 |
|
---|
100 |
|
---|
101 | // state variables used only internally (for eval function)
|
---|
102 | private readonly IRegressionProblemData problemData;
|
---|
103 | private readonly double[][] x;
|
---|
104 | private readonly double[] y;
|
---|
105 | private readonly double[][] testX;
|
---|
106 | private readonly double[] testY;
|
---|
107 | private readonly double[] scalingFactor;
|
---|
108 | private readonly double[] scalingOffset;
|
---|
109 | private readonly double yStdDev; // for scaling parameters (e.g. stopping condition for LM)
|
---|
110 | private readonly int constOptIterations;
|
---|
111 | private readonly double lambda; // weight of penalty term for regularization
|
---|
112 | private readonly double lowerEstimationLimit, upperEstimationLimit;
|
---|
113 | private readonly bool collectParetoOptimalModels;
|
---|
114 | private readonly List<ISymbolicRegressionSolution> paretoBestModels = new List<ISymbolicRegressionSolution>();
|
---|
115 | private readonly List<double[]> paretoFront = new List<double[]>(); // matching the models
|
---|
116 |
|
---|
117 | private readonly ExpressionEvaluator evaluator, testEvaluator;
|
---|
118 |
|
---|
119 | // values for best solution
|
---|
120 | private double bestRSq;
|
---|
121 | private byte[] bestCode;
|
---|
122 | private int bestNParams;
|
---|
123 | private double[] bestConsts;
|
---|
124 |
|
---|
125 | // stats
|
---|
126 | private int funcEvaluations;
|
---|
127 | private int gradEvaluations;
|
---|
128 |
|
---|
129 | // buffers
|
---|
130 | private readonly double[] ones; // vector of ones (as default params)
|
---|
131 | private readonly double[] constsBuf;
|
---|
132 | private readonly double[] predBuf, testPredBuf;
|
---|
133 | private readonly double[][] gradBuf;
|
---|
134 |
|
---|
135 | public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,
|
---|
136 | int constOptIterations, double lambda,
|
---|
137 | IPolicy treePolicy = null,
|
---|
138 | bool collectParetoOptimalModels = false,
|
---|
139 | double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
|
---|
140 | bool allowProdOfVars = true,
|
---|
141 | bool allowExp = true,
|
---|
142 | bool allowLog = true,
|
---|
143 | bool allowInv = true,
|
---|
144 | bool allowMultipleTerms = false) {
|
---|
145 |
|
---|
146 | if (lambda < 0) throw new ArgumentException("Lambda must be larger or equal zero", "lambda");
|
---|
147 |
|
---|
148 | this.problemData = problemData;
|
---|
149 | this.constOptIterations = constOptIterations;
|
---|
150 | this.lambda = lambda;
|
---|
151 | this.evalFun = this.Eval;
|
---|
152 | this.lowerEstimationLimit = lowerEstimationLimit;
|
---|
153 | this.upperEstimationLimit = upperEstimationLimit;
|
---|
154 | this.collectParetoOptimalModels = collectParetoOptimalModels;
|
---|
155 |
|
---|
156 | random = new MersenneTwister(randSeed);
|
---|
157 |
|
---|
158 | // prepare data for evaluation
|
---|
159 | double[][] x;
|
---|
160 | double[] y;
|
---|
161 | double[][] testX;
|
---|
162 | double[] testY;
|
---|
163 | double[] scalingFactor;
|
---|
164 | double[] scalingOffset;
|
---|
165 | // get training and test datasets (scale linearly based on training set if required)
|
---|
166 | GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
|
---|
167 | GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
|
---|
168 | this.x = x;
|
---|
169 | this.y = y;
|
---|
170 | this.yStdDev = HeuristicLab.Common.EnumerableStatisticExtensions.StandardDeviation(y);
|
---|
171 | this.testX = testX;
|
---|
172 | this.testY = testY;
|
---|
173 | this.scalingFactor = scalingFactor;
|
---|
174 | this.scalingOffset = scalingOffset;
|
---|
175 | this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
|
---|
176 | // we need a separate evaluator because the vector length for the test dataset might differ
|
---|
177 | this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
|
---|
178 |
|
---|
179 | this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
|
---|
180 | this.treePolicy = treePolicy ?? new Ucb();
|
---|
181 | this.tree = new Tree() {
|
---|
182 | state = automaton.CurrentState,
|
---|
183 | actionStatistics = treePolicy.CreateActionStatistics(),
|
---|
184 | expr = "",
|
---|
185 | level = 0
|
---|
186 | };
|
---|
187 |
|
---|
188 | // reset best solution
|
---|
189 | this.bestRSq = 0;
|
---|
190 | // code for default solution (constant model)
|
---|
191 | this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
|
---|
192 | this.bestNParams = 0;
|
---|
193 | this.bestConsts = null;
|
---|
194 |
|
---|
195 | // init buffers
|
---|
196 | this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
|
---|
197 | constsBuf = new double[MaxParams];
|
---|
198 | this.predBuf = new double[y.Length];
|
---|
199 | this.testPredBuf = new double[testY.Length];
|
---|
200 |
|
---|
201 | this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
|
---|
202 | }
|
---|
203 |
|
---|
204 | #region IState inferface
|
---|
205 | public bool Done { get { return tree != null && tree.Done; } }
|
---|
206 |
|
---|
207 | public double BestSolutionTrainingQuality {
|
---|
208 | get {
|
---|
209 | evaluator.Exec(bestCode, x, bestConsts, predBuf);
|
---|
210 | return RSq(y, predBuf);
|
---|
211 | }
|
---|
212 | }
|
---|
213 |
|
---|
214 | public double BestSolutionTestQuality {
|
---|
215 | get {
|
---|
216 | testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
|
---|
217 | return RSq(testY, testPredBuf);
|
---|
218 | }
|
---|
219 | }
|
---|
220 |
|
---|
221 | // takes the code of the best solution and creates and equivalent symbolic regression model
|
---|
222 | public ISymbolicRegressionModel BestModel {
|
---|
223 | get {
|
---|
224 | var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
|
---|
225 | var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
|
---|
226 |
|
---|
227 | var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
|
---|
228 | var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
|
---|
229 | model.Scale(problemData); // apply linear scaling
|
---|
230 | return model;
|
---|
231 | }
|
---|
232 | }
|
---|
233 | public IEnumerable<ISymbolicRegressionSolution> ParetoBestModels {
|
---|
234 | get { return paretoBestModels; }
|
---|
235 | }
|
---|
236 |
|
---|
237 | public int TotalRollouts { get { return totalRollouts; } }
|
---|
238 | public int EffectiveRollouts { get { return effectiveRollouts; } }
|
---|
239 | public int FuncEvaluations { get { return funcEvaluations; } }
|
---|
240 | public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
|
---|
241 |
|
---|
242 | #endregion
|
---|
243 |
|
---|
244 | private double Eval(byte[] code, int nParams) {
|
---|
245 | double[] optConsts;
|
---|
246 | double q;
|
---|
247 | Eval(code, nParams, out q, out optConsts);
|
---|
248 |
|
---|
249 | // single objective best
|
---|
250 | if (q > bestRSq) {
|
---|
251 | bestRSq = q;
|
---|
252 | bestNParams = nParams;
|
---|
253 | this.bestCode = new byte[code.Length];
|
---|
254 | this.bestConsts = new double[bestNParams];
|
---|
255 |
|
---|
256 | Array.Copy(code, bestCode, code.Length);
|
---|
257 | Array.Copy(optConsts, bestConsts, bestNParams);
|
---|
258 | }
|
---|
259 | if (collectParetoOptimalModels) {
|
---|
260 | // multi-objective best
|
---|
261 | var complexity = // SymbolicDataAnalysisModelComplexityCalculator.CalculateComplexity() TODO: implement Kommenda's tree complexity directly in the evaluator
|
---|
262 | Array.FindIndex(code, (opc) => opc == (byte)OpCodes.Exit); // use length of expression as surrogate for complexity
|
---|
263 | UpdateParetoFront(q, complexity, code, optConsts, nParams, scalingFactor, scalingOffset);
|
---|
264 | }
|
---|
265 | return q;
|
---|
266 | }
|
---|
267 |
|
---|
268 | private void Eval(byte[] code, int nParams, out double rsq, out double[] optConsts) {
|
---|
269 | // we make a first pass to determine a valid starting configuration for all constants
|
---|
270 | // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
|
---|
271 | // scale and offset are set to optimal starting configuration
|
---|
272 | // assumes scale is the first param and offset is the last param
|
---|
273 |
|
---|
274 | // reset constants
|
---|
275 | Array.Copy(ones, constsBuf, nParams);
|
---|
276 | evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
|
---|
277 | funcEvaluations++;
|
---|
278 |
|
---|
279 | if (nParams == 0 || constOptIterations < 0) {
|
---|
280 | // if we don't need to optimize parameters then we are done
|
---|
281 | // changing scale and offset does not influence r²
|
---|
282 | rsq = RSq(y, predBuf);
|
---|
283 | optConsts = constsBuf;
|
---|
284 | } else {
|
---|
285 | // optimize constants using the starting point calculated above
|
---|
286 | OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
|
---|
287 |
|
---|
288 | evaluator.Exec(code, x, constsBuf, predBuf);
|
---|
289 | funcEvaluations++;
|
---|
290 |
|
---|
291 | rsq = RSq(y, predBuf);
|
---|
292 | optConsts = constsBuf;
|
---|
293 | }
|
---|
294 | }
|
---|
295 |
|
---|
296 |
|
---|
297 |
|
---|
298 | #region helpers
|
---|
299 | private static double RSq(IEnumerable<double> x, IEnumerable<double> y) {
|
---|
300 | OnlineCalculatorError error;
|
---|
301 | double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
|
---|
302 | return error == OnlineCalculatorError.None ? r * r : 0.0;
|
---|
303 | }
|
---|
304 |
|
---|
305 |
|
---|
306 | private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
|
---|
307 | double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
|
---|
308 | Array.Copy(consts, optConsts, nParams);
|
---|
309 |
|
---|
310 | // direct usage of LM is recommended in alglib manual for better performance than the lsfit interface (which uses lm internally).
|
---|
311 | alglib.minlmstate state;
|
---|
312 | alglib.minlmreport rep = null;
|
---|
313 | alglib.minlmcreatevj(y.Length + 1, optConsts, out state); // +1 for penalty term
|
---|
314 | // Using the change of the gradient as stopping criterion is recommended in alglib manual.
|
---|
315 | // However, the most recent version of alglib (as of Oct 2017) only supports epsX as stopping criterion
|
---|
316 | alglib.minlmsetcond(state, epsg: 1E-6 * yStdDev, epsf: epsF, epsx: 0.0, maxits: nIters);
|
---|
317 | // alglib.minlmsetgradientcheck(state, 1E-5);
|
---|
318 | alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
|
---|
319 | alglib.minlmresults(state, out optConsts, out rep);
|
---|
320 | funcEvaluations += rep.nfunc;
|
---|
321 | gradEvaluations += rep.njac * nParams;
|
---|
322 |
|
---|
323 | if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
|
---|
324 |
|
---|
325 | // only use optimized constants if successful
|
---|
326 | if (rep.terminationtype >= 0) {
|
---|
327 | Array.Copy(optConsts, consts, optConsts.Length);
|
---|
328 | }
|
---|
329 | }
|
---|
330 |
|
---|
331 | private void Func(double[] arg, double[] fi, object obj) {
|
---|
332 | var code = (byte[])obj;
|
---|
333 | int n = predBuf.Length;
|
---|
334 | evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
|
---|
335 | for (int r = 0; r < n; r++) {
|
---|
336 | var res = predBuf[r] - y[r];
|
---|
337 | fi[r] = res;
|
---|
338 | }
|
---|
339 |
|
---|
340 | var penaltyIdx = fi.Length - 1;
|
---|
341 | fi[penaltyIdx] = 0.0;
|
---|
342 | // calc length of parameter vector for regularization
|
---|
343 | var aa = 0.0;
|
---|
344 | for (int i = 0; i < arg.Length; i++) {
|
---|
345 | aa += arg[i] * arg[i];
|
---|
346 | }
|
---|
347 | if (lambda > 0 && aa > 0) {
|
---|
348 | // scale lambda using stdDev(y) to make the parameter independent of the scale of y
|
---|
349 | // scale lambda using n to make parameter independent of the number of training points
|
---|
350 | // take the root because LM squares the result
|
---|
351 | fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
|
---|
352 | }
|
---|
353 | }
|
---|
354 |
|
---|
355 | private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
|
---|
356 | int n = predBuf.Length;
|
---|
357 | int nParams = arg.Length;
|
---|
358 | var code = (byte[])obj;
|
---|
359 | evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
|
---|
360 | for (int r = 0; r < n; r++) {
|
---|
361 | var res = predBuf[r] - y[r];
|
---|
362 | fi[r] = res;
|
---|
363 |
|
---|
364 | for (int k = 0; k < nParams; k++) {
|
---|
365 | jac[r, k] = gradBuf[k][r];
|
---|
366 | }
|
---|
367 | }
|
---|
368 | // calc length of parameter vector for regularization
|
---|
369 | double aa = 0.0;
|
---|
370 | for (int i = 0; i < arg.Length; i++) {
|
---|
371 | aa += arg[i] * arg[i];
|
---|
372 | }
|
---|
373 |
|
---|
374 | var penaltyIdx = fi.Length - 1;
|
---|
375 | if (lambda > 0 && aa > 0) {
|
---|
376 | fi[penaltyIdx] = 0.0;
|
---|
377 | // scale lambda using stdDev(y) to make the parameter independent of the scale of y
|
---|
378 | // scale lambda using n to make parameter independent of the number of training points
|
---|
379 | // take the root because alglib LM squares the result
|
---|
380 | fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
|
---|
381 |
|
---|
382 | for (int i = 0; i < arg.Length; i++) {
|
---|
383 | jac[penaltyIdx, i] = 0.5 / fi[penaltyIdx] * 2 * n * lambda / yStdDev * arg[i];
|
---|
384 | }
|
---|
385 | } else {
|
---|
386 | fi[penaltyIdx] = 0.0;
|
---|
387 | for (int i = 0; i < arg.Length; i++) {
|
---|
388 | jac[penaltyIdx, i] = 0.0;
|
---|
389 | }
|
---|
390 | }
|
---|
391 | }
|
---|
392 |
|
---|
393 |
|
---|
394 | private void UpdateParetoFront(double q, int complexity, byte[] code, double[] param, int nParam,
|
---|
395 | double[] scalingFactor, double[] scalingOffset) {
|
---|
396 | double[] best = new double[2];
|
---|
397 | double[] cur = new double[2] { q, complexity };
|
---|
398 | bool[] max = new[] { true, false };
|
---|
399 | var isNonDominated = true;
|
---|
400 | foreach (var e in paretoFront) {
|
---|
401 | var domRes = DominationCalculator<int>.Dominates(cur, e, max, true);
|
---|
402 | if (domRes == DominationResult.IsDominated) {
|
---|
403 | isNonDominated = false;
|
---|
404 | break;
|
---|
405 | }
|
---|
406 | }
|
---|
407 | if (isNonDominated) {
|
---|
408 | paretoFront.Add(cur);
|
---|
409 |
|
---|
410 | // create model
|
---|
411 | var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
|
---|
412 | var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
|
---|
413 |
|
---|
414 | var t = new SymbolicExpressionTree(treeGen.Exec(code, param, nParam, scalingFactor, scalingOffset));
|
---|
415 | var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
|
---|
416 | model.Scale(problemData); // apply linear scaling
|
---|
417 |
|
---|
418 | var sol = model.CreateRegressionSolution(this.problemData);
|
---|
419 | sol.Name = string.Format("{0:N5} {1}", q, complexity);
|
---|
420 |
|
---|
421 | paretoBestModels.Add(sol);
|
---|
422 | }
|
---|
423 | for (int i = paretoFront.Count - 2; i >= 0; i--) {
|
---|
424 | var @ref = paretoFront[i];
|
---|
425 | var domRes = DominationCalculator<int>.Dominates(cur, @ref, max, true);
|
---|
426 | if (domRes == DominationResult.Dominates) {
|
---|
427 | paretoFront.RemoveAt(i);
|
---|
428 | paretoBestModels.RemoveAt(i);
|
---|
429 | }
|
---|
430 | }
|
---|
431 | }
|
---|
432 | #endregion
|
---|
433 | }
|
---|
434 |
|
---|
435 |
|
---|
436 | /// <summary>
|
---|
437 | /// Static method to initialize a state for the algorithm
|
---|
438 | /// </summary>
|
---|
439 | /// <param name="problemData">The problem data</param>
|
---|
440 | /// <param name="randSeed">Random seed.</param>
|
---|
441 | /// <param name="maxVariables">Maximum number of variable references that are allowed in the expression.</param>
|
---|
442 | /// <param name="scaleVariables">Optionally scale input variables to the interval [0..1] (recommended)</param>
|
---|
443 | /// <param name="constOptIterations">Maximum number of iterations for constants optimization (Levenberg-Marquardt)</param>
|
---|
444 | /// <param name="lambda">Penalty factor for regularization (0..inf.), small penalty disabled regularization.</param>
|
---|
445 | /// <param name="policy">Tree search policy (random, ucb, eps-greedy, ...)</param>
|
---|
446 | /// <param name="collectParameterOptimalModels">Optionally collect all Pareto-optimal solutions having minimal length and error.</param>
|
---|
447 | /// <param name="lowerEstimationLimit">Optionally limit the result of the expression to this lower value.</param>
|
---|
448 | /// <param name="upperEstimationLimit">Optionally limit the result of the expression to this upper value.</param>
|
---|
449 | /// <param name="allowProdOfVars">Allow products of expressions.</param>
|
---|
450 | /// <param name="allowExp">Allow expressions with exponentials.</param>
|
---|
451 | /// <param name="allowLog">Allow expressions with logarithms</param>
|
---|
452 | /// <param name="allowInv">Allow expressions with 1/x</param>
|
---|
453 | /// <param name="allowMultipleTerms">Allow expressions which are sums of multiple terms.</param>
|
---|
454 | /// <returns></returns>
|
---|
455 |
|
---|
456 | public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
|
---|
457 | bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0,
|
---|
458 | IPolicy policy = null,
|
---|
459 | bool collectParameterOptimalModels = false,
|
---|
460 | double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
|
---|
461 | bool allowProdOfVars = true,
|
---|
462 | bool allowExp = true,
|
---|
463 | bool allowLog = true,
|
---|
464 | bool allowInv = true,
|
---|
465 | bool allowMultipleTerms = false
|
---|
466 | ) {
|
---|
467 | return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,
|
---|
468 | policy, collectParameterOptimalModels,
|
---|
469 | lowerEstimationLimit, upperEstimationLimit,
|
---|
470 | allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
|
---|
471 | }
|
---|
472 |
|
---|
473 | // returns the quality of the evaluated solution
|
---|
474 | public static double MakeStep(IState state) {
|
---|
475 | var mctsState = state as State;
|
---|
476 | if (mctsState == null) throw new ArgumentException("state");
|
---|
477 | if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
|
---|
478 |
|
---|
479 | return TreeSearch(mctsState);
|
---|
480 | }
|
---|
481 | #endregion
|
---|
482 |
|
---|
483 | private static double TreeSearch(State mctsState) {
|
---|
484 | var automaton = mctsState.automaton;
|
---|
485 | var tree = mctsState.tree;
|
---|
486 | var eval = mctsState.evalFun;
|
---|
487 | var rand = mctsState.random;
|
---|
488 | var treePolicy = mctsState.treePolicy;
|
---|
489 | double q = 0;
|
---|
490 | bool success = false;
|
---|
491 | do {
|
---|
492 | automaton.Reset();
|
---|
493 | success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, out q);
|
---|
494 | mctsState.totalRollouts++;
|
---|
495 | } while (!success && !tree.Done);
|
---|
496 | mctsState.effectiveRollouts++;
|
---|
497 |
|
---|
498 | if (mctsState.effectiveRollouts % 10 == 1) {
|
---|
499 | //Console.WriteLine(WriteTree(tree));
|
---|
500 | //Console.WriteLine(TraceTree(tree));
|
---|
501 | }
|
---|
502 | return q;
|
---|
503 | }
|
---|
504 |
|
---|
505 | private static Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();
|
---|
506 | private static Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();
|
---|
507 | private static Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();
|
---|
508 |
|
---|
509 |
|
---|
510 |
|
---|
511 | // search forward
|
---|
512 | private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
|
---|
513 | out double q) {
|
---|
514 | // ROLLOUT AND EXPANSION
|
---|
515 | // We are navigating a graph (states might be reached via different paths) instead of a tree.
|
---|
516 | // State equivalence is checked through ExprHash (based on the generated code through the path).
|
---|
517 |
|
---|
518 | // We switch between rollout-mode and expansion mode
|
---|
519 | // Rollout-mode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB)
|
---|
520 | // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression)
|
---|
521 | // In expansion mode we might re-enter the graph and switch back to rollout-mode
|
---|
522 | // We do this until we reach a complete expression (final state)
|
---|
523 |
|
---|
524 | // Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent
|
---|
525 | // Sub-graphs which have been completely searched are marked as done.
|
---|
526 | // Roll-out could lead to a state where all follow-states are done. In this case we call the rollout ineffective.
|
---|
527 |
|
---|
528 | while (!automaton.IsFinalState(automaton.CurrentState)) {
|
---|
529 | if (children.ContainsKey(tree)) {
|
---|
530 | if (children[tree].All(ch => ch.Done)) {
|
---|
531 | tree.Done = true;
|
---|
532 | break;
|
---|
533 | }
|
---|
534 | // ROLLOUT INSIDE TREE
|
---|
535 | // UCT selection within tree
|
---|
536 | int selectedIdx = 0;
|
---|
537 | if (children[tree].Count > 1) {
|
---|
538 | selectedIdx = treePolicy.Select(children[tree].Select(ch => ch.actionStatistics), rand);
|
---|
539 | }
|
---|
540 | tree = children[tree][selectedIdx];
|
---|
541 |
|
---|
542 | // move the automaton forward until reaching the state
|
---|
543 | // all steps where no alternatives are possible are immediately taken
|
---|
544 | // TODO: simplification of the automaton
|
---|
545 | int[] possibleFollowStates;
|
---|
546 | int nFs;
|
---|
547 | automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
|
---|
548 | while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
|
---|
549 | automaton.Goto(possibleFollowStates[0]);
|
---|
550 | automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
|
---|
551 | }
|
---|
552 | Debug.Assert(possibleFollowStates.Contains(tree.state));
|
---|
553 | automaton.Goto(tree.state);
|
---|
554 | } else {
|
---|
555 | // EXPAND
|
---|
556 | int[] possibleFollowStates;
|
---|
557 | int nFs;
|
---|
558 | automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
|
---|
559 | while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
|
---|
560 | // no alternatives -> just go to the next state
|
---|
561 | automaton.Goto(possibleFollowStates[0]);
|
---|
562 | automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
|
---|
563 | }
|
---|
564 | if (nFs == 0) {
|
---|
565 | // stuck in a dead end (no final state and no allowed follow states)
|
---|
566 | tree.Done = true;
|
---|
567 | break;
|
---|
568 | }
|
---|
569 | var newChildren = new List<Tree>(nFs);
|
---|
570 | children.Add(tree, newChildren);
|
---|
571 | for (int i = 0; i < nFs; i++) {
|
---|
572 | Tree child = null;
|
---|
573 | // for selected states (EvalStates) we introduce state unification (detection of equivalent states)
|
---|
574 | if (automaton.IsEvalState(possibleFollowStates[i])) {
|
---|
575 | var hc = Hashcode(automaton);
|
---|
576 | if (!nodes.TryGetValue(hc, out child)) {
|
---|
577 | child = new Tree() {
|
---|
578 | children = null,
|
---|
579 | state = possibleFollowStates[i],
|
---|
580 | actionStatistics = treePolicy.CreateActionStatistics(),
|
---|
581 | expr = string.Empty, // ExprStr(automaton),
|
---|
582 | level = tree.level + 1
|
---|
583 | };
|
---|
584 | nodes.Add(hc, child);
|
---|
585 | }
|
---|
586 | // only allow forward edges (don't add the child if we would go back in the graph)
|
---|
587 | else if (child.level > tree.level) {
|
---|
588 | // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
|
---|
589 | // to all parents
|
---|
590 | BackpropagateStatistics(child.actionStatistics, tree);
|
---|
591 | } else {
|
---|
592 | // prevent cycles
|
---|
593 | Debug.Assert(child.level <= tree.level);
|
---|
594 | child = null;
|
---|
595 | }
|
---|
596 | } else {
|
---|
597 | child = new Tree() {
|
---|
598 | children = null,
|
---|
599 | state = possibleFollowStates[i],
|
---|
600 | actionStatistics = treePolicy.CreateActionStatistics(),
|
---|
601 | expr = string.Empty, // ExprStr(automaton),
|
---|
602 | level = tree.level + 1
|
---|
603 | };
|
---|
604 | }
|
---|
605 | if (child != null)
|
---|
606 | newChildren.Add(child);
|
---|
607 | }
|
---|
608 |
|
---|
609 | if (!newChildren.Any()) {
|
---|
610 | // stuck in a dead end (no final state and no allowed follow states)
|
---|
611 | tree.Done = true;
|
---|
612 | break;
|
---|
613 | }
|
---|
614 |
|
---|
615 | foreach (var ch in newChildren) {
|
---|
616 | if (!parents.ContainsKey(ch)) {
|
---|
617 | parents.Add(ch, new List<Tree>());
|
---|
618 | }
|
---|
619 | parents[ch].Add(tree);
|
---|
620 | }
|
---|
621 |
|
---|
622 |
|
---|
623 | // follow one of the children
|
---|
624 | tree = SelectFinalOrRandom2(automaton, tree, rand);
|
---|
625 | automaton.Goto(tree.state);
|
---|
626 | }
|
---|
627 | }
|
---|
628 |
|
---|
629 | bool success;
|
---|
630 |
|
---|
631 | // EVALUATE TREE
|
---|
632 | if (automaton.IsFinalState(automaton.CurrentState)) {
|
---|
633 | tree.Done = true;
|
---|
634 | tree.expr = ExprStr(automaton);
|
---|
635 | byte[] code; int nParams;
|
---|
636 | automaton.GetCode(out code, out nParams);
|
---|
637 | q = eval(code, nParams);
|
---|
638 | q = TransformQuality(q);
|
---|
639 | success = true;
|
---|
640 | } else {
|
---|
641 | // we got stuck in roll-out (not evaluation necessary!)
|
---|
642 | q = 0.0;
|
---|
643 | success = false;
|
---|
644 | }
|
---|
645 |
|
---|
646 | // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS
|
---|
647 | // Update statistics
|
---|
648 | // Set branch to done if all children are done.
|
---|
649 | BackpropagateQuality(tree, q, treePolicy);
|
---|
650 |
|
---|
651 | return success;
|
---|
652 | }
|
---|
653 |
|
---|
654 |
|
---|
655 | private static double TransformQuality(double q) {
|
---|
656 | // no transformation
|
---|
657 | return q;
|
---|
658 |
|
---|
659 | // EXPERIMENTAL!
|
---|
660 | // optimal result: q = 1 -> return huge value
|
---|
661 | // if (q >= 1.0) return 1E16;
|
---|
662 | // // return number of 9s in R²
|
---|
663 | // return -Math.Log10(1 - q);
|
---|
664 | }
|
---|
665 |
|
---|
666 | // backpropagate existing statistics to all parents
|
---|
667 | private static void BackpropagateStatistics(IActionStatistics stats, Tree tree) {
|
---|
668 | tree.actionStatistics.Add(stats);
|
---|
669 | if (parents.ContainsKey(tree)) {
|
---|
670 | foreach (var parent in parents[tree]) {
|
---|
671 | BackpropagateStatistics(stats, parent);
|
---|
672 | }
|
---|
673 | }
|
---|
674 | }
|
---|
675 |
|
---|
676 | private static ulong Hashcode(Automaton automaton) {
|
---|
677 | byte[] code;
|
---|
678 | int nParams;
|
---|
679 | automaton.GetCode(out code, out nParams);
|
---|
680 | return ExprHash.GetHash(code, nParams);
|
---|
681 | }
|
---|
682 |
|
---|
683 | private static void BackpropagateQuality(Tree tree, double q, IPolicy policy) {
|
---|
684 | if (q > 0) policy.Update(tree.actionStatistics, q);
|
---|
685 | if (children.ContainsKey(tree) && children[tree].All(ch => ch.Done)) {
|
---|
686 | tree.Done = true;
|
---|
687 | // children[tree] = null; keep all nodes
|
---|
688 | }
|
---|
689 |
|
---|
690 | if (parents.ContainsKey(tree)) {
|
---|
691 | foreach (var parent in parents[tree]) {
|
---|
692 | BackpropagateQuality(parent, q, policy);
|
---|
693 | }
|
---|
694 | }
|
---|
695 | }
|
---|
696 |
|
---|
697 | private static Tree SelectFinalOrRandom2(Automaton automaton, Tree tree, IRandom rand) {
|
---|
698 | // if one of the new children leads to a final state then go there
|
---|
699 | // otherwise choose a random child
|
---|
700 | int selectedChildIdx = -1;
|
---|
701 | // find first final state if there is one
|
---|
702 | var children = MctsSymbolicRegressionStatic.children[tree];
|
---|
703 | for (int i = 0; i < children.Count; i++) {
|
---|
704 | if (automaton.IsFinalState(children[i].state)) {
|
---|
705 | selectedChildIdx = i;
|
---|
706 | break;
|
---|
707 | }
|
---|
708 | }
|
---|
709 | // no final state -> select the first child
|
---|
710 | if (selectedChildIdx == -1) {
|
---|
711 | selectedChildIdx = 0;
|
---|
712 | }
|
---|
713 | return children[selectedChildIdx];
|
---|
714 | }
|
---|
715 |
|
---|
716 | // tree search might fail because of constraints for expressions
|
---|
717 | // in this case we get stuck we just restart
|
---|
718 | // see ConstraintHandler.cs for more info
|
---|
719 | private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
|
---|
720 | out double q) {
|
---|
721 | Tree selectedChild = null;
|
---|
722 | Contract.Assert(tree.state == automaton.CurrentState);
|
---|
723 | Contract.Assert(!tree.Done);
|
---|
724 | if (tree.children == null) {
|
---|
725 | if (automaton.IsFinalState(tree.state)) {
|
---|
726 | // final state
|
---|
727 | tree.Done = true;
|
---|
728 |
|
---|
729 | // EVALUATE
|
---|
730 | byte[] code; int nParams;
|
---|
731 | automaton.GetCode(out code, out nParams);
|
---|
732 | q = eval(code, nParams);
|
---|
733 |
|
---|
734 | treePolicy.Update(tree.actionStatistics, q);
|
---|
735 | return true; // we reached a final state
|
---|
736 | } else {
|
---|
737 | // EXPAND
|
---|
738 | int[] possibleFollowStates;
|
---|
739 | int nFs;
|
---|
740 | automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
|
---|
741 | if (nFs == 0) {
|
---|
742 | // stuck in a dead end (no final state and no allowed follow states)
|
---|
743 | q = 0;
|
---|
744 | tree.Done = true;
|
---|
745 | tree.children = null;
|
---|
746 | return false;
|
---|
747 | }
|
---|
748 | tree.children = new Tree[nFs];
|
---|
749 | for (int i = 0; i < tree.children.Length; i++)
|
---|
750 | tree.children[i] = new Tree() {
|
---|
751 | children = null,
|
---|
752 | state = possibleFollowStates[i],
|
---|
753 | actionStatistics = treePolicy.CreateActionStatistics()
|
---|
754 | };
|
---|
755 |
|
---|
756 | selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
|
---|
757 | }
|
---|
758 | } else {
|
---|
759 | // tree.children != null
|
---|
760 | // UCT selection within tree
|
---|
761 | int selectedIdx = 0;
|
---|
762 | if (tree.children.Length > 1) {
|
---|
763 | selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);
|
---|
764 | }
|
---|
765 | selectedChild = tree.children[selectedIdx];
|
---|
766 | }
|
---|
767 | // make selected step and recurse
|
---|
768 | automaton.Goto(selectedChild.state);
|
---|
769 | var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);
|
---|
770 | if (success) {
|
---|
771 | // only update if successful
|
---|
772 | treePolicy.Update(tree.actionStatistics, q);
|
---|
773 | }
|
---|
774 |
|
---|
775 | tree.Done = tree.children.All(ch => ch.Done);
|
---|
776 | if (tree.Done) {
|
---|
777 | tree.children = null; // cut off the sub-branch if it has been fully explored
|
---|
778 | }
|
---|
779 | return success;
|
---|
780 | }
|
---|
781 |
|
---|
782 | private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
|
---|
783 | // if one of the new children leads to a final state then go there
|
---|
784 | // otherwise choose a random child
|
---|
785 | int selectedChildIdx = -1;
|
---|
786 | // find first final state if there is one
|
---|
787 | for (int i = 0; i < tree.children.Length; i++) {
|
---|
788 | if (automaton.IsFinalState(tree.children[i].state)) {
|
---|
789 | selectedChildIdx = i;
|
---|
790 | break;
|
---|
791 | }
|
---|
792 | }
|
---|
793 | // no final state -> select a the first child
|
---|
794 | if (selectedChildIdx == -1) {
|
---|
795 | selectedChildIdx = 0;
|
---|
796 | }
|
---|
797 | return tree.children[selectedChildIdx];
|
---|
798 | }
|
---|
799 |
|
---|
800 | // scales data and extracts values from dataset into arrays
|
---|
801 | private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,
|
---|
802 | out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
|
---|
803 | xs = new double[problemData.AllowedInputVariables.Count()][];
|
---|
804 |
|
---|
805 | var i = 0;
|
---|
806 | if (scaleVariables) {
|
---|
807 | scalingFactor = new double[xs.Length + 1];
|
---|
808 | scalingOffset = new double[xs.Length + 1];
|
---|
809 | } else {
|
---|
810 | scalingFactor = null;
|
---|
811 | scalingOffset = null;
|
---|
812 | }
|
---|
813 | foreach (var var in problemData.AllowedInputVariables) {
|
---|
814 | if (scaleVariables) {
|
---|
815 | var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
|
---|
816 | var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
|
---|
817 | var range = maxX - minX;
|
---|
818 |
|
---|
819 | // scaledX = (x - min) / range
|
---|
820 | var sf = 1.0 / range;
|
---|
821 | var offset = -minX / range;
|
---|
822 | scalingFactor[i] = sf;
|
---|
823 | scalingOffset[i] = offset;
|
---|
824 | i++;
|
---|
825 | }
|
---|
826 | }
|
---|
827 |
|
---|
828 | if (scaleVariables) {
|
---|
829 | // transform target variable to zero-mean
|
---|
830 | scalingFactor[i] = 1.0;
|
---|
831 | scalingOffset[i] = -problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Average();
|
---|
832 | }
|
---|
833 |
|
---|
834 | GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
|
---|
835 | }
|
---|
836 |
|
---|
837 | // extract values from dataset into arrays
|
---|
838 | private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,
|
---|
839 | out double[][] xs, out double[] y) {
|
---|
840 | xs = new double[problemData.AllowedInputVariables.Count()][];
|
---|
841 |
|
---|
842 | int i = 0;
|
---|
843 | foreach (var var in problemData.AllowedInputVariables) {
|
---|
844 | var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
|
---|
845 | var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
|
---|
846 | xs[i++] =
|
---|
847 | problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
|
---|
848 | }
|
---|
849 |
|
---|
850 | {
|
---|
851 | var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
|
---|
852 | var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
|
---|
853 | y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(yi => yi * sf + offset).ToArray();
|
---|
854 | }
|
---|
855 | }
|
---|
856 |
|
---|
857 | // for debugging only
|
---|
858 |
|
---|
859 |
|
---|
860 | private static string ExprStr(Automaton automaton) {
|
---|
861 | byte[] code;
|
---|
862 | int nParams;
|
---|
863 | automaton.GetCode(out code, out nParams);
|
---|
864 | return Disassembler.CodeToString(code);
|
---|
865 | }
|
---|
866 |
|
---|
867 | private static string WriteStatistics(Tree tree) {
|
---|
868 | var sb = new System.IO.StringWriter();
|
---|
869 | sb.WriteLine("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);
|
---|
870 | if (children.ContainsKey(tree)) {
|
---|
871 | foreach (var ch in children[tree]) {
|
---|
872 | sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);
|
---|
873 | }
|
---|
874 | }
|
---|
875 | return sb.ToString();
|
---|
876 | }
|
---|
877 |
|
---|
878 | private static string TraceTree(Tree tree) {
|
---|
879 | var sb = new StringBuilder();
|
---|
880 | sb.Append(
|
---|
881 | @"digraph {
|
---|
882 | ratio = fill;
|
---|
883 | node [style=filled];
|
---|
884 | ");
|
---|
885 | int nodeId = 0;
|
---|
886 |
|
---|
887 | TraceTreeRec(tree, 0, sb, ref nodeId);
|
---|
888 | sb.Append("}");
|
---|
889 | return sb.ToString();
|
---|
890 | }
|
---|
891 |
|
---|
892 | private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId) {
|
---|
893 | var avgNodeQ = tree.actionStatistics.AverageQuality;
|
---|
894 | var tries = tree.actionStatistics.Tries;
|
---|
895 | if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
|
---|
896 | var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
|
---|
897 |
|
---|
898 | sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine();
|
---|
899 |
|
---|
900 | var list = new List<Tuple<int, int, Tree>>();
|
---|
901 | if (children.ContainsKey(tree)) {
|
---|
902 | foreach (var ch in children[tree]) {
|
---|
903 | nextId++;
|
---|
904 | avgNodeQ = ch.actionStatistics.AverageQuality;
|
---|
905 | tries = ch.actionStatistics.Tries;
|
---|
906 | if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
|
---|
907 | hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
|
---|
908 | sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
|
---|
909 | sb.AppendFormat("{0} -> {1}", parentId, nextId, avgNodeQ).AppendLine();
|
---|
910 | list.Add(Tuple.Create(tries, nextId, ch));
|
---|
911 | }
|
---|
912 | foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {
|
---|
913 | TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId);
|
---|
914 | }
|
---|
915 | }
|
---|
916 | }
|
---|
917 |
|
---|
918 | private static string WriteTree(Tree tree) {
|
---|
919 | var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);
|
---|
920 | var nodeIds = new Dictionary<Tree, int>();
|
---|
921 | sb.Write(
|
---|
922 | @"digraph {
|
---|
923 | ratio = fill;
|
---|
924 | node [style=filled];
|
---|
925 | ");
|
---|
926 | int threshold = nodes.Count > 500 ? 10 : 0;
|
---|
927 | foreach (var kvp in children) {
|
---|
928 | var parent = kvp.Key;
|
---|
929 | int parentId;
|
---|
930 | if (!nodeIds.TryGetValue(parent, out parentId)) {
|
---|
931 | parentId = nodeIds.Count + 1;
|
---|
932 | var avgNodeQ = parent.actionStatistics.AverageQuality;
|
---|
933 | var tries = parent.actionStatistics.Tries;
|
---|
934 | if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
|
---|
935 | var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
|
---|
936 | if (parent.actionStatistics.Tries > threshold)
|
---|
937 | sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);
|
---|
938 | nodeIds.Add(parent, parentId);
|
---|
939 | }
|
---|
940 | foreach (var child in kvp.Value) {
|
---|
941 | int childId;
|
---|
942 | if (!nodeIds.TryGetValue(child, out childId)) {
|
---|
943 | childId = nodeIds.Count + 1;
|
---|
944 | nodeIds.Add(child, childId);
|
---|
945 | }
|
---|
946 | var avgNodeQ = child.actionStatistics.AverageQuality;
|
---|
947 | var tries = child.actionStatistics.Tries;
|
---|
948 | if (tries < 1) continue;
|
---|
949 | if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
|
---|
950 | var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
|
---|
951 | if (tries > threshold) {
|
---|
952 | sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);
|
---|
953 | var edgeLabel = child.expr;
|
---|
954 | // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
|
---|
955 | sb.Write("{0} -> {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel);
|
---|
956 | }
|
---|
957 | }
|
---|
958 | }
|
---|
959 |
|
---|
960 | sb.Write("}");
|
---|
961 | return sb.ToString();
|
---|
962 | }
|
---|
963 | }
|
---|
964 | }
|
---|