Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/16/21 19:13:10 (4 years ago)
Author:
gkronber
Message:

#3106 add parameters

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/3106_AnalyticContinuedFractionsRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/ContinuedFractionRegression/Algorithm.cs

    r17972 r17983  
    77using HeuristicLab.Core;
    88using HeuristicLab.Data;
     9using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     10using HeuristicLab.Parameters;
    911using HeuristicLab.Problems.DataAnalysis;
     12using HeuristicLab.Problems.DataAnalysis.Symbolic;
     13using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    1014using HeuristicLab.Random;
    1115
     
    1519  [StorableType("7A375270-EAAF-4AD1-82FF-132318D20E09")]
    1620  public class Algorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
     21    private const string MutationRateParameterName = "MutationRate";
     22    private const string DepthParameterName = "Depth";
     23    private const string NumGenerationsParameterName = "Depth";
     24
     25    #region parameters
     26    public IFixedValueParameter<PercentValue> MutationRateParameter => (IFixedValueParameter<PercentValue>)Parameters[MutationRateParameterName];
     27    public double MutationRate {
     28      get { return MutationRateParameter.Value.Value; }
     29      set { MutationRateParameter.Value.Value = value; }
     30    }
     31    public IFixedValueParameter<IntValue> DepthParameter => (IFixedValueParameter<IntValue>)Parameters[DepthParameterName];
     32    public int Depth {
     33      get { return DepthParameter.Value.Value; }
     34      set { DepthParameter.Value.Value = value; }
     35    }
     36    public IFixedValueParameter<IntValue> NumGenerationsParameter => (IFixedValueParameter<IntValue>)Parameters[NumGenerationsParameterName];
     37    public int NumGenerations {
     38      get { return NumGenerationsParameter.Value.Value; }
     39      set { NumGenerationsParameter.Value.Value = value; }
     40    }
     41    #endregion
     42
     43    // storable ctor
     44    [StorableConstructor]
     45    public Algorithm(StorableConstructorFlag _) : base(_) { }
     46
     47    // cloning ctor
     48    public Algorithm(Algorithm original, Cloner cloner) : base(original, cloner) { }
     49
     50
     51    // default ctor
     52    public Algorithm() : base() {
     53      Parameters.Add(new FixedValueParameter<PercentValue>(MutationRateParameterName, "Mutation rate (default 10%)", new PercentValue(0.1)));
     54      Parameters.Add(new FixedValueParameter<IntValue>(DepthParameterName, "Depth of the continued fraction representation (default 6)", new IntValue(6)));
     55      Parameters.Add(new FixedValueParameter<IntValue>(NumGenerationsParameterName, "The maximum number of generations (default 200)", new IntValue(200)));
     56    }
     57
    1758    public override IDeepCloneable Clone(Cloner cloner) {
    1859      throw new NotImplementedException();
     
    2566        problemData.TrainingIndices);
    2667      var nVars = x.GetLength(1) - 1;
    27       var rand = new MersenneTwister(31415);
    28       CFRAlgorithm(nVars, depth: 6, 0.10, x, out var best, out var bestObj, rand, numGen: 200, stagnatingGens: 5, cancellationToken);
     68      var seed = new System.Random().Next();
     69      var rand = new MersenneTwister((uint)seed);
     70      CFRAlgorithm(nVars, Depth, MutationRate, x, out var best, out var bestObj, rand, NumGenerations, stagnatingGens: 5, cancellationToken);
    2971    }
    3072
     
    4789        /* local search optimization of current solutions */
    4890        foreach (var agent in pop_r.IterateLevels()) {
    49           LocalSearchSimplex(agent.current, ref agent.currentObjValue, trainingData, rand);
    50         }
    51 
    52         foreach (var agent in pop_r.IteratePostOrder()) agent.MaintainInvariant(); // Deviates from Alg1 in paper
     91          LocalSearchSimplex(agent.current, ref agent.currentObjValue, trainingData, rand); // CHECK paper states that pocket might also be optimized. Unclear how / when invariants are maintained.
     92        }
     93
     94        foreach (var agent in pop_r.IteratePostOrder()) agent.MaintainInvariant(); // CHECK deviates from Alg1 in paper
    5395
    5496        /* replace old population with evolved population */
     
    5698
    5799        /* keep track of the best solution */
    58         if (bestObj > pop.pocketObjValue) {
     100        if (bestObj > pop.pocketObjValue) { // CHECK: comparison obviously wrong in the paper
    59101          best = pop.pocket;
    60102          bestObj = pop.pocketObjValue;
    61103          bestObjGen = gen;
    62           Results.AddOrUpdateResult("MSE (best)", new DoubleValue(bestObj));
     104          // Results.AddOrUpdateResult("MSE (best)", new DoubleValue(bestObj));
     105          // Results.AddOrUpdateResult("Solution", CreateSymbolicRegressionSolution(best, Problem.ProblemData));
    63106        }
    64107
    65108
    66109        if (gen > bestObjGen + stagnatingGens) {
    67           bestObjGen = gen; // wait at least stagnatingGens until resetting again
    68           // Reset(pop, nVars, depth, rand, trainingData);
    69           InitialPopulation(nVars, depth, rand, trainingData);
    70         }
    71       }
    72     }
     110          bestObjGen = gen; // CHECK: unspecified in the paper: wait at least stagnatingGens until resetting again
     111          Reset(pop, nVars, depth, rand, trainingData);
     112          // InitialPopulation(nVars, depth, rand, trainingData); CHECK reset is not specified in the paper
     113        }
     114      }
     115    }
     116
     117
    73118
    74119    private Agent InitialPopulation(int nVars, int depth, IRandom rand, double[,] trainingData) {
     
    116161    private Agent RecombinePopulation(Agent pop, IRandom rand, int nVars) {
    117162      var l = pop;
     163
    118164      if (pop.children.Count > 0) {
    119165        var s1 = pop.children[0];
    120166        var s2 = pop.children[1];
    121167        var s3 = pop.children[2];
    122         l.current = Recombine(l.pocket, s1.current, SelectRandomOp(rand), rand, nVars);
    123         s3.current = Recombine(s3.pocket, l.current, SelectRandomOp(rand), rand, nVars);
    124         s1.current = Recombine(s1.pocket, s2.current, SelectRandomOp(rand), rand, nVars);
    125         s2.current = Recombine(s2.pocket, s3.current, SelectRandomOp(rand), rand, nVars);
    126       }
    127 
    128       foreach (var child in pop.children) {
    129         RecombinePopulation(child, rand, nVars);
     168
     169        // CHECK Deviates from paper (recombine all models in the current pop before updating the population)
     170        var l_current = Recombine(l.pocket, s1.current, SelectRandomOp(rand), rand, nVars);
     171        var s3_current = Recombine(s3.pocket, l.current, SelectRandomOp(rand), rand, nVars);
     172        var s1_current = Recombine(s1.pocket, s2.current, SelectRandomOp(rand), rand, nVars);
     173        var s2_current = Recombine(s2.pocket, s3.current, SelectRandomOp(rand), rand, nVars);
     174
     175        // recombination works from top to bottom
     176        // CHECK do we use the new current solutions (s1_current .. s3_current) already in the next levels?
     177        foreach (var child in pop.children) {
     178          RecombinePopulation(child, rand, nVars);
     179        }
     180
     181        l.current = l_current;
     182        s3.current = s3_current;
     183        s1.current = s1_current;
     184        s2.current = s2_current;
    130185      }
    131186      return pop;
     
    158213    private static ContinuedFraction Recombine(ContinuedFraction p1, ContinuedFraction p2, Func<bool[], bool[], bool[]> op, IRandom rand, int nVars) {
    159214      ContinuedFraction ch = new ContinuedFraction() { h = new Term[p1.h.Length] };
    160       /* apply a recombination operator chosen uniformly at random on variable sof two parents into offspring */
     215      /* apply a recombination operator chosen uniformly at random on variables of two parents into offspring */
    161216      ch.vars = op(p1.vars, p2.vars);
    162217
     
    168223        /* recombine coefficient values for variables */
    169224        var coefx = new double[nVars];
    170         var varsx = new bool[nVars]; // TODO: deviates from paper -> check
     225        var varsx = new bool[nVars]; // CHECK: deviates from paper, probably forgotten in the pseudo-code
    171226        for (int vi = 1; vi < nVars; vi++) {
    172           if (ch.vars[vi]) {
     227          if (ch.vars[vi]) {  // CHECK: paper uses featAt()
    173228            if (varsa[vi] && varsb[vi]) {
    174229              coefx[vi] = coefa[vi] + (rand.NextDouble() * 5 - 1) * (coefb[vi] - coefa[vi]) / 3.0;
     
    190245      }
    191246      /* update current solution and apply local search */
    192       // return LocalSearchSimplex(ch, trainingData); // Deviates from paper because Alg1 also has LocalSearch after Recombination
     247      // return LocalSearchSimplex(ch, trainingData); // CHECK: Deviates from paper because Alg1 also has LocalSearch after Recombination
    193248      return ch;
    194249    }
     
    220275        /* Case 1: cfrac variable is turned ON: Turn OFF the variable, and either 'Remove' or
    221276         * 'Remember' the coefficient value at random */
    222         if (cfrac.vars[vIdx]) {
    223           h.vars[vIdx] = false;
     277        if (cfrac.vars[vIdx]) {  // CHECK: paper uses varAt()
     278          h.vars[vIdx] = false;  // CHECK: paper uses varAt()
    224279          h.coef[vIdx] = coinToss(0, h.coef[vIdx]);
    225280        } else {
     
    227282           * or 'Replace' the coefficient with a random value between -3 and 3 at random */
    228283          if (!h.vars[vIdx]) {
    229             h.vars[vIdx] = true;
     284            h.vars[vIdx] = true;  // CHECK: paper uses varAt()
    230285            h.coef[vIdx] = coinToss(0, rand.NextDouble() * 6 - 3);
    231286          }
     
    233288      }
    234289      /* toggle the randomly selected variable */
    235       cfrac.vars[vIdx] = !cfrac.vars[vIdx];
     290      cfrac.vars[vIdx] = !cfrac.vars[vIdx];  // CHECK: paper uses varAt()
    236291    }
    237292
    238293    private void ModifyVariable(ContinuedFraction cfrac, IRandom rand) {
    239294      /* randomly select a variable which is turned ON */
    240       var candVars = cfrac.vars.Count(vi => vi);
    241       if (candVars == 0) return; // no variable active
    242       var vIdx = rand.Next(candVars);
     295      var candVars = new List<int>();
     296      for (int i = 0; i < cfrac.vars.Length; i++) if (cfrac.vars[i]) candVars.Add(i);  // CHECK: paper uses varAt()
     297      if (candVars.Count == 0) return; // no variable active
     298      var vIdx = candVars[rand.Next(candVars.Count)];
    243299
    244300      /* randomly select a term (h) of continued fraction */
     
    246302
    247303      /* modify the coefficient value*/
    248       if (h.vars[vIdx]) {
     304      if (h.vars[vIdx]) {  // CHECK: paper uses varAt()
    249305        h.coef[vIdx] = 0.0;
    250306      } else {
     
    252308      }
    253309      /* Toggle the randomly selected variable */
    254       h.vars[vIdx] = !h.vars[vIdx];
     310      h.vars[vIdx] = !h.vars[vIdx]; // CHECK: paper uses varAt()
    255311    }
    256312
     
    268324        sum += res * res;
    269325      }
    270       var delta = 0.1; // TODO
     326      var delta = 0.1;
    271327      return sum / trainingData.GetLength(0) * (1 + delta * cfrac.vars.Count(vi => vi));
    272328    }
     
    281337        res = numerator / denom;
    282338      }
     339      var h0 = cfrac.h[0];
     340      res += h0.beta + dot(h0.vars, h0.coef, dataPoint);
    283341      return res;
    284342    }
     
    329387
    330388        var newQuality = Evaluate(ch, trainingData);
    331 
    332         // TODO: optionally use regularization (ridge / LASSO)
    333389
    334390        if (newQuality < bestQuality) {
     
    377433      }
    378434    }
     435
     436    Symbol addSy = new Addition();
     437    Symbol mulSy = new Multiplication();
     438    Symbol divSy = new Division();
     439    Symbol startSy = new StartSymbol();
     440    Symbol progSy = new ProgramRootSymbol();
     441    Symbol constSy = new Constant();
     442    Symbol varSy = new Problems.DataAnalysis.Symbolic.Variable();
     443
     444    private ISymbolicRegressionSolution CreateSymbolicRegressionSolution(ContinuedFraction cfrac, IRegressionProblemData problemData) {
     445      var variables = problemData.AllowedInputVariables.ToArray();
     446      ISymbolicExpressionTreeNode res = null;
     447      for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
     448        var hi = cfrac.h[i];
     449        var hi1 = cfrac.h[i - 1];
     450        var denom = CreateLinearCombination(hi.vars, hi.coef, variables, hi.beta);
     451        if (res != null) {
     452          denom.AddSubtree(res);
     453        }
     454
     455        var numerator = CreateLinearCombination(hi1.vars, hi1.coef, variables, hi1.beta);
     456
     457        res = divSy.CreateTreeNode();
     458        res.AddSubtree(numerator);
     459        res.AddSubtree(denom);
     460      }
     461
     462      var h0 = cfrac.h[0];
     463      var h0Term = CreateLinearCombination(h0.vars, h0.coef, variables, h0.beta);
     464      h0Term.AddSubtree(res);
     465
     466      var progRoot = progSy.CreateTreeNode();
     467      var start = startSy.CreateTreeNode();
     468      progRoot.AddSubtree(start);
     469      start.AddSubtree(h0Term);
     470
     471      var model = new SymbolicRegressionModel(problemData.TargetVariable, new SymbolicExpressionTree(progRoot), new SymbolicDataAnalysisExpressionTreeBatchInterpreter());
     472      var sol = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
     473      return sol;
     474    }
     475
     476    private ISymbolicExpressionTreeNode CreateLinearCombination(bool[] vars, double[] coef, string[] variables, double beta) {
     477      var sum = addSy.CreateTreeNode();
     478      for (int i = 0; i < vars.Length; i++) {
     479        if (vars[i]) {
     480          var varNode = (VariableTreeNode)varSy.CreateTreeNode();
     481          varNode.Weight = coef[i];
     482          varNode.VariableName = variables[i];
     483          sum.AddSubtree(varNode);
     484        }
     485      }
     486      sum.AddSubtree(CreateConstant(beta));
     487      return sum;
     488    }
     489
     490    private ISymbolicExpressionTreeNode CreateConstant(double value) {
     491      var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
     492      constNode.Value = value;
     493      return constNode;
     494    }
    379495  }
    380496
Note: See TracChangeset for help on using the changeset viewer.