[13645]  1  #region License Information


 2  /* HeuristicLab


[14185]  3  * Copyright (C) 20022016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


[13645]  4  *


 5  * This file is part of HeuristicLab.


 6  *


 7  * HeuristicLab is free software: you can redistribute it and/or modify


 8  * it under the terms of the GNU General Public License as published by


 9  * the Free Software Foundation, either version 3 of the License, or


 10  * (at your option) any later version.


 11  *


 12  * HeuristicLab is distributed in the hope that it will be useful,


 13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


 14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


 15  * GNU General Public License for more details.


 16  *


 17  * You should have received a copy of the GNU General Public License


 18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


 19  */


 20  #endregion


 21 


 22  using System;


 23  using System.Linq;


 24  using System.Runtime.CompilerServices;


 25  using System.Threading;


[13658]  26  using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;


[13645]  27  using HeuristicLab.Analysis;


 28  using HeuristicLab.Common;


 29  using HeuristicLab.Core;


 30  using HeuristicLab.Data;


 31  using HeuristicLab.Optimization;


 32  using HeuristicLab.Parameters;


 33  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


 34  using HeuristicLab.Problems.DataAnalysis;


 35 


 36  namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {


 37  [Item("MCTS Symbolic Regression", "Monte carlo tree search for symbolic regression. Useful mainly as a base learner in gradient boosting.")]


 38  [StorableClass]


 39  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]


[14869]  40  public class MctsSymbolicRegressionAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {


[13645]  41 


 42  #region ParameterNames


 43  private const string IterationsParameterName = "Iterations";


 44  private const string MaxVariablesParameterName = "Maximum variables";


 45  private const string ScaleVariablesParameterName = "Scale variables";


 46  private const string AllowedFactorsParameterName = "Allowed factors";


 47  private const string ConstantOptimizationIterationsParameterName = "Iterations (constant optimization)";


[13658]  48  private const string PolicyParameterName = "Policy";


[13645]  49  private const string SeedParameterName = "Seed";


 50  private const string SetSeedRandomlyParameterName = "SetSeedRandomly";


 51  private const string UpdateIntervalParameterName = "UpdateInterval";


 52  private const string CreateSolutionParameterName = "CreateSolution";


 53  private const string PunishmentFactorParameterName = "PunishmentFactor";


 54 


 55  private const string VariableProductFactorName = "product(xi)";


 56  private const string ExpFactorName = "exp(c * product(xi))";


 57  private const string LogFactorName = "log(c + sum(c*product(xi))";


 58  private const string InvFactorName = "1 / (1 + sum(c*product(xi))";


 59  private const string FactorSumsName = "sum of multiple terms";


 60  #endregion


 61 


 62  #region ParameterProperties


 63  public IFixedValueParameter<IntValue> IterationsParameter {


 64  get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }


 65  }


[13652]  66  public IFixedValueParameter<IntValue> MaxVariableReferencesParameter {


[13645]  67  get { return (IFixedValueParameter<IntValue>)Parameters[MaxVariablesParameterName]; }


 68  }


 69  public IFixedValueParameter<BoolValue> ScaleVariablesParameter {


 70  get { return (IFixedValueParameter<BoolValue>)Parameters[ScaleVariablesParameterName]; }


 71  }


 72  public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {


 73  get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }


 74  }


