#region License Information
/* HeuristicLab
* Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using HeuristicLab.Core;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
using HeuristicLab.Optimization;
using HeuristicLab.Problems.DataAnalysis;
using HeuristicLab.Problems.DataAnalysis.Symbolic;
using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
using HeuristicLab.Random;
namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
public static class MctsSymbolicRegressionStatic {
// OBJECTIVES:
// 1) solve toy problems without numeric constants (to show that structure search is effective / efficient)
// - e.g. Keijzer, Nguyen ... where no numeric constants are involved
// - assumptions:
// - we don't know the necessary operations or functions -> all available functions could be necessary
// - but we do not need to tune numeric constants -> no scaling of input variables x!
// 2) Solve toy problems with numeric constants to make the algorithm invariant concerning variable scale.
// This is important for real world applications.
// - e.g. Korns or Vladislavleva problems where numeric constants are involved
// - assumptions:
// - any numeric constant is possible (a-priori we might assume that small abs. constants are more likely)
// - standardization of variables is possible (or might be necessary) as we adjust numeric parameters of the expression anyway
// - to simplify the problem we can restrict the set of functions e.g. we assume which functions are necessary for the problem instance
// -> several steps: (a) polynomials, (b) rational polynomials, (c) exponential or logarithmic functions, rational functions with exponential and logarithmic parts
// 3) efficiency and effectiveness for real-world problems
// - e.g. Tower problem
// - (1) and (2) combined, structure search must be effective in combination with numeric optimization of constants
//
// TODO: The samples of x1*... or x2*... do not give any information about the relevance of the interaction term x1*x2 in general!
// --> E.g. if x1, x2 ~ N(0, 1) or U(-1, 1) this is trivial to show
// --> Therefore, looking at roll-out statistics for arm selection (MCTS-style) is useless in the general case!
// --> It is necessary to rely on other features for the arm selection.
// --> TODO: Which heuristics can we apply?
// TODO: Solve Poly-10
// TODO: rename everything as this is not MCTS anymore
// TODO: when a path to an expression is explored first (e.g. x1 + x2)
// and later we find the a longer form x1 + x1 + x2 where the number of variable references
// exceeds the maximum in the automaton this leads to an error (see unit tests)
// TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants
// TODO: check if transformation of y is correct and works (Obj 2)
// TODO: The algorithm is not invariant to location and scale of variables.
// Include offset for variables as parameter (for Objective 2)
// TODO: why does LM optimization converge so slowly with exp(x), log(x), and 1/x allowed (Obj 2)?
// TODO: support e(-x) and possibly (1/-x) (Obj 1)
// TODO: is it OK to initialize all constants to 1 (Obj 2)?
// TODO: improve memory usage
// TODO: analyze / improve perf of ExprHashing (canonical form for expressions)
// TODO: support empty test partition
// TODO: the algorithm should be invariant to linear transformations of the space (y = f(x') = f( Ax ) ) for invertible transformations A --> see unit tests
#region static API
public interface IState {
bool Done { get; }
ISymbolicRegressionModel BestModel { get; }
double BestSolutionTrainingQuality { get; }
double BestSolutionTestQuality { get; }
IEnumerable ParetoBestModels { get; }
int TotalRollouts { get; }
int EffectiveRollouts { get; }
int FuncEvaluations { get; }
int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
// TODO other stats on LM optimizer might be interesting here
}
// created through factory method
private class State : IState {
private const int MaxParams = 100;
// state variables used by MCTS
internal readonly Automaton automaton;
internal IRandom random { get; private set; }
internal readonly Tree tree;
internal readonly Func evalFun;
// MCTS might get stuck. Track statistics on the number of effective roll-outs
internal int totalRollouts;
internal int effectiveRollouts;
// state variables used only internally (for eval function)
private readonly IRegressionProblemData problemData;
private readonly double[][] x;
private readonly double[] y;
private readonly double[][] testX;
private readonly double[] testY;
private readonly double[] scalingFactor;
private readonly double[] scalingOffset;
private readonly double yStdDev; // for scaling parameters (e.g. stopping condition for LM)
private readonly int constOptIterations;
private readonly double lambda; // weight of penalty term for regularization
private readonly double lowerEstimationLimit, upperEstimationLimit;
private readonly bool collectParetoOptimalModels;
private readonly List paretoBestModels = new List();
private readonly List paretoFront = new List(); // matching the models
private readonly ExpressionEvaluator evaluator, testEvaluator;
internal readonly Dictionary> children = new Dictionary>();
internal readonly Dictionary> parents = new Dictionary>();
internal readonly Dictionary nodes = new Dictionary();
// values for best solution
private double bestR;
private byte[] bestCode;
private int bestNParams;
private double[] bestConsts;
// stats
private int funcEvaluations;
private int gradEvaluations;
// buffers
private readonly double[] ones; // vector of ones (as default params)
private readonly double[] constsBuf;
private readonly double[] predBuf, testPredBuf;
private readonly double[][] gradBuf;
public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,
int constOptIterations, double lambda,
bool collectParetoOptimalModels = false,
double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
bool allowProdOfVars = true,
bool allowExp = true,
bool allowLog = true,
bool allowInv = true,
bool allowMultipleTerms = false) {
if (lambda < 0) throw new ArgumentException("Lambda must be larger or equal zero", "lambda");
this.problemData = problemData;
this.constOptIterations = constOptIterations;
this.lambda = lambda;
this.evalFun = this.Eval;
this.lowerEstimationLimit = lowerEstimationLimit;
this.upperEstimationLimit = upperEstimationLimit;
this.collectParetoOptimalModels = collectParetoOptimalModels;
random = new MersenneTwister(randSeed);
// prepare data for evaluation
double[][] x;
double[] y;
double[][] testX;
double[] testY;
double[] scalingFactor;
double[] scalingOffset;
// get training and test datasets (scale linearly based on training set if required)
GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
this.x = x;
this.y = y;
this.yStdDev = HeuristicLab.Common.EnumerableStatisticExtensions.StandardDeviation(y);
this.testX = testX;
this.testY = testY;
this.scalingFactor = scalingFactor;
this.scalingOffset = scalingOffset;
this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
// we need a separate evaluator because the vector length for the test dataset might differ
this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
this.automaton = new Automaton(x, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms, maxVariables);
this.tree = new Tree() {
state = automaton.CurrentState,
expr = "",
level = 0
};
// reset best solution
this.bestR = 0;
// code for default solution (constant model)
this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
this.bestNParams = 0;
this.bestConsts = null;
// init buffers
this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
constsBuf = new double[MaxParams];
this.predBuf = new double[y.Length];
this.testPredBuf = new double[testY.Length];
this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
}
#region IState inferface
public bool Done { get { return tree != null && tree.Done; } }
public double BestSolutionTrainingQuality {
get {
evaluator.Exec(bestCode, x, bestConsts, predBuf);
return Rho(y, predBuf);
}
}
public double BestSolutionTestQuality {
get {
testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
return Rho(testY, testPredBuf);
}
}
// takes the code of the best solution and creates and equivalent symbolic regression models
public ISymbolicRegressionModel BestModel {
get {
var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
model.Scale(problemData); // apply linear scaling
return model;
}
}
public IEnumerable ParetoBestModels {
get { return paretoBestModels; }
}
public int TotalRollouts { get { return totalRollouts; } }
public int EffectiveRollouts { get { return effectiveRollouts; } }
public int FuncEvaluations { get { return funcEvaluations; } }
public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
#endregion
#if DEBUG
public string ExprStr(Automaton automaton) {
byte[] code;
int nParams;
automaton.GetCode(out code, out nParams);
var generator = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
var @params = Enumerable.Repeat(1.0, nParams).ToArray();
var root = generator.Exec(code, @params, nParams, null, null);
var formatter = new InfixExpressionFormatter();
return formatter.Format(new SymbolicExpressionTree(root));
}
#endif
private double Eval(byte[] code, int nParams) {
double[] optConsts;
double q;
Eval(code, nParams, out q, out optConsts);
// single objective best
if (q > bestR) {
bestR = q;
bestNParams = nParams;
this.bestCode = new byte[code.Length];
this.bestConsts = new double[bestNParams];
Array.Copy(code, bestCode, code.Length);
Array.Copy(optConsts, bestConsts, bestNParams);
}
if (collectParetoOptimalModels) {
// multi-objective best
var complexity = // SymbolicDataAnalysisModelComplexityCalculator.CalculateComplexity() TODO: implement Kommenda's tree complexity directly in the evaluator
Array.FindIndex(code, (opc) => opc == (byte)OpCodes.Exit); // use length of expression as surrogate for complexity
UpdateParetoFront(q, complexity, code, optConsts, nParams, scalingFactor, scalingOffset);
}
return q;
}
private void Eval(byte[] code, int nParams, out double rho, out double[] optConsts) {
// we make a first pass to determine a valid starting configuration for all constants
// constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
// scale and offset are set to optimal starting configuration
// assumes scale is the first param and offset is the last param
// reset constants
Array.Copy(ones, constsBuf, nParams);
evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
funcEvaluations++;
if (nParams == 0 || constOptIterations < 0) {
// if we don't need to optimize parameters then we are done
// changing scale and offset does not influence r²
rho = Rho(y, predBuf);
optConsts = constsBuf;
} else {
// optimize constants using the starting point calculated above
OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
evaluator.Exec(code, x, constsBuf, predBuf);
funcEvaluations++;
rho = Rho(y, predBuf);
optConsts = constsBuf;
}
}
#region helpers
private static double Rho(IEnumerable x, IEnumerable y) {
OnlineCalculatorError error;
double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
return error == OnlineCalculatorError.None ? r : 0.0;
}
private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
Array.Copy(consts, optConsts, nParams);
// direct usage of LM is recommended in alglib manual for better performance than the lsfit interface (which uses lm internally).
alglib.minlmstate state;
alglib.minlmreport rep = null;
alglib.minlmcreatevj(y.Length + 1, optConsts, out state); // +1 for penalty term
// Using the change of the gradient as stopping criterion is recommended in alglib manual.
// However, the most recent version of alglib (as of Oct 2017) only supports epsX as stopping criterion
alglib.minlmsetcond(state, epsg: 1E-6 * yStdDev, epsf: epsF, epsx: 0.0, maxits: nIters);
// alglib.minlmsetgradientcheck(state, 1E-5);
alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
alglib.minlmresults(state, out optConsts, out rep);
funcEvaluations += rep.nfunc;
gradEvaluations += rep.njac * nParams;
if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
// only use optimized constants if successful
if (rep.terminationtype >= 0) {
Array.Copy(optConsts, consts, optConsts.Length);
}
}
private void Func(double[] arg, double[] fi, object obj) {
var code = (byte[])obj;
int n = predBuf.Length;
evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
for (int r = 0; r < n; r++) {
var res = predBuf[r] - y[r];
fi[r] = res;
}
var penaltyIdx = fi.Length - 1;
fi[penaltyIdx] = 0.0;
// calc length of parameter vector for regularization
var aa = 0.0;
for (int i = 0; i < arg.Length; i++) {
aa += arg[i] * arg[i];
}
if (lambda > 0 && aa > 0) {
// scale lambda using stdDev(y) to make the parameter independent of the scale of y
// scale lambda using n to make parameter independent of the number of training points
// take the root because LM squares the result
fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
}
}
private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
int n = predBuf.Length;
int nParams = arg.Length;
var code = (byte[])obj;
evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
for (int r = 0; r < n; r++) {
var res = predBuf[r] - y[r];
fi[r] = res;
for (int k = 0; k < nParams; k++) {
jac[r, k] = gradBuf[k][r];
}
}
// calc length of parameter vector for regularization
double aa = 0.0;
for (int i = 0; i < arg.Length; i++) {
aa += arg[i] * arg[i];
}
var penaltyIdx = fi.Length - 1;
if (lambda > 0 && aa > 0) {
fi[penaltyIdx] = 0.0;
// scale lambda using stdDev(y) to make the parameter independent of the scale of y
// scale lambda using n to make parameter independent of the number of training points
// take the root because alglib LM squares the result
fi[penaltyIdx] = Math.Sqrt(n * lambda / yStdDev * aa);
for (int i = 0; i < arg.Length; i++) {
jac[penaltyIdx, i] = 0.5 / fi[penaltyIdx] * 2 * n * lambda / yStdDev * arg[i];
}
} else {
fi[penaltyIdx] = 0.0;
for (int i = 0; i < arg.Length; i++) {
jac[penaltyIdx, i] = 0.0;
}
}
}
private void UpdateParetoFront(double q, int complexity, byte[] code, double[] param, int nParam,
double[] scalingFactor, double[] scalingOffset) {
double[] best = new double[2];
double[] cur = new double[2] { q, complexity };
bool[] max = new[] { true, false };
var isNonDominated = true;
foreach (var e in paretoFront) {
var domRes = DominationCalculator.Dominates(cur, e, max, true);
if (domRes == DominationResult.IsDominated) {
isNonDominated = false;
break;
}
}
if (isNonDominated) {
paretoFront.Add(cur);
// create model
var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
var t = new SymbolicExpressionTree(treeGen.Exec(code, param, nParam, scalingFactor, scalingOffset));
var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
model.Scale(problemData); // apply linear scaling
var sol = model.CreateRegressionSolution(this.problemData);
sol.Name = string.Format("{0:N5} {1}", q, complexity);
paretoBestModels.Add(sol);
}
for (int i = paretoFront.Count - 2; i >= 0; i--) {
var @ref = paretoFront[i];
var domRes = DominationCalculator.Dominates(cur, @ref, max, true);
if (domRes == DominationResult.Dominates) {
paretoFront.RemoveAt(i);
paretoBestModels.RemoveAt(i);
}
}
}
#endregion
}
///
/// Static method to initialize a state for the algorithm
///
/// The problem data
/// Random seed.
/// Maximum number of variable references that are allowed in the expression.
/// Optionally scale input variables to the interval [0..1] (recommended)
/// Maximum number of iterations for constants optimization (Levenberg-Marquardt)
/// Penalty factor for regularization (0..inf.), small penalty disabled regularization.
/// Tree search policy (random, ucb, eps-greedy, ...)
/// Optionally collect all Pareto-optimal solutions having minimal length and error.
/// Optionally limit the result of the expression to this lower value.
/// Optionally limit the result of the expression to this upper value.
/// Allow products of expressions.
/// Allow expressions with exponentials.
/// Allow expressions with logarithms
/// Allow expressions with 1/x
/// Allow expressions which are sums of multiple terms.
///
public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0,
bool collectParameterOptimalModels = false,
double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
bool allowProdOfVars = true,
bool allowExp = true,
bool allowLog = true,
bool allowInv = true,
bool allowMultipleTerms = false
) {
return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,
collectParameterOptimalModels,
lowerEstimationLimit, upperEstimationLimit,
allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
}
// returns the quality of the evaluated solution
public static double MakeStep(IState state) {
var mctsState = state as State;
if (mctsState == null) throw new ArgumentException("state");
if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
return TreeSearch(mctsState);
}
#endregion
private static double TreeSearch(State mctsState) {
var automaton = mctsState.automaton;
var tree = mctsState.tree;
var eval = mctsState.evalFun;
var rand = mctsState.random;
double q = 0;
bool success = false;
do {
automaton.Reset();
success = TryTreeSearchRec2(rand, tree, automaton, eval, mctsState, out q);
mctsState.totalRollouts++;
} while (!success && !tree.Done);
if (success) {
mctsState.effectiveRollouts++;
#if DEBUG
// Console.WriteLine(mctsState.ExprStr(automaton));
#endif
return q;
} else return 0.0;
}
// search forward
private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton,
Func eval,
State state,
out double q) {
// ROLLOUT AND EXPANSION
// We are navigating a graph (states might be reached via different paths) instead of a tree.
// State equivalence is checked through ExprHash (based on the generated code through the path).
// We switch between rollout-mode and expansion mode.
// Rollout-mode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB).
// Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression).
// In expansion mode we might re-enter the graph and switch back to rollout-mode.
// We do this until we reach a complete expression (final state).
// Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent.
// Sub-graphs which have been completely searched are marked as done.
// Roll-out could lead to a state where all follow-states are done. In this case we call the rollout ineffective.
while (!automaton.IsFinalState(automaton.CurrentState)) {
// Console.WriteLine(automaton.stateNames[automaton.CurrentState]);
if (state.children.ContainsKey(tree)) {
if (state.children[tree].All(ch => ch.Done)) {
tree.Done = true;
break;
}
// ROLLOUT INSIDE TREE
// UCT selection within tree
int selectedIdx = 0;
if (state.children[tree].Count > 1) {
selectedIdx = SelectInternal(state.children[tree], rand);
}
tree = state.children[tree][selectedIdx];
// all steps where no alternatives could be taken immediately (without expanding the tree)
// TODO: simplification of the automaton
int[] possibleFollowStates = new int[1000];
int nFs;
automaton.FollowStates(automaton.CurrentState, ref possibleFollowStates, out nFs);
Debug.Assert(possibleFollowStates.Contains(tree.state));
automaton.Goto(tree.state);
} else {
// EXPAND
int[] possibleFollowStates = new int[1000];
int nFs;
string actionString = "";
automaton.FollowStates(automaton.CurrentState, ref possibleFollowStates, out nFs);
if (nFs == 0) {
// stuck in a dead end (no final state and no allowed follow states)
tree.Done = true;
break;
}
var newChildren = new List(nFs);
state.children.Add(tree, newChildren);
for (int i = 0; i < nFs; i++) {
Tree child = null;
// for selected states (EvalStates) we introduce state unification (detection of equivalent states)
if (automaton.IsEvalState(possibleFollowStates[i])) {
var hc = Hashcode(automaton);
hc = ((hc << 5) + hc) ^ (ulong)tree.state; // TODO fix unit test for structure enumeration
if (!state.nodes.TryGetValue(hc, out child)) {
// Console.WriteLine("New expression (hash: {0}, state: {1})", Hashcode(automaton), automaton.stateNames[possibleFollowStates[i]]);
child = new Tree() {
state = possibleFollowStates[i],
expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
level = tree.level + 1
};
state.nodes.Add(hc, child);
}
// only allow forward edges (don't add the child if we would go back in the graph)
else if (child.level > tree.level) {
// Console.WriteLine("Existing expression (hash: {0}, state: {1})", Hashcode(automaton), automaton.stateNames[possibleFollowStates[i]]);
// whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
// to all parents
BackpropagateStatistics(tree, state, child.visits);
} else {
// Console.WriteLine("Cycle (hash: {0}, state: {1})", Hashcode(automaton), automaton.stateNames[possibleFollowStates[i]]);
// prevent cycles
Debug.Assert(child.level <= tree.level);
child = null;
}
} else {
child = new Tree() {
state = possibleFollowStates[i],
expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
level = tree.level + 1
};
}
if (child != null)
newChildren.Add(child);
}
if (!newChildren.Any()) {
// stuck in a dead end (no final state and no allowed follow states)
tree.Done = true;
break;
}
foreach (var ch in newChildren) {
if (!state.parents.ContainsKey(ch)) {
state.parents.Add(ch, new List());
}
state.parents[ch].Add(tree);
}
// follow one of the children
tree = SelectStateLeadingToFinal(automaton, tree, rand, state);
automaton.Goto(tree.state);
}
}
bool success;
// EVALUATE TREE
if (!tree.Done && automaton.IsFinalState(automaton.CurrentState)) {
tree.Done = true;
// for debugging
// tree.expr = state.ExprStr(automaton);
byte[] code; int nParams;
automaton.GetCode(out code, out nParams);
q = eval(code, nParams);
success = true;
BackpropagateQuality(tree, q, state);
} else {
// we got stuck in roll-out (not evaluation necessary!)
q = 0.0;
success = false;
}
// RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS
// Update statistics
// Set branch to done if all children are done.
BackpropagateDone(tree, state);
BackpropagateDebugStats(tree, q, state);
return success;
}
private static int SelectInternal(List list, IRandom rand) {
Debug.Assert(list.Any(t => !t.Done));
// check if there is any node which has not been visited
for(int i=0;i ch.Done)) {
tree.Done = true;
// children[tree] = null; keep all nodes
}
if (state.parents.ContainsKey(tree)) {
foreach (var parent in state.parents[tree]) {
BackpropagateDone(parent, state);
}
}
}
private static void BackpropagateDebugStats(Tree tree, double q, State state) {
if (state.parents.ContainsKey(tree)) {
foreach (var parent in state.parents[tree]) {
BackpropagateDebugStats(parent, q, state);
}
}
}
private static Tree SelectStateLeadingToFinal(Automaton automaton, Tree tree, IRandom rand, State state) {
// find the child with the smallest state value (smaller values are closer to the final state)
int selectedChildIdx = 0;
var children = state.children[tree];
Tree minChild = children.First();
for (int i = 1; i < children.Count; i++) {
if (children[i].state < minChild.state)
selectedChildIdx = i;
}
return children[selectedChildIdx];
}
// scales data and extracts values from dataset into arrays
private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable rows,
out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
xs = new double[problemData.AllowedInputVariables.Count()][];
var i = 0;
if (scaleVariables) {
scalingFactor = new double[xs.Length + 1];
scalingOffset = new double[xs.Length + 1];
} else {
scalingFactor = null;
scalingOffset = null;
}
foreach (var var in problemData.AllowedInputVariables) {
if (scaleVariables) {
var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
var range = maxX - minX;
// scaledX = (x - min) / range
var sf = 1.0 / range;
var offset = -minX / range;
scalingFactor[i] = sf;
scalingOffset[i] = offset;
i++;
}
}
if (scaleVariables) {
// transform target variable to zero-mean
scalingFactor[i] = 1.0;
scalingOffset[i] = -problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Average();
}
GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
}
// extract values from dataset into arrays
private static void GenerateData(IRegressionProblemData problemData, IEnumerable rows, double[] scalingFactor, double[] scalingOffset,
out double[][] xs, out double[] y) {
xs = new double[problemData.AllowedInputVariables.Count()][];
int i = 0;
foreach (var var in problemData.AllowedInputVariables) {
var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
xs[i++] =
problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
}
{
var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(yi => yi * sf + offset).ToArray();
}
}
// for debugging only
#region debugging
private static string TraceTree(Tree tree, State state) {
var sb = new StringBuilder();
sb.Append(
@"digraph {
ratio = fill;
node [style=filled];
");
int nodeId = 0;
TraceTreeRec(tree, 0, sb, ref nodeId, state);
sb.Append("}");
return sb.ToString();
}
private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) {
var tries = tree.visits;
sb.AppendFormat("{0} [label=\"{1}\"]; ", parentId, tries).AppendLine();
var list = new List>();
if (state.children.ContainsKey(tree)) {
foreach (var ch in state.children[tree]) {
nextId++;
tries = ch.visits;
sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine();
sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", parentId, nextId, ch.expr).AppendLine();
list.Add(Tuple.Create(tries, nextId, ch));
}
foreach (var tup in list) {
var ch = tup.Item3;
var chId = tup.Item2;
if (state.children.ContainsKey(ch) && state.children[ch].Count == 1) {
var chch = state.children[ch].First();
nextId++;
tries = chch.visits;
sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine();
sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", chId, nextId, chch.expr).AppendLine();
}
}
foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {
TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId, state);
}
}
}
private static string WriteTree(Tree tree, State state) {
var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);
var nodeIds = new Dictionary();
sb.Write(
@"digraph {
ratio = fill;
node [style=filled];
");
int threshold = /* state.nodes.Count > 500 ? 10 : */ 0;
foreach (var kvp in state.children) {
var parent = kvp.Key;
int parentId;
if (!nodeIds.TryGetValue(parent, out parentId)) {
parentId = nodeIds.Count + 1;
var tries = parent.visits;
if (tries > threshold)
sb.Write("{0} [label=\"{1}\"]; ", parentId, tries);
nodeIds.Add(parent, parentId);
}
foreach (var child in kvp.Value) {
int childId;
if (!nodeIds.TryGetValue(child, out childId)) {
childId = nodeIds.Count + 1;
nodeIds.Add(child, childId);
}
var tries = child.visits;
if (tries < 1) continue;
if (tries > threshold) {
sb.Write("{0} [label=\"{1}\"]; ", childId, tries);
var edgeLabel = child.expr;
// if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
sb.Write("{0} -> {1} [label=\"{2}\"]", parentId, childId, edgeLabel);
}
}
}
sb.Write("}");
return sb.ToString();
}
#endregion
}
}