Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/04/17 22:00:52 (7 years ago)
Author:
gkronber
Message:

#2796 worked on MCTS symb reg

Location:
branches/MCTS-SymbReg-2796
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/MCTS-SymbReg-2796

    • Property svn:ignore set to
      TestResults
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionAlgorithm.cs

    r15360 r15403  
    2222using System;
    2323using System.Linq;
    24 using System.Runtime.CompilerServices;
    2524using System.Threading;
    2625using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
     
    3635
    3736namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
    38   [Item("MCTS Symbolic Regression", "Monte carlo tree search for symbolic regression. Useful mainly as a base learner in gradient boosting.")]
     37  // TODO: support pause (persisting/cloning the state)
     38  [Item("MCTS Symbolic Regression", "Monte carlo tree search for symbolic regression.")]
    3939  [StorableClass]
    4040  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
     
    5353    private const string CreateSolutionParameterName = "CreateSolution";
    5454    private const string PunishmentFactorParameterName = "PunishmentFactor";
    55 
    56     private const string VariableProductFactorName = "product(xi)";
    57     private const string ExpFactorName = "exp(c * product(xi))";
    58     private const string LogFactorName = "log(c + sum(c*product(xi))";
    59     private const string InvFactorName = "1 / (1 + sum(c*product(xi))";
    60     private const string FactorSumsName = "sum of multiple terms";
     55    private const string CollectParetoOptimalSolutionsParameterName = "CollectParetoOptimalSolutions";
     56    private const string LambdaParameterName = "Lambda";
     57
     58    private const string VariableProductFactorName = "x * y * ...";
     59    private const string ExpFactorName = "exp(c * x * y ...)";
     60    private const string LogFactorName = "log(c + c1 x + c2 x + ...)";
     61    private const string InvFactorName = "1 / (1 + c1 x + c2 x + ...)";
     62    private const string FactorSumsName = "t1(x) + t2(x) + ... ";
    6163    #endregion
    6264
     
    9496    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
    9597      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
     98    }
     99    public IFixedValueParameter<BoolValue> CollectParetoOptimalSolutionsParameter {
     100      get { return (IFixedValueParameter<BoolValue>)Parameters[CollectParetoOptimalSolutionsParameterName]; }
     101    }
     102    public IFixedValueParameter<DoubleValue> LambdaParameter {
     103      get { return (IFixedValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
    96104    }
    97105    #endregion
     
    136144      get { return CreateSolutionParameter.Value.Value; }
    137145      set { CreateSolutionParameter.Value.Value = value; }
     146    }
     147    public bool CollectParetoOptimalSolutions {
     148      get { return CollectParetoOptimalSolutionsParameter.Value.Value; }
     149      set { CollectParetoOptimalSolutionsParameter.Value.Value = value; }
     150    }
     151    public double Lambda {
     152      get { return LambdaParameter.Value.Value; }
     153      set { LambdaParameter.Value.Value = value; }
    138154    }
    139155    #endregion
     
    177193      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName,
    178194        "Number of iterations for constant optimization. A small number of iterations should be sufficient for most models. " +
    179         "Set to 0 to disable constants optimization.", new IntValue(10)));
     195        "Set to 0 to let the algorithm stop automatically when it converges. Set to -1 to disable constants optimization.", new IntValue(10)));
    180196      Parameters.Add(new FixedValueParameter<BoolValue>(ScaleVariablesParameterName,
    181         "Set to true to scale all input variables to the range [0..1]", new BoolValue(false)));
     197        "Set to true to all input variables to the range [0..1]", new BoolValue(true)));
    182198      Parameters[ScaleVariablesParameterName].Hidden = true;
    183199      Parameters.Add(new FixedValueParameter<DoubleValue>(PunishmentFactorParameterName, "Estimations of models can be bounded. The estimation limits are calculated in the following way (lb = mean(y) - punishmentFactor*range(y), ub = mean(y) + punishmentFactor*range(y))", new DoubleValue(10)));
     
    187203      Parameters[UpdateIntervalParameterName].Hidden = true;
    188204      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName,
    189         "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     205        "Optionally produce a solution at the end of the run", new BoolValue(true)));
    190206      Parameters[CreateSolutionParameterName].Hidden = true;
     207
     208      Parameters.Add(new FixedValueParameter<BoolValue>(CollectParetoOptimalSolutionsParameterName,
     209        "Optionally collect a set of Pareto-optimal solutions minimizing error and complexity.", new BoolValue(false)));
     210      Parameters[CollectParetoOptimalSolutionsParameterName].Hidden = true;
     211
     212      Parameters.Add(new FixedValueParameter<DoubleValue>(LambdaParameterName,
     213        "Lambda is the factor for the regularization term in the objective function (Obj = (y - f(x,p))² + lambda * |p|²)", new DoubleValue(0.0)));
    191214    }
    192215
     
    195218    }
    196219
     220    // TODO: support pause and restart
    197221    protected override void Run(CancellationToken cancellationToken) {
    198222      // Set up the algorithm
    199223      if (SetSeedRandomly) Seed = new System.Random().Next();
     224      var collectPareto = CollectParetoOptimalSolutions;
    200225
    201226      // Set up the results display
     
    229254      var gradEvals = new IntValue();
    230255      Results.Add(new Result("Gradient evaluations", gradEvals));
    231       var paretoBestModelsResult = new Result("ParetoBestModels", typeof(ItemList<ISymbolicRegressionSolution>));
    232       Results.Add(paretoBestModelsResult);
    233 
     256
     257      Result paretoBestModelsResult = new Result("ParetoBestModels", typeof(ItemList<ISymbolicRegressionSolution>));
     258      if (collectPareto) {
     259        Results.Add(paretoBestModelsResult);
     260      }
    234261
    235262      // same as in SymbolicRegressionSingleObjectiveProblem
     
    246273      var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
    247274      if (!AllowedFactors.CheckedItems.Any()) throw new ArgumentException("At least on type of factor must be allowed");
    248       var state = MctsSymbolicRegressionStatic.CreateState(problemData, (uint)Seed, MaxVariableReferences, ScaleVariables, ConstantOptimizationIterations,
    249         Policy,
     275      var state = MctsSymbolicRegressionStatic.CreateState(problemData, (uint)Seed, MaxVariableReferences, ScaleVariables,
     276        ConstantOptimizationIterations, Lambda,
     277        Policy, collectPareto,
    250278        lowerLimit, upperLimit,
    251279        allowProdOfVars: AllowedFactors.CheckedItems.Any(s => s.Value.Value == VariableProductFactorName),
     
    261289      double curBestQ = 0.0;
    262290      int n = 0;
     291
     292      // cancelled before we acutally started
     293      cancellationToken.ThrowIfCancellationRequested();
     294
    263295      // Loop until iteration limit reached or canceled.
    264       for (int i = 0; i < Iterations && !state.Done; i++) {
    265         cancellationToken.ThrowIfCancellationRequested();
    266 
     296      for (int i = 0; i < Iterations && !state.Done && !cancellationToken.IsCancellationRequested; i++) {
    267297        var q = MctsSymbolicRegressionStatic.MakeStep(state);
    268298        sumQ += q; // sum of qs in the last updateinterval iterations
     
    286316          totalRollouts.Value = state.TotalRollouts;
    287317
    288           paretoBestModelsResult.Value = new ItemList<ISymbolicRegressionSolution>(state.ParetoBestModels);
     318          if (collectPareto) {
     319            paretoBestModelsResult.Value = new ItemList<ISymbolicRegressionSolution>(state.ParetoBestModels);
     320          }
    289321
    290322          table.Rows["Best quality"].Values.Add(bestQuality.Value);
     
    296328      }
    297329
    298       // final results
     330      // final results (assumes that at least one iteration was calculated)
    299331      if (n > 0) {
    300332        if (bestQ > bestQuality.Value) {
Note: See TracChangeset for help on using the changeset viewer.