[13658]  75  public IValueParameter<IPolicy> PolicyParameter {


 76  get { return (IValueParameter<IPolicy>)Parameters[PolicyParameterName]; }


[13645]  77  }


 78  public IFixedValueParameter<DoubleValue> PunishmentFactorParameter {


 79  get { return (IFixedValueParameter<DoubleValue>)Parameters[PunishmentFactorParameterName]; }


 80  }


 81  public IValueParameter<ICheckedItemList<StringValue>> AllowedFactorsParameter {


 82  get { return (IValueParameter<ICheckedItemList<StringValue>>)Parameters[AllowedFactorsParameterName]; }


 83  }


 84  public IFixedValueParameter<IntValue> SeedParameter {


 85  get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }


 86  }


 87  public FixedValueParameter<BoolValue> SetSeedRandomlyParameter {


 88  get { return (FixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }


 89  }


 90  public IFixedValueParameter<IntValue> UpdateIntervalParameter {


 91  get { return (IFixedValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }


 92  }


 93  public IFixedValueParameter<BoolValue> CreateSolutionParameter {


 94  get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }


 95  }


 96  #endregion


 97 


 98  #region Properties


 99  public int Iterations {


 100  get { return IterationsParameter.Value.Value; }


 101  set { IterationsParameter.Value.Value = value; }


 102  }


 103  public int Seed {


 104  get { return SeedParameter.Value.Value; }


 105  set { SeedParameter.Value.Value = value; }


 106  }


 107  public bool SetSeedRandomly {


 108  get { return SetSeedRandomlyParameter.Value.Value; }


 109  set { SetSeedRandomlyParameter.Value.Value = value; }


 110  }


[13652]  111  public int MaxVariableReferences {


 112  get { return MaxVariableReferencesParameter.Value.Value; }


 113  set { MaxVariableReferencesParameter.Value.Value = value; }


[13645]  114  }


[13658]  115  public IPolicy Policy {


 116  get { return PolicyParameter.Value; }


 117  set { PolicyParameter.Value = value; }


[13645]  118  }


 119  public double PunishmentFactor {


 120  get { return PunishmentFactorParameter.Value.Value; }


 121  set { PunishmentFactorParameter.Value.Value = value; }


 122  }


 123  public ICheckedItemList<StringValue> AllowedFactors {


 124  get { return AllowedFactorsParameter.Value; }


 125  }


 126  public int ConstantOptimizationIterations {


 127  get { return ConstantOptimizationIterationsParameter.Value.Value; }


 128  set { ConstantOptimizationIterationsParameter.Value.Value = value; }


 129  }


 130  public bool ScaleVariables {


 131  get { return ScaleVariablesParameter.Value.Value; }


 132  set { ScaleVariablesParameter.Value.Value = value; }


 133  }


 134  public bool CreateSolution {


 135  get { return CreateSolutionParameter.Value.Value; }


 136  set { CreateSolutionParameter.Value.Value = value; }


 137  }


 138  #endregion


 139 


 140  [StorableConstructor]


 141  protected MctsSymbolicRegressionAlgorithm(bool deserializing) : base(deserializing) { }


 142 


 143  protected MctsSymbolicRegressionAlgorithm(MctsSymbolicRegressionAlgorithm original, Cloner cloner)


 144  : base(original, cloner) {


 145  }


 146 


 147  public override IDeepCloneable Clone(Cloner cloner) {


 148  return new MctsSymbolicRegressionAlgorithm(this, cloner);


 149  }


 150 


 151  public MctsSymbolicRegressionAlgorithm() {


 152  Problem = new RegressionProblem(); // default problem


 153 


 154  var defaultFactorsList = new CheckedItemList<StringValue>(


 155  new string[] { VariableProductFactorName, ExpFactorName, LogFactorName, InvFactorName, FactorSumsName }


 156  .Select(s => new StringValue(s).AsReadOnly())


 157  ).AsReadOnly();


 158  defaultFactorsList.SetItemCheckedState(defaultFactorsList.First(s => s.Value == FactorSumsName), false);


 159 


 160  Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName,


 161  "Number of iterations", new IntValue(100000)));


 162  Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName,


 163  "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));


 164  Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName,


 165  "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));


 166  Parameters.Add(new FixedValueParameter<IntValue>(MaxVariablesParameterName,


 167  "Maximal number of variables references in the symbolic regression models (multiple usages of the same variable are counted)", new IntValue(5)));


[13658]  168  // Parameters.Add(new FixedValueParameter<DoubleValue>(CParameterName,


 169  // "Balancing parameter in UCT formula (0 < c < 1000). Small values: greedy search. Large values: enumeration. Default: 1.0", new DoubleValue(1.0)));


 170  Parameters.Add(new ValueParameter<IPolicy>(PolicyParameterName,


 171  "The policy to use for selecting nodes in MCTS (e.g. Ucb)", new Ucb()));


 172  PolicyParameter.Hidden = true;


