#region License Information
/* HeuristicLab
* Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using HeuristicLab.Analysis;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Optimization;
using HeuristicLab.Parameters;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using HeuristicLab.Problems.DataAnalysis;
namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
[Item("MCTS Symbolic Regression", "Monte carlo tree search for symbolic regression. Useful mainly as a base learner in gradient boosting.")]
[StorableClass]
[Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
public class MctsSymbolicRegressionAlgorithm : BasicAlgorithm {
public override Type ProblemType {
get { return typeof(IRegressionProblem); }
}
public new IRegressionProblem Problem {
get { return (IRegressionProblem)base.Problem; }
set { base.Problem = value; }
}
#region ParameterNames
private const string IterationsParameterName = "Iterations";
private const string MaxVariablesParameterName = "Maximum variables";
private const string ScaleVariablesParameterName = "Scale variables";
private const string AllowedFactorsParameterName = "Allowed factors";
private const string ConstantOptimizationIterationsParameterName = "Iterations (constant optimization)";
private const string CParameterName = "C";
private const string SeedParameterName = "Seed";
private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
private const string UpdateIntervalParameterName = "UpdateInterval";
private const string CreateSolutionParameterName = "CreateSolution";
private const string PunishmentFactorParameterName = "PunishmentFactor";
private const string VariableProductFactorName = "product(xi)";
private const string ExpFactorName = "exp(c * product(xi))";
private const string LogFactorName = "log(c + sum(c*product(xi))";
private const string InvFactorName = "1 / (1 + sum(c*product(xi))";
private const string FactorSumsName = "sum of multiple terms";
#endregion
#region ParameterProperties
public IFixedValueParameter IterationsParameter {
get { return (IFixedValueParameter)Parameters[IterationsParameterName]; }
}
public IFixedValueParameter MaxVariableReferencesParameter {
get { return (IFixedValueParameter)Parameters[MaxVariablesParameterName]; }
}
public IFixedValueParameter ScaleVariablesParameter {
get { return (IFixedValueParameter)Parameters[ScaleVariablesParameterName]; }
}
public IFixedValueParameter ConstantOptimizationIterationsParameter {
get { return (IFixedValueParameter)Parameters[ConstantOptimizationIterationsParameterName]; }
}
public IFixedValueParameter CParameter {
get { return (IFixedValueParameter)Parameters[CParameterName]; }
}
public IFixedValueParameter PunishmentFactorParameter {
get { return (IFixedValueParameter)Parameters[PunishmentFactorParameterName]; }
}
public IValueParameter> AllowedFactorsParameter {
get { return (IValueParameter>)Parameters[AllowedFactorsParameterName]; }
}
public IFixedValueParameter SeedParameter {
get { return (IFixedValueParameter)Parameters[SeedParameterName]; }
}
public FixedValueParameter SetSeedRandomlyParameter {
get { return (FixedValueParameter)Parameters[SetSeedRandomlyParameterName]; }
}
public IFixedValueParameter UpdateIntervalParameter {
get { return (IFixedValueParameter)Parameters[UpdateIntervalParameterName]; }
}
public IFixedValueParameter CreateSolutionParameter {
get { return (IFixedValueParameter)Parameters[CreateSolutionParameterName]; }
}
#endregion
#region Properties
public int Iterations {
get { return IterationsParameter.Value.Value; }
set { IterationsParameter.Value.Value = value; }
}
public int Seed {
get { return SeedParameter.Value.Value; }
set { SeedParameter.Value.Value = value; }
}
public bool SetSeedRandomly {
get { return SetSeedRandomlyParameter.Value.Value; }
set { SetSeedRandomlyParameter.Value.Value = value; }
}
public int MaxVariableReferences {
get { return MaxVariableReferencesParameter.Value.Value; }
set { MaxVariableReferencesParameter.Value.Value = value; }
}
public double C {
get { return CParameter.Value.Value; }
set { CParameter.Value.Value = value; }
}
public double PunishmentFactor {
get { return PunishmentFactorParameter.Value.Value; }
set { PunishmentFactorParameter.Value.Value = value; }
}
public ICheckedItemList AllowedFactors {
get { return AllowedFactorsParameter.Value; }
}
public int ConstantOptimizationIterations {
get { return ConstantOptimizationIterationsParameter.Value.Value; }
set { ConstantOptimizationIterationsParameter.Value.Value = value; }
}
public bool ScaleVariables {
get { return ScaleVariablesParameter.Value.Value; }
set { ScaleVariablesParameter.Value.Value = value; }
}
public bool CreateSolution {
get { return CreateSolutionParameter.Value.Value; }
set { CreateSolutionParameter.Value.Value = value; }
}
#endregion
[StorableConstructor]
protected MctsSymbolicRegressionAlgorithm(bool deserializing) : base(deserializing) { }
protected MctsSymbolicRegressionAlgorithm(MctsSymbolicRegressionAlgorithm original, Cloner cloner)
: base(original, cloner) {
}
public override IDeepCloneable Clone(Cloner cloner) {
return new MctsSymbolicRegressionAlgorithm(this, cloner);
}
public MctsSymbolicRegressionAlgorithm() {
Problem = new RegressionProblem(); // default problem
var defaultFactorsList = new CheckedItemList(
new string[] { VariableProductFactorName, ExpFactorName, LogFactorName, InvFactorName, FactorSumsName }
.Select(s => new StringValue(s).AsReadOnly())
).AsReadOnly();
defaultFactorsList.SetItemCheckedState(defaultFactorsList.First(s => s.Value == FactorSumsName), false);
Parameters.Add(new FixedValueParameter(IterationsParameterName,
"Number of iterations", new IntValue(100000)));
Parameters.Add(new FixedValueParameter(SeedParameterName,
"The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
Parameters.Add(new FixedValueParameter(SetSeedRandomlyParameterName,
"True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
Parameters.Add(new FixedValueParameter(MaxVariablesParameterName,
"Maximal number of variables references in the symbolic regression models (multiple usages of the same variable are counted)", new IntValue(5)));
Parameters.Add(new FixedValueParameter(CParameterName,
"Balancing parameter in UCT formula (0 < c < 1000). Small values: greedy search. Large values: enumeration. Default: 1.0", new DoubleValue(1.0)));
Parameters.Add(new ValueParameter>(AllowedFactorsParameterName,
"Choose which expressions are allowed as factors in the model.", defaultFactorsList));
Parameters.Add(new FixedValueParameter(ConstantOptimizationIterationsParameterName,
"Number of iterations for constant optimization. A small number of iterations should be sufficient for most models. " +
"Set to 0 to disable constants optimization.", new IntValue(10)));
Parameters.Add(new FixedValueParameter(ScaleVariablesParameterName,
"Set to true to scale all input variables to the range [0..1]", new BoolValue(false)));
Parameters[ScaleVariablesParameterName].Hidden = true;
Parameters.Add(new FixedValueParameter(PunishmentFactorParameterName, "Estimations of models can be bounded. The estimation limits are calculated in the following way (lb = mean(y) - punishmentFactor*range(y), ub = mean(y) + punishmentFactor*range(y))", new DoubleValue(10)));
Parameters[PunishmentFactorParameterName].Hidden = true;
Parameters.Add(new FixedValueParameter(UpdateIntervalParameterName,
"Number of iterations until the results are updated", new IntValue(100)));
Parameters[UpdateIntervalParameterName].Hidden = true;
Parameters.Add(new FixedValueParameter(CreateSolutionParameterName,
"Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
Parameters[CreateSolutionParameterName].Hidden = true;
}
[StorableHook(HookType.AfterDeserialization)]
private void AfterDeserialization() {
}
protected override void Run(CancellationToken cancellationToken) {
// Set up the algorithm
if (SetSeedRandomly) Seed = new System.Random().Next();
// Set up the results display
var iterations = new IntValue(0);
Results.Add(new Result("Iterations", iterations));
var table = new DataTable("Qualities");
table.Rows.Add(new DataRow("Best quality"));
table.Rows.Add(new DataRow("Current best quality"));
table.Rows.Add(new DataRow("Average quality"));
Results.Add(new Result("Qualities", table));
var bestQuality = new DoubleValue();
Results.Add(new Result("Best quality", bestQuality));
var curQuality = new DoubleValue();
Results.Add(new Result("Current best quality", curQuality));
var avgQuality = new DoubleValue();
Results.Add(new Result("Average quality", avgQuality));
var totalRollouts = new IntValue();
Results.Add(new Result("Total rollouts", totalRollouts));
var effRollouts = new IntValue();
Results.Add(new Result("Effective rollouts", effRollouts));
var funcEvals = new IntValue();
Results.Add(new Result("Function evaluations", funcEvals));
var gradEvals = new IntValue();
Results.Add(new Result("Gradient evaluations", gradEvals));
// same as in SymbolicRegressionSingleObjectiveProblem
var y = Problem.ProblemData.Dataset.GetDoubleValues(Problem.ProblemData.TargetVariable,
Problem.ProblemData.TrainingIndices);
var avgY = y.Average();
var minY = y.Min();
var maxY = y.Max();
var range = maxY - minY;
var lowerLimit = avgY - PunishmentFactor * range;
var upperLimit = avgY + PunishmentFactor * range;
// init
var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
if (!AllowedFactors.CheckedItems.Any()) throw new ArgumentException("At least on type of factor must be allowed");
var state = MctsSymbolicRegressionStatic.CreateState(problemData, (uint)Seed, MaxVariableReferences, C, ScaleVariables, ConstantOptimizationIterations,
lowerLimit, upperLimit,
allowProdOfVars: AllowedFactors.CheckedItems.Any(s => s.Value.Value == VariableProductFactorName),
allowExp: AllowedFactors.CheckedItems.Any(s => s.Value.Value == ExpFactorName),
allowLog: AllowedFactors.CheckedItems.Any(s => s.Value.Value == LogFactorName),
allowInv: AllowedFactors.CheckedItems.Any(s => s.Value.Value == InvFactorName),
allowMultipleTerms: AllowedFactors.CheckedItems.Any(s => s.Value.Value == FactorSumsName)
);
var updateInterval = UpdateIntervalParameter.Value.Value;
double sumQ = 0.0;
double bestQ = 0.0;
double curBestQ = 0.0;
double q = 0.0;
int n = 0;
// Loop until iteration limit reached or canceled.
for (int i = 0; i < Iterations && !state.Done; i++) {
cancellationToken.ThrowIfCancellationRequested();
q = MctsSymbolicRegressionStatic.MakeStep(state);
sumQ += q; // sum of qs in the last updateinterval iterations
curBestQ = Math.Max(q, curBestQ); // the best q in the last updateinterval iterations
bestQ = Math.Max(q, bestQ); // the best q overall
n++;
// iteration results
if (n == updateInterval) {
bestQuality.Value = bestQ;
curQuality.Value = curBestQ;
avgQuality.Value = sumQ / n;
sumQ = 0.0;
curBestQ = 0.0;
funcEvals.Value = state.FuncEvaluations;
gradEvals.Value = state.GradEvaluations;
effRollouts.Value = state.EffectiveRollouts;
totalRollouts.Value = state.TotalRollouts;
table.Rows["Best quality"].Values.Add(bestQuality.Value);
table.Rows["Current best quality"].Values.Add(curQuality.Value);
table.Rows["Average quality"].Values.Add(avgQuality.Value);
iterations.Value += n;
n = 0;
}
}
// final results
if (n > 0) {
bestQuality.Value = bestQ;
curQuality.Value = curBestQ;
avgQuality.Value = sumQ / n;
funcEvals.Value = state.FuncEvaluations;
gradEvals.Value = state.GradEvaluations;
effRollouts.Value = state.EffectiveRollouts;
totalRollouts.Value = state.TotalRollouts;
table.Rows["Best quality"].Values.Add(bestQuality.Value);
table.Rows["Current best quality"].Values.Add(curQuality.Value);
table.Rows["Average quality"].Values.Add(avgQuality.Value);
iterations.Value = iterations.Value + n;
}
Results.Add(new Result("Best solution quality (train)", new DoubleValue(state.BestSolutionTrainingQuality)));
Results.Add(new Result("Best solution quality (test)", new DoubleValue(state.BestSolutionTestQuality)));
// produce solution
if (CreateSolution) {
var model = state.BestModel;
// otherwise we produce a regression solution
Results.Add(new Result("Solution", model.CreateRegressionSolution(problemData)));
}
}
}
}