- Timestamp:
- 05/16/21 19:13:10 (4 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/3106_AnalyticContinuedFractionsRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/ContinuedFractionRegression/Algorithm.cs
r17972 r17983 7 7 using HeuristicLab.Core; 8 8 using HeuristicLab.Data; 9 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 10 using HeuristicLab.Parameters; 9 11 using HeuristicLab.Problems.DataAnalysis; 12 using HeuristicLab.Problems.DataAnalysis.Symbolic; 13 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; 10 14 using HeuristicLab.Random; 11 15 … … 15 19 [StorableType("7A375270-EAAF-4AD1-82FF-132318D20E09")] 16 20 public class Algorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> { 21 private const string MutationRateParameterName = "MutationRate"; 22 private const string DepthParameterName = "Depth"; 23 private const string NumGenerationsParameterName = "Depth"; 24 25 #region parameters 26 public IFixedValueParameter<PercentValue> MutationRateParameter => (IFixedValueParameter<PercentValue>)Parameters[MutationRateParameterName]; 27 public double MutationRate { 28 get { return MutationRateParameter.Value.Value; } 29 set { MutationRateParameter.Value.Value = value; } 30 } 31 public IFixedValueParameter<IntValue> DepthParameter => (IFixedValueParameter<IntValue>)Parameters[DepthParameterName]; 32 public int Depth { 33 get { return DepthParameter.Value.Value; } 34 set { DepthParameter.Value.Value = value; } 35 } 36 public IFixedValueParameter<IntValue> NumGenerationsParameter => (IFixedValueParameter<IntValue>)Parameters[NumGenerationsParameterName]; 37 public int NumGenerations { 38 get { return NumGenerationsParameter.Value.Value; } 39 set { NumGenerationsParameter.Value.Value = value; } 40 } 41 #endregion 42 43 // storable ctor 44 [StorableConstructor] 45 public Algorithm(StorableConstructorFlag _) : base(_) { } 46 47 // cloning ctor 48 public Algorithm(Algorithm original, Cloner cloner) : base(original, cloner) { } 49 50 51 // default ctor 52 public Algorithm() : base() { 53 Parameters.Add(new FixedValueParameter<PercentValue>(MutationRateParameterName, "Mutation rate (default 10%)", new PercentValue(0.1))); 54 Parameters.Add(new FixedValueParameter<IntValue>(DepthParameterName, "Depth of the continued fraction representation (default 6)", new IntValue(6))); 55 Parameters.Add(new FixedValueParameter<IntValue>(NumGenerationsParameterName, "The maximum number of generations (default 200)", new IntValue(200))); 56 } 57 17 58 public override IDeepCloneable Clone(Cloner cloner) { 18 59 throw new NotImplementedException(); … … 25 66 problemData.TrainingIndices); 26 67 var nVars = x.GetLength(1) - 1; 27 var rand = new MersenneTwister(31415); 28 CFRAlgorithm(nVars, depth: 6, 0.10, x, out var best, out var bestObj, rand, numGen: 200, stagnatingGens: 5, cancellationToken); 68 var seed = new System.Random().Next(); 69 var rand = new MersenneTwister((uint)seed); 70 CFRAlgorithm(nVars, Depth, MutationRate, x, out var best, out var bestObj, rand, NumGenerations, stagnatingGens: 5, cancellationToken); 29 71 } 30 72 … … 47 89 /* local search optimization of current solutions */ 48 90 foreach (var agent in pop_r.IterateLevels()) { 49 LocalSearchSimplex(agent.current, ref agent.currentObjValue, trainingData, rand); 50 } 51 52 foreach (var agent in pop_r.IteratePostOrder()) agent.MaintainInvariant(); // Deviates from Alg1 in paper91 LocalSearchSimplex(agent.current, ref agent.currentObjValue, trainingData, rand); // CHECK paper states that pocket might also be optimized. Unclear how / when invariants are maintained. 92 } 93 94 foreach (var agent in pop_r.IteratePostOrder()) agent.MaintainInvariant(); // CHECK deviates from Alg1 in paper 53 95 54 96 /* replace old population with evolved population */ … … 56 98 57 99 /* keep track of the best solution */ 58 if (bestObj > pop.pocketObjValue) { 100 if (bestObj > pop.pocketObjValue) { // CHECK: comparison obviously wrong in the paper 59 101 best = pop.pocket; 60 102 bestObj = pop.pocketObjValue; 61 103 bestObjGen = gen; 62 Results.AddOrUpdateResult("MSE (best)", new DoubleValue(bestObj)); 104 // Results.AddOrUpdateResult("MSE (best)", new DoubleValue(bestObj)); 105 // Results.AddOrUpdateResult("Solution", CreateSymbolicRegressionSolution(best, Problem.ProblemData)); 63 106 } 64 107 65 108 66 109 if (gen > bestObjGen + stagnatingGens) { 67 bestObjGen = gen; // wait at least stagnatingGens until resetting again 68 // Reset(pop, nVars, depth, rand, trainingData); 69 InitialPopulation(nVars, depth, rand, trainingData); 70 } 71 } 72 } 110 bestObjGen = gen; // CHECK: unspecified in the paper: wait at least stagnatingGens until resetting again 111 Reset(pop, nVars, depth, rand, trainingData); 112 // InitialPopulation(nVars, depth, rand, trainingData); CHECK reset is not specified in the paper 113 } 114 } 115 } 116 117 73 118 74 119 private Agent InitialPopulation(int nVars, int depth, IRandom rand, double[,] trainingData) { … … 116 161 private Agent RecombinePopulation(Agent pop, IRandom rand, int nVars) { 117 162 var l = pop; 163 118 164 if (pop.children.Count > 0) { 119 165 var s1 = pop.children[0]; 120 166 var s2 = pop.children[1]; 121 167 var s3 = pop.children[2]; 122 l.current = Recombine(l.pocket, s1.current, SelectRandomOp(rand), rand, nVars); 123 s3.current = Recombine(s3.pocket, l.current, SelectRandomOp(rand), rand, nVars); 124 s1.current = Recombine(s1.pocket, s2.current, SelectRandomOp(rand), rand, nVars); 125 s2.current = Recombine(s2.pocket, s3.current, SelectRandomOp(rand), rand, nVars); 126 } 127 128 foreach (var child in pop.children) { 129 RecombinePopulation(child, rand, nVars); 168 169 // CHECK Deviates from paper (recombine all models in the current pop before updating the population) 170 var l_current = Recombine(l.pocket, s1.current, SelectRandomOp(rand), rand, nVars); 171 var s3_current = Recombine(s3.pocket, l.current, SelectRandomOp(rand), rand, nVars); 172 var s1_current = Recombine(s1.pocket, s2.current, SelectRandomOp(rand), rand, nVars); 173 var s2_current = Recombine(s2.pocket, s3.current, SelectRandomOp(rand), rand, nVars); 174 175 // recombination works from top to bottom 176 // CHECK do we use the new current solutions (s1_current .. s3_current) already in the next levels? 177 foreach (var child in pop.children) { 178 RecombinePopulation(child, rand, nVars); 179 } 180 181 l.current = l_current; 182 s3.current = s3_current; 183 s1.current = s1_current; 184 s2.current = s2_current; 130 185 } 131 186 return pop; … … 158 213 private static ContinuedFraction Recombine(ContinuedFraction p1, ContinuedFraction p2, Func<bool[], bool[], bool[]> op, IRandom rand, int nVars) { 159 214 ContinuedFraction ch = new ContinuedFraction() { h = new Term[p1.h.Length] }; 160 /* apply a recombination operator chosen uniformly at random on variable sof two parents into offspring */215 /* apply a recombination operator chosen uniformly at random on variables of two parents into offspring */ 161 216 ch.vars = op(p1.vars, p2.vars); 162 217 … … 168 223 /* recombine coefficient values for variables */ 169 224 var coefx = new double[nVars]; 170 var varsx = new bool[nVars]; // TODO: deviates from paper -> check225 var varsx = new bool[nVars]; // CHECK: deviates from paper, probably forgotten in the pseudo-code 171 226 for (int vi = 1; vi < nVars; vi++) { 172 if (ch.vars[vi]) { 227 if (ch.vars[vi]) { // CHECK: paper uses featAt() 173 228 if (varsa[vi] && varsb[vi]) { 174 229 coefx[vi] = coefa[vi] + (rand.NextDouble() * 5 - 1) * (coefb[vi] - coefa[vi]) / 3.0; … … 190 245 } 191 246 /* update current solution and apply local search */ 192 // return LocalSearchSimplex(ch, trainingData); // Deviates from paper because Alg1 also has LocalSearch after Recombination247 // return LocalSearchSimplex(ch, trainingData); // CHECK: Deviates from paper because Alg1 also has LocalSearch after Recombination 193 248 return ch; 194 249 } … … 220 275 /* Case 1: cfrac variable is turned ON: Turn OFF the variable, and either 'Remove' or 221 276 * 'Remember' the coefficient value at random */ 222 if (cfrac.vars[vIdx]) { 223 h.vars[vIdx] = false; 277 if (cfrac.vars[vIdx]) { // CHECK: paper uses varAt() 278 h.vars[vIdx] = false; // CHECK: paper uses varAt() 224 279 h.coef[vIdx] = coinToss(0, h.coef[vIdx]); 225 280 } else { … … 227 282 * or 'Replace' the coefficient with a random value between -3 and 3 at random */ 228 283 if (!h.vars[vIdx]) { 229 h.vars[vIdx] = true; 284 h.vars[vIdx] = true; // CHECK: paper uses varAt() 230 285 h.coef[vIdx] = coinToss(0, rand.NextDouble() * 6 - 3); 231 286 } … … 233 288 } 234 289 /* toggle the randomly selected variable */ 235 cfrac.vars[vIdx] = !cfrac.vars[vIdx]; 290 cfrac.vars[vIdx] = !cfrac.vars[vIdx]; // CHECK: paper uses varAt() 236 291 } 237 292 238 293 private void ModifyVariable(ContinuedFraction cfrac, IRandom rand) { 239 294 /* randomly select a variable which is turned ON */ 240 var candVars = cfrac.vars.Count(vi => vi); 241 if (candVars == 0) return; // no variable active 242 var vIdx = rand.Next(candVars); 295 var candVars = new List<int>(); 296 for (int i = 0; i < cfrac.vars.Length; i++) if (cfrac.vars[i]) candVars.Add(i); // CHECK: paper uses varAt() 297 if (candVars.Count == 0) return; // no variable active 298 var vIdx = candVars[rand.Next(candVars.Count)]; 243 299 244 300 /* randomly select a term (h) of continued fraction */ … … 246 302 247 303 /* modify the coefficient value*/ 248 if (h.vars[vIdx]) { 304 if (h.vars[vIdx]) { // CHECK: paper uses varAt() 249 305 h.coef[vIdx] = 0.0; 250 306 } else { … … 252 308 } 253 309 /* Toggle the randomly selected variable */ 254 h.vars[vIdx] = !h.vars[vIdx]; 310 h.vars[vIdx] = !h.vars[vIdx]; // CHECK: paper uses varAt() 255 311 } 256 312 … … 268 324 sum += res * res; 269 325 } 270 var delta = 0.1; // TODO326 var delta = 0.1; 271 327 return sum / trainingData.GetLength(0) * (1 + delta * cfrac.vars.Count(vi => vi)); 272 328 } … … 281 337 res = numerator / denom; 282 338 } 339 var h0 = cfrac.h[0]; 340 res += h0.beta + dot(h0.vars, h0.coef, dataPoint); 283 341 return res; 284 342 } … … 329 387 330 388 var newQuality = Evaluate(ch, trainingData); 331 332 // TODO: optionally use regularization (ridge / LASSO)333 389 334 390 if (newQuality < bestQuality) { … … 377 433 } 378 434 } 435 436 Symbol addSy = new Addition(); 437 Symbol mulSy = new Multiplication(); 438 Symbol divSy = new Division(); 439 Symbol startSy = new StartSymbol(); 440 Symbol progSy = new ProgramRootSymbol(); 441 Symbol constSy = new Constant(); 442 Symbol varSy = new Problems.DataAnalysis.Symbolic.Variable(); 443 444 private ISymbolicRegressionSolution CreateSymbolicRegressionSolution(ContinuedFraction cfrac, IRegressionProblemData problemData) { 445 var variables = problemData.AllowedInputVariables.ToArray(); 446 ISymbolicExpressionTreeNode res = null; 447 for (int i = cfrac.h.Length - 1; i > 1; i -= 2) { 448 var hi = cfrac.h[i]; 449 var hi1 = cfrac.h[i - 1]; 450 var denom = CreateLinearCombination(hi.vars, hi.coef, variables, hi.beta); 451 if (res != null) { 452 denom.AddSubtree(res); 453 } 454 455 var numerator = CreateLinearCombination(hi1.vars, hi1.coef, variables, hi1.beta); 456 457 res = divSy.CreateTreeNode(); 458 res.AddSubtree(numerator); 459 res.AddSubtree(denom); 460 } 461 462 var h0 = cfrac.h[0]; 463 var h0Term = CreateLinearCombination(h0.vars, h0.coef, variables, h0.beta); 464 h0Term.AddSubtree(res); 465 466 var progRoot = progSy.CreateTreeNode(); 467 var start = startSy.CreateTreeNode(); 468 progRoot.AddSubtree(start); 469 start.AddSubtree(h0Term); 470 471 var model = new SymbolicRegressionModel(problemData.TargetVariable, new SymbolicExpressionTree(progRoot), new SymbolicDataAnalysisExpressionTreeBatchInterpreter()); 472 var sol = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone()); 473 return sol; 474 } 475 476 private ISymbolicExpressionTreeNode CreateLinearCombination(bool[] vars, double[] coef, string[] variables, double beta) { 477 var sum = addSy.CreateTreeNode(); 478 for (int i = 0; i < vars.Length; i++) { 479 if (vars[i]) { 480 var varNode = (VariableTreeNode)varSy.CreateTreeNode(); 481 varNode.Weight = coef[i]; 482 varNode.VariableName = variables[i]; 483 sum.AddSubtree(varNode); 484 } 485 } 486 sum.AddSubtree(CreateConstant(beta)); 487 return sum; 488 } 489 490 private ISymbolicExpressionTreeNode CreateConstant(double value) { 491 var constNode = (ConstantTreeNode)constSy.CreateTreeNode(); 492 constNode.Value = value; 493 return constNode; 494 } 379 495 } 380 496
Note: See TracChangeset
for help on using the changeset viewer.