[13645]  173  Parameters.Add(new ValueParameter<ICheckedItemList<StringValue>>(AllowedFactorsParameterName,


 174  "Choose which expressions are allowed as factors in the model.", defaultFactorsList));


 175 


 176  Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName,


 177  "Number of iterations for constant optimization. A small number of iterations should be sufficient for most models. " +


 178  "Set to 0 to disable constants optimization.", new IntValue(10)));


 179  Parameters.Add(new FixedValueParameter<BoolValue>(ScaleVariablesParameterName,


 180  "Set to true to scale all input variables to the range [0..1]", new BoolValue(false)));


 181  Parameters[ScaleVariablesParameterName].Hidden = true;


 182  Parameters.Add(new FixedValueParameter<DoubleValue>(PunishmentFactorParameterName, "Estimations of models can be bounded. The estimation limits are calculated in the following way (lb = mean(y)  punishmentFactor*range(y), ub = mean(y) + punishmentFactor*range(y))", new DoubleValue(10)));


 183  Parameters[PunishmentFactorParameterName].Hidden = true;


 184  Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName,


 185  "Number of iterations until the results are updated", new IntValue(100)));


 186  Parameters[UpdateIntervalParameterName].Hidden = true;


 187  Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName,


 188  "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));


 189  Parameters[CreateSolutionParameterName].Hidden = true;


 190  }


 191 


 192  [StorableHook(HookType.AfterDeserialization)]


 193  private void AfterDeserialization() {


 194  }


 195 


 196  protected override void Run(CancellationToken cancellationToken) {


 197  // Set up the algorithm


 198  if (SetSeedRandomly) Seed = new System.Random().Next();


 199 


 200  // Set up the results display


 201  var iterations = new IntValue(0);


 202  Results.Add(new Result("Iterations", iterations));


 203 


[13669]  204  var bestSolutionIteration = new IntValue(0);


 205  Results.Add(new Result("Best solution iteration", bestSolutionIteration));


 206 


[13645]  207  var table = new DataTable("Qualities");


 208  table.Rows.Add(new DataRow("Best quality"));


 209  table.Rows.Add(new DataRow("Current best quality"));


 210  table.Rows.Add(new DataRow("Average quality"));


 211  Results.Add(new Result("Qualities", table));


 212 


 213  var bestQuality = new DoubleValue();


 214  Results.Add(new Result("Best quality", bestQuality));


 215 


 216  var curQuality = new DoubleValue();


 217  Results.Add(new Result("Current best quality", curQuality));


 218 


 219  var avgQuality = new DoubleValue();


 220  Results.Add(new Result("Average quality", avgQuality));


 221 


[13651]  222  var totalRollouts = new IntValue();


 223  Results.Add(new Result("Total rollouts", totalRollouts));


 224  var effRollouts = new IntValue();


 225  Results.Add(new Result("Effective rollouts", effRollouts));


 226  var funcEvals = new IntValue();


 227  Results.Add(new Result("Function evaluations", funcEvals));


 228  var gradEvals = new IntValue();


 229  Results.Add(new Result("Gradient evaluations", gradEvals));


 230 


 231 


[13645]  232  // same as in SymbolicRegressionSingleObjectiveProblem


 233  var y = Problem.ProblemData.Dataset.GetDoubleValues(Problem.ProblemData.TargetVariable,


 234  Problem.ProblemData.TrainingIndices);


 235  var avgY = y.Average();


 236  var minY = y.Min();


 237  var maxY = y.Max();


 238  var range = maxY  minY;


 239  var lowerLimit = avgY  PunishmentFactor * range;


 240  var upperLimit = avgY + PunishmentFactor * range;


 241 


 242  // init


 243  var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();


 244  if (!AllowedFactors.CheckedItems.Any()) throw new ArgumentException("At least on type of factor must be allowed");


[13658]  245  var state = MctsSymbolicRegressionStatic.CreateState(problemData, (uint)Seed, MaxVariableReferences, ScaleVariables, ConstantOptimizationIterations,


 246  Policy,


[13645]  247  lowerLimit, upperLimit,


 248  allowProdOfVars: AllowedFactors.CheckedItems.Any(s => s.Value.Value == VariableProductFactorName),


 249  allowExp: AllowedFactors.CheckedItems.Any(s => s.Value.Value == ExpFactorName),


 250  allowLog: AllowedFactors.CheckedItems.Any(s => s.Value.Value == LogFactorName),


 251  allowInv: AllowedFactors.CheckedItems.Any(s => s.Value.Value == InvFactorName),


 252  allowMultipleTerms: AllowedFactors.CheckedItems.Any(s => s.Value.Value == FactorSumsName)


 253  );


 254 


 255  var updateInterval = UpdateIntervalParameter.Value.Value;


 256  double sumQ = 0.0;


 257  double bestQ = 0.0;


 258  double curBestQ = 0.0;


 259  int n = 0;


 260  // Loop until iteration limit reached or canceled.


 261  for (int i = 0; i < Iterations && !state.Done; i++) {


 262  cancellationToken.ThrowIfCancellationRequested();


 263 


[13669]  264  var q = MctsSymbolicRegressionStatic.MakeStep(state);


[13645]  265  sumQ += q; // sum of qs in the last updateinterval iterations


 266  curBestQ = Math.Max(q, curBestQ); // the best q in the last updateinterval iterations


 267  bestQ = Math.Max(q, bestQ); // the best q overall


 268  n++;


 269  // iteration results


 270  if (n == updateInterval) {


[13669]  271  if (bestQ > bestQuality.Value) {


 272  bestSolutionIteration.Value = i;


 273  }


[13645]  274  bestQuality.Value = bestQ;


 275  curQuality.Value = curBestQ;


 276  avgQuality.Value = sumQ / n;


 277  sumQ = 0.0;


 278  curBestQ = 0.0;


 279 


[13651]  280  funcEvals.Value = state.FuncEvaluations;


 281  gradEvals.Value = state.GradEvaluations;


 282  effRollouts.Value = state.EffectiveRollouts;


 283  totalRollouts.Value = state.TotalRollouts;


 284 


[13645]  285  table.Rows["Best quality"].Values.Add(bestQuality.Value);


 286  table.Rows["Current best quality"].Values.Add(curQuality.Value);


 287  table.Rows["Average quality"].Values.Add(avgQuality.Value);


 288  iterations.Value += n;


 289  n = 0;


 290  }


 291  }


 292 


 293  // final results


 294  if (n > 0) {


[13669]  295  if (bestQ > bestQuality.Value) {


 296  bestSolutionIteration.Value = iterations.Value + n;


 297  }


[13645]  298  bestQuality.Value = bestQ;


 299  curQuality.Value = curBestQ;


 300  avgQuality.Value = sumQ / n;


 301 


[13651]  302  funcEvals.Value = state.FuncEvaluations;


 303  gradEvals.Value = state.GradEvaluations;


 304  effRollouts.Value = state.EffectiveRollouts;


 305  totalRollouts.Value = state.TotalRollouts;


 306 


[13645]  307  table.Rows["Best quality"].Values.Add(bestQuality.Value);


 308  table.Rows["Current best quality"].Values.Add(curQuality.Value);


 309  table.Rows["Average quality"].Values.Add(avgQuality.Value);


 310  iterations.Value = iterations.Value + n;


[13651]  311 


[13645]  312  }


 313 


 314 


 315  Results.Add(new Result("Best solution quality (train)", new DoubleValue(state.BestSolutionTrainingQuality)));


 316  Results.Add(new Result("Best solution quality (test)", new DoubleValue(state.BestSolutionTestQuality)));


 317 


[13651]  318 


[13645]  319  // produce solution


 320  if (CreateSolution) {


 321  var model = state.BestModel;


 322 


 323  // otherwise we produce a regression solution


 324  Results.Add(new Result("Solution", model.CreateRegressionSolution(problemData)));


 325  }


 326  }


 327  }


 328  }

