[17971] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Linq;
|
---|
| 4 | using System.Threading;
|
---|
| 5 | using HEAL.Attic;
|
---|
| 6 | using HeuristicLab.Common;
|
---|
| 7 | using HeuristicLab.Core;
|
---|
| 8 | using HeuristicLab.Data;
|
---|
[17983] | 9 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
| 10 | using HeuristicLab.Parameters;
|
---|
[17971] | 11 | using HeuristicLab.Problems.DataAnalysis;
|
---|
[17983] | 12 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
| 13 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
[17971] | 14 | using HeuristicLab.Random;
|
---|
| 15 |
|
---|
| 16 | namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression {
|
---|
| 17 | [Item("Continuous Fraction Regression (CFR)", "TODO")]
|
---|
| 18 | [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 999)]
|
---|
| 19 | [StorableType("7A375270-EAAF-4AD1-82FF-132318D20E09")]
|
---|
| 20 | public class Algorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
|
---|
[17983] | 21 | private const string MutationRateParameterName = "MutationRate";
|
---|
| 22 | private const string DepthParameterName = "Depth";
|
---|
| 23 | private const string NumGenerationsParameterName = "Depth";
|
---|
| 24 |
|
---|
| 25 | #region parameters
|
---|
| 26 | public IFixedValueParameter<PercentValue> MutationRateParameter => (IFixedValueParameter<PercentValue>)Parameters[MutationRateParameterName];
|
---|
| 27 | public double MutationRate {
|
---|
| 28 | get { return MutationRateParameter.Value.Value; }
|
---|
| 29 | set { MutationRateParameter.Value.Value = value; }
|
---|
| 30 | }
|
---|
| 31 | public IFixedValueParameter<IntValue> DepthParameter => (IFixedValueParameter<IntValue>)Parameters[DepthParameterName];
|
---|
| 32 | public int Depth {
|
---|
| 33 | get { return DepthParameter.Value.Value; }
|
---|
| 34 | set { DepthParameter.Value.Value = value; }
|
---|
| 35 | }
|
---|
| 36 | public IFixedValueParameter<IntValue> NumGenerationsParameter => (IFixedValueParameter<IntValue>)Parameters[NumGenerationsParameterName];
|
---|
| 37 | public int NumGenerations {
|
---|
| 38 | get { return NumGenerationsParameter.Value.Value; }
|
---|
| 39 | set { NumGenerationsParameter.Value.Value = value; }
|
---|
| 40 | }
|
---|
| 41 | #endregion
|
---|
| 42 |
|
---|
| 43 | // storable ctor
|
---|
| 44 | [StorableConstructor]
|
---|
| 45 | public Algorithm(StorableConstructorFlag _) : base(_) { }
|
---|
| 46 |
|
---|
| 47 | // cloning ctor
|
---|
| 48 | public Algorithm(Algorithm original, Cloner cloner) : base(original, cloner) { }
|
---|
| 49 |
|
---|
| 50 |
|
---|
| 51 | // default ctor
|
---|
| 52 | public Algorithm() : base() {
|
---|
| 53 | Parameters.Add(new FixedValueParameter<PercentValue>(MutationRateParameterName, "Mutation rate (default 10%)", new PercentValue(0.1)));
|
---|
| 54 | Parameters.Add(new FixedValueParameter<IntValue>(DepthParameterName, "Depth of the continued fraction representation (default 6)", new IntValue(6)));
|
---|
| 55 | Parameters.Add(new FixedValueParameter<IntValue>(NumGenerationsParameterName, "The maximum number of generations (default 200)", new IntValue(200)));
|
---|
| 56 | }
|
---|
| 57 |
|
---|
[17971] | 58 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 59 | throw new NotImplementedException();
|
---|
| 60 | }
|
---|
| 61 |
|
---|
| 62 | protected override void Run(CancellationToken cancellationToken) {
|
---|
| 63 | var problemData = Problem.ProblemData;
|
---|
| 64 |
|
---|
| 65 | var x = problemData.Dataset.ToArray(problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }),
|
---|
| 66 | problemData.TrainingIndices);
|
---|
[17972] | 67 | var nVars = x.GetLength(1) - 1;
|
---|
[17983] | 68 | var seed = new System.Random().Next();
|
---|
| 69 | var rand = new MersenneTwister((uint)seed);
|
---|
| 70 | CFRAlgorithm(nVars, Depth, MutationRate, x, out var best, out var bestObj, rand, NumGenerations, stagnatingGens: 5, cancellationToken);
|
---|
[17971] | 71 | }
|
---|
| 72 |
|
---|
| 73 | private void CFRAlgorithm(int nVars, int depth, double mutationRate, double[,] trainingData,
|
---|
| 74 | out ContinuedFraction best, out double bestObj,
|
---|
[17972] | 75 | IRandom rand, int numGen, int stagnatingGens,
|
---|
| 76 | CancellationToken cancellationToken) {
|
---|
[17971] | 77 | /* Algorithm 1 */
|
---|
| 78 | /* Generate initial population by a randomized algorithm */
|
---|
| 79 | var pop = InitialPopulation(nVars, depth, rand, trainingData);
|
---|
| 80 | best = pop.pocket;
|
---|
| 81 | bestObj = pop.pocketObjValue;
|
---|
[17972] | 82 | var bestObjGen = 0;
|
---|
| 83 | for (int gen = 1; gen <= numGen && !cancellationToken.IsCancellationRequested; gen++) {
|
---|
[17971] | 84 | /* mutate each current solution in the population */
|
---|
| 85 | var pop_mu = Mutate(pop, mutationRate, rand);
|
---|
| 86 | /* generate new population by recombination mechanism */
|
---|
| 87 | var pop_r = RecombinePopulation(pop_mu, rand, nVars);
|
---|
| 88 |
|
---|
| 89 | /* local search optimization of current solutions */
|
---|
| 90 | foreach (var agent in pop_r.IterateLevels()) {
|
---|
[17983] | 91 | LocalSearchSimplex(agent.current, ref agent.currentObjValue, trainingData, rand); // CHECK paper states that pocket might also be optimized. Unclear how / when invariants are maintained.
|
---|
[17971] | 92 | }
|
---|
| 93 |
|
---|
[17983] | 94 | foreach (var agent in pop_r.IteratePostOrder()) agent.MaintainInvariant(); // CHECK deviates from Alg1 in paper
|
---|
[17971] | 95 |
|
---|
| 96 | /* replace old population with evolved population */
|
---|
| 97 | pop = pop_r;
|
---|
| 98 |
|
---|
| 99 | /* keep track of the best solution */
|
---|
[17983] | 100 | if (bestObj > pop.pocketObjValue) { // CHECK: comparison obviously wrong in the paper
|
---|
[17971] | 101 | best = pop.pocket;
|
---|
| 102 | bestObj = pop.pocketObjValue;
|
---|
[17972] | 103 | bestObjGen = gen;
|
---|
[17983] | 104 | // Results.AddOrUpdateResult("MSE (best)", new DoubleValue(bestObj));
|
---|
| 105 | // Results.AddOrUpdateResult("Solution", CreateSymbolicRegressionSolution(best, Problem.ProblemData));
|
---|
[17971] | 106 | }
|
---|
[17972] | 107 |
|
---|
| 108 |
|
---|
| 109 | if (gen > bestObjGen + stagnatingGens) {
|
---|
[17983] | 110 | bestObjGen = gen; // CHECK: unspecified in the paper: wait at least stagnatingGens until resetting again
|
---|
| 111 | Reset(pop, nVars, depth, rand, trainingData);
|
---|
| 112 | // InitialPopulation(nVars, depth, rand, trainingData); CHECK reset is not specified in the paper
|
---|
[17972] | 113 | }
|
---|
[17971] | 114 | }
|
---|
| 115 | }
|
---|
| 116 |
|
---|
[17983] | 117 |
|
---|
| 118 |
|
---|
[17971] | 119 | private Agent InitialPopulation(int nVars, int depth, IRandom rand, double[,] trainingData) {
|
---|
| 120 | /* instantiate 13 agents in the population */
|
---|
| 121 | var pop = new Agent();
|
---|
| 122 | // see Figure 2
|
---|
| 123 | for (int i = 0; i < 3; i++) {
|
---|
| 124 | pop.children.Add(new Agent());
|
---|
| 125 | for (int j = 0; j < 3; j++) {
|
---|
| 126 | pop.children[i].children.Add(new Agent());
|
---|
| 127 | }
|
---|
| 128 | }
|
---|
| 129 |
|
---|
| 130 | foreach (var agent in pop.IteratePostOrder()) {
|
---|
| 131 | agent.current = new ContinuedFraction(nVars, depth, rand);
|
---|
| 132 | agent.pocket = new ContinuedFraction(nVars, depth, rand);
|
---|
| 133 |
|
---|
| 134 | agent.currentObjValue = Evaluate(agent.current, trainingData);
|
---|
| 135 | agent.pocketObjValue = Evaluate(agent.pocket, trainingData);
|
---|
| 136 |
|
---|
| 137 | /* within each agent, the pocket solution always holds the better value of guiding
|
---|
| 138 | * function than its current solution
|
---|
| 139 | */
|
---|
| 140 | agent.MaintainInvariant();
|
---|
| 141 | }
|
---|
| 142 | return pop;
|
---|
| 143 | }
|
---|
| 144 |
|
---|
[17972] | 145 | // TODO: reset is not described in the paper
|
---|
| 146 | private void Reset(Agent root, int nVars, int depth, IRandom rand, double[,] trainingData) {
|
---|
| 147 | root.pocket = new ContinuedFraction(nVars, depth, rand);
|
---|
| 148 | root.current = new ContinuedFraction(nVars, depth, rand);
|
---|
| 149 |
|
---|
| 150 | root.currentObjValue = Evaluate(root.current, trainingData);
|
---|
| 151 | root.pocketObjValue = Evaluate(root.pocket, trainingData);
|
---|
| 152 |
|
---|
| 153 | /* within each agent, the pocket solution always holds the better value of guiding
|
---|
| 154 | * function than its current solution
|
---|
| 155 | */
|
---|
| 156 | root.MaintainInvariant();
|
---|
| 157 | }
|
---|
| 158 |
|
---|
| 159 |
|
---|
| 160 |
|
---|
[17971] | 161 | private Agent RecombinePopulation(Agent pop, IRandom rand, int nVars) {
|
---|
| 162 | var l = pop;
|
---|
[17983] | 163 |
|
---|
[17971] | 164 | if (pop.children.Count > 0) {
|
---|
| 165 | var s1 = pop.children[0];
|
---|
| 166 | var s2 = pop.children[1];
|
---|
| 167 | var s3 = pop.children[2];
|
---|
| 168 |
|
---|
[17983] | 169 | // CHECK Deviates from paper (recombine all models in the current pop before updating the population)
|
---|
| 170 | var l_current = Recombine(l.pocket, s1.current, SelectRandomOp(rand), rand, nVars);
|
---|
| 171 | var s3_current = Recombine(s3.pocket, l.current, SelectRandomOp(rand), rand, nVars);
|
---|
| 172 | var s1_current = Recombine(s1.pocket, s2.current, SelectRandomOp(rand), rand, nVars);
|
---|
| 173 | var s2_current = Recombine(s2.pocket, s3.current, SelectRandomOp(rand), rand, nVars);
|
---|
| 174 |
|
---|
| 175 | // recombination works from top to bottom
|
---|
| 176 | // CHECK do we use the new current solutions (s1_current .. s3_current) already in the next levels?
|
---|
| 177 | foreach (var child in pop.children) {
|
---|
| 178 | RecombinePopulation(child, rand, nVars);
|
---|
| 179 | }
|
---|
| 180 |
|
---|
| 181 | l.current = l_current;
|
---|
| 182 | s3.current = s3_current;
|
---|
| 183 | s1.current = s1_current;
|
---|
| 184 | s2.current = s2_current;
|
---|
[17971] | 185 | }
|
---|
| 186 | return pop;
|
---|
| 187 | }
|
---|
| 188 |
|
---|
| 189 | private Func<bool[], bool[], bool[]> SelectRandomOp(IRandom rand) {
|
---|
| 190 | bool[] union(bool[] a, bool[] b) {
|
---|
| 191 | var res = new bool[a.Length];
|
---|
| 192 | for (int i = 0; i < a.Length; i++) res[i] = a[i] || b[i];
|
---|
| 193 | return res;
|
---|
| 194 | }
|
---|
| 195 | bool[] intersect(bool[] a, bool[] b) {
|
---|
| 196 | var res = new bool[a.Length];
|
---|
| 197 | for (int i = 0; i < a.Length; i++) res[i] = a[i] && b[i];
|
---|
| 198 | return res;
|
---|
| 199 | }
|
---|
| 200 | bool[] symmetricDifference(bool[] a, bool[] b) {
|
---|
| 201 | var res = new bool[a.Length];
|
---|
| 202 | for (int i = 0; i < a.Length; i++) res[i] = a[i] ^ b[i];
|
---|
| 203 | return res;
|
---|
| 204 | }
|
---|
| 205 | switch (rand.Next(3)) {
|
---|
| 206 | case 0: return union;
|
---|
| 207 | case 1: return intersect;
|
---|
| 208 | case 2: return symmetricDifference;
|
---|
| 209 | default: throw new ArgumentException();
|
---|
| 210 | }
|
---|
| 211 | }
|
---|
| 212 |
|
---|
| 213 | private static ContinuedFraction Recombine(ContinuedFraction p1, ContinuedFraction p2, Func<bool[], bool[], bool[]> op, IRandom rand, int nVars) {
|
---|
| 214 | ContinuedFraction ch = new ContinuedFraction() { h = new Term[p1.h.Length] };
|
---|
[17983] | 215 | /* apply a recombination operator chosen uniformly at random on variables of two parents into offspring */
|
---|
[17971] | 216 | ch.vars = op(p1.vars, p2.vars);
|
---|
| 217 |
|
---|
| 218 | /* recombine the coefficients for each term (h) of the continued fraction */
|
---|
| 219 | for (int i = 0; i < p1.h.Length; i++) {
|
---|
| 220 | var coefa = p1.h[i].coef; var varsa = p1.h[i].vars;
|
---|
| 221 | var coefb = p2.h[i].coef; var varsb = p2.h[i].vars;
|
---|
| 222 |
|
---|
| 223 | /* recombine coefficient values for variables */
|
---|
| 224 | var coefx = new double[nVars];
|
---|
[17983] | 225 | var varsx = new bool[nVars]; // CHECK: deviates from paper, probably forgotten in the pseudo-code
|
---|
[17971] | 226 | for (int vi = 1; vi < nVars; vi++) {
|
---|
[17983] | 227 | if (ch.vars[vi]) { // CHECK: paper uses featAt()
|
---|
[17971] | 228 | if (varsa[vi] && varsb[vi]) {
|
---|
| 229 | coefx[vi] = coefa[vi] + (rand.NextDouble() * 5 - 1) * (coefb[vi] - coefa[vi]) / 3.0;
|
---|
| 230 | varsx[vi] = true;
|
---|
| 231 | } else if (varsa[vi]) {
|
---|
| 232 | coefx[vi] = coefa[vi];
|
---|
| 233 | varsx[vi] = true;
|
---|
| 234 | } else if (varsb[vi]) {
|
---|
| 235 | coefx[vi] = coefb[vi];
|
---|
| 236 | varsx[vi] = true;
|
---|
| 237 | }
|
---|
| 238 | }
|
---|
| 239 | }
|
---|
| 240 | /* update new coefficients of the term in offspring */
|
---|
| 241 | ch.h[i] = new Term() { coef = coefx, vars = varsx };
|
---|
| 242 | /* compute new value of constant (beta) for term hi in the offspring solution ch using
|
---|
| 243 | * beta of p1.hi and p2.hi */
|
---|
| 244 | ch.h[i].beta = p1.h[i].beta + (rand.NextDouble() * 5 - 1) * (p2.h[i].beta - p1.h[i].beta) / 3.0;
|
---|
| 245 | }
|
---|
| 246 | /* update current solution and apply local search */
|
---|
[17983] | 247 | // return LocalSearchSimplex(ch, trainingData); // CHECK: Deviates from paper because Alg1 also has LocalSearch after Recombination
|
---|
[17971] | 248 | return ch;
|
---|
| 249 | }
|
---|
| 250 |
|
---|
| 251 | private Agent Mutate(Agent pop, double mutationRate, IRandom rand) {
|
---|
| 252 | foreach (var agent in pop.IterateLevels()) {
|
---|
| 253 | if (rand.NextDouble() < mutationRate) {
|
---|
| 254 | if (agent.currentObjValue < 1.2 * agent.pocketObjValue ||
|
---|
| 255 | agent.currentObjValue > 2 * agent.pocketObjValue)
|
---|
| 256 | ToggleVariables(agent.current, rand); // major mutation
|
---|
[17972] | 257 | else
|
---|
| 258 | ModifyVariable(agent.current, rand); // soft mutation
|
---|
[17971] | 259 | }
|
---|
| 260 | }
|
---|
| 261 | return pop;
|
---|
| 262 | }
|
---|
| 263 |
|
---|
| 264 | private void ToggleVariables(ContinuedFraction cfrac, IRandom rand) {
|
---|
| 265 | double coinToss(double a, double b) {
|
---|
| 266 | return rand.NextDouble() < 0.5 ? a : b;
|
---|
| 267 | }
|
---|
| 268 |
|
---|
| 269 | /* select a variable index uniformly at random */
|
---|
| 270 | int N = cfrac.vars.Length;
|
---|
| 271 | var vIdx = rand.Next(N);
|
---|
| 272 |
|
---|
| 273 | /* for each depth of continued fraction, toggle the selection of variables of the term (h) */
|
---|
| 274 | foreach (var h in cfrac.h) {
|
---|
| 275 | /* Case 1: cfrac variable is turned ON: Turn OFF the variable, and either 'Remove' or
|
---|
| 276 | * 'Remember' the coefficient value at random */
|
---|
[17983] | 277 | if (cfrac.vars[vIdx]) { // CHECK: paper uses varAt()
|
---|
| 278 | h.vars[vIdx] = false; // CHECK: paper uses varAt()
|
---|
[17971] | 279 | h.coef[vIdx] = coinToss(0, h.coef[vIdx]);
|
---|
| 280 | } else {
|
---|
| 281 | /* Case 2: term variable is turned OFF: Turn ON the variable, and either 'Remove'
|
---|
| 282 | * or 'Replace' the coefficient with a random value between -3 and 3 at random */
|
---|
| 283 | if (!h.vars[vIdx]) {
|
---|
[17983] | 284 | h.vars[vIdx] = true; // CHECK: paper uses varAt()
|
---|
[17971] | 285 | h.coef[vIdx] = coinToss(0, rand.NextDouble() * 6 - 3);
|
---|
| 286 | }
|
---|
| 287 | }
|
---|
| 288 | }
|
---|
| 289 | /* toggle the randomly selected variable */
|
---|
[17983] | 290 | cfrac.vars[vIdx] = !cfrac.vars[vIdx]; // CHECK: paper uses varAt()
|
---|
[17971] | 291 | }
|
---|
| 292 |
|
---|
| 293 | private void ModifyVariable(ContinuedFraction cfrac, IRandom rand) {
|
---|
| 294 | /* randomly select a variable which is turned ON */
|
---|
[17983] | 295 | var candVars = new List<int>();
|
---|
| 296 | for (int i = 0; i < cfrac.vars.Length; i++) if (cfrac.vars[i]) candVars.Add(i); // CHECK: paper uses varAt()
|
---|
| 297 | if (candVars.Count == 0) return; // no variable active
|
---|
| 298 | var vIdx = candVars[rand.Next(candVars.Count)];
|
---|
[17971] | 299 |
|
---|
| 300 | /* randomly select a term (h) of continued fraction */
|
---|
| 301 | var h = cfrac.h[rand.Next(cfrac.h.Length)];
|
---|
| 302 |
|
---|
| 303 | /* modify the coefficient value*/
|
---|
[17983] | 304 | if (h.vars[vIdx]) { // CHECK: paper uses varAt()
|
---|
[17971] | 305 | h.coef[vIdx] = 0.0;
|
---|
| 306 | } else {
|
---|
| 307 | h.coef[vIdx] = rand.NextDouble() * 6 - 3;
|
---|
| 308 | }
|
---|
| 309 | /* Toggle the randomly selected variable */
|
---|
[17983] | 310 | h.vars[vIdx] = !h.vars[vIdx]; // CHECK: paper uses varAt()
|
---|
[17971] | 311 | }
|
---|
| 312 |
|
---|
| 313 | private static double Evaluate(ContinuedFraction cfrac, double[,] trainingData) {
|
---|
| 314 | var dataPoint = new double[trainingData.GetLength(1) - 1];
|
---|
| 315 | var yIdx = trainingData.GetLength(1) - 1;
|
---|
| 316 | double sum = 0.0;
|
---|
| 317 | for (int r = 0; r < trainingData.GetLength(0); r++) {
|
---|
| 318 | for (int c = 0; c < dataPoint.Length; c++) {
|
---|
| 319 | dataPoint[c] = trainingData[r, c];
|
---|
| 320 | }
|
---|
| 321 | var y = trainingData[r, yIdx];
|
---|
| 322 | var pred = Evaluate(cfrac, dataPoint);
|
---|
| 323 | var res = y - pred;
|
---|
| 324 | sum += res * res;
|
---|
| 325 | }
|
---|
[17983] | 326 | var delta = 0.1;
|
---|
[17971] | 327 | return sum / trainingData.GetLength(0) * (1 + delta * cfrac.vars.Count(vi => vi));
|
---|
| 328 | }
|
---|
| 329 |
|
---|
| 330 | private static double Evaluate(ContinuedFraction cfrac, double[] dataPoint) {
|
---|
| 331 | var res = 0.0;
|
---|
| 332 | for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
|
---|
| 333 | var hi = cfrac.h[i];
|
---|
| 334 | var hi1 = cfrac.h[i - 1];
|
---|
[17972] | 335 | var denom = hi.beta + dot(hi.vars, hi.coef, dataPoint) + res;
|
---|
| 336 | var numerator = hi1.beta + dot(hi1.vars, hi1.coef, dataPoint);
|
---|
[17971] | 337 | res = numerator / denom;
|
---|
| 338 | }
|
---|
[17983] | 339 | var h0 = cfrac.h[0];
|
---|
| 340 | res += h0.beta + dot(h0.vars, h0.coef, dataPoint);
|
---|
[17971] | 341 | return res;
|
---|
| 342 | }
|
---|
| 343 |
|
---|
[17972] | 344 | private static double dot(bool[] filter, double[] x, double[] y) {
|
---|
[17971] | 345 | var s = 0.0;
|
---|
| 346 | for (int i = 0; i < x.Length; i++)
|
---|
[17972] | 347 | if (filter[i])
|
---|
| 348 | s += x[i] * y[i];
|
---|
[17971] | 349 | return s;
|
---|
| 350 | }
|
---|
| 351 |
|
---|
| 352 |
|
---|
[17972] | 353 | private static void LocalSearchSimplex(ContinuedFraction ch, ref double quality, double[,] trainingData, IRandom rand) {
|
---|
[17971] | 354 | double uniformPeturbation = 1.0;
|
---|
| 355 | double tolerance = 1e-3;
|
---|
| 356 | int maxEvals = 250;
|
---|
| 357 | int numSearches = 4;
|
---|
| 358 | var numRows = trainingData.GetLength(0);
|
---|
| 359 | int numSelectedRows = numRows / 5; // 20% of the training samples
|
---|
| 360 |
|
---|
[17972] | 361 | quality = Evaluate(ch, trainingData); // get quality with origial coefficients
|
---|
| 362 |
|
---|
[17971] | 363 | double[] origCoeff = ExtractCoeff(ch);
|
---|
[17972] | 364 | if (origCoeff.Length == 0) return; // no parameters to optimize
|
---|
[17971] | 365 |
|
---|
[17972] | 366 | var bestQuality = quality;
|
---|
[17971] | 367 | var bestCoeff = origCoeff;
|
---|
| 368 |
|
---|
| 369 | var fittingData = SelectRandomRows(trainingData, numSelectedRows, rand);
|
---|
| 370 |
|
---|
| 371 | double objFunc(double[] curCoeff) {
|
---|
| 372 | SetCoeff(ch, curCoeff);
|
---|
| 373 | return Evaluate(ch, fittingData);
|
---|
| 374 | }
|
---|
| 375 |
|
---|
| 376 | for (int count = 0; count < numSearches; count++) {
|
---|
| 377 |
|
---|
| 378 | SimplexConstant[] constants = new SimplexConstant[origCoeff.Length];
|
---|
| 379 | for (int i = 0; i < origCoeff.Length; i++) {
|
---|
| 380 | constants[i] = new SimplexConstant(origCoeff[i], uniformPeturbation);
|
---|
| 381 | }
|
---|
| 382 |
|
---|
| 383 | RegressionResult result = NelderMeadSimplex.Regress(constants, tolerance, maxEvals, objFunc);
|
---|
| 384 |
|
---|
| 385 | var optimizedCoeff = result.Constants;
|
---|
| 386 | SetCoeff(ch, optimizedCoeff);
|
---|
| 387 |
|
---|
| 388 | var newQuality = Evaluate(ch, trainingData);
|
---|
| 389 |
|
---|
| 390 | if (newQuality < bestQuality) {
|
---|
| 391 | bestCoeff = optimizedCoeff;
|
---|
| 392 | bestQuality = newQuality;
|
---|
| 393 | }
|
---|
| 394 | } // reps
|
---|
| 395 |
|
---|
| 396 | SetCoeff(ch, bestCoeff);
|
---|
[17972] | 397 | quality = bestQuality;
|
---|
[17971] | 398 | }
|
---|
| 399 |
|
---|
| 400 | private static double[,] SelectRandomRows(double[,] trainingData, int numSelectedRows, IRandom rand) {
|
---|
| 401 | var numRows = trainingData.GetLength(0);
|
---|
| 402 | var numCols = trainingData.GetLength(1);
|
---|
| 403 | var selectedRows = Enumerable.Range(0, numRows).Shuffle(rand).Take(numSelectedRows).ToArray();
|
---|
| 404 | var subset = new double[numSelectedRows, numCols];
|
---|
| 405 | var i = 0;
|
---|
| 406 | foreach (var r in selectedRows) {
|
---|
| 407 | for (int c = 0; c < numCols; c++) {
|
---|
| 408 | subset[i, c] = trainingData[r, c];
|
---|
| 409 | }
|
---|
| 410 | i++;
|
---|
| 411 | }
|
---|
| 412 | return subset;
|
---|
| 413 | }
|
---|
| 414 |
|
---|
| 415 | private static double[] ExtractCoeff(ContinuedFraction ch) {
|
---|
| 416 | var coeff = new List<double>();
|
---|
| 417 | foreach (var hi in ch.h) {
|
---|
[17972] | 418 | coeff.Add(hi.beta);
|
---|
[17971] | 419 | for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
|
---|
| 420 | if (hi.vars[vIdx]) coeff.Add(hi.coef[vIdx]);
|
---|
| 421 | }
|
---|
| 422 | }
|
---|
| 423 | return coeff.ToArray();
|
---|
| 424 | }
|
---|
| 425 |
|
---|
| 426 | private static void SetCoeff(ContinuedFraction ch, double[] curCoeff) {
|
---|
| 427 | int k = 0;
|
---|
| 428 | foreach (var hi in ch.h) {
|
---|
[17972] | 429 | hi.beta = curCoeff[k++];
|
---|
[17971] | 430 | for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
|
---|
| 431 | if (hi.vars[vIdx]) hi.coef[vIdx] = curCoeff[k++];
|
---|
| 432 | }
|
---|
| 433 | }
|
---|
| 434 | }
|
---|
[17983] | 435 |
|
---|
| 436 | Symbol addSy = new Addition();
|
---|
| 437 | Symbol mulSy = new Multiplication();
|
---|
| 438 | Symbol divSy = new Division();
|
---|
| 439 | Symbol startSy = new StartSymbol();
|
---|
| 440 | Symbol progSy = new ProgramRootSymbol();
|
---|
| 441 | Symbol constSy = new Constant();
|
---|
| 442 | Symbol varSy = new Problems.DataAnalysis.Symbolic.Variable();
|
---|
| 443 |
|
---|
| 444 | private ISymbolicRegressionSolution CreateSymbolicRegressionSolution(ContinuedFraction cfrac, IRegressionProblemData problemData) {
|
---|
| 445 | var variables = problemData.AllowedInputVariables.ToArray();
|
---|
| 446 | ISymbolicExpressionTreeNode res = null;
|
---|
| 447 | for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
|
---|
| 448 | var hi = cfrac.h[i];
|
---|
| 449 | var hi1 = cfrac.h[i - 1];
|
---|
| 450 | var denom = CreateLinearCombination(hi.vars, hi.coef, variables, hi.beta);
|
---|
| 451 | if (res != null) {
|
---|
| 452 | denom.AddSubtree(res);
|
---|
| 453 | }
|
---|
| 454 |
|
---|
| 455 | var numerator = CreateLinearCombination(hi1.vars, hi1.coef, variables, hi1.beta);
|
---|
| 456 |
|
---|
| 457 | res = divSy.CreateTreeNode();
|
---|
| 458 | res.AddSubtree(numerator);
|
---|
| 459 | res.AddSubtree(denom);
|
---|
| 460 | }
|
---|
| 461 |
|
---|
| 462 | var h0 = cfrac.h[0];
|
---|
| 463 | var h0Term = CreateLinearCombination(h0.vars, h0.coef, variables, h0.beta);
|
---|
| 464 | h0Term.AddSubtree(res);
|
---|
| 465 |
|
---|
| 466 | var progRoot = progSy.CreateTreeNode();
|
---|
| 467 | var start = startSy.CreateTreeNode();
|
---|
| 468 | progRoot.AddSubtree(start);
|
---|
| 469 | start.AddSubtree(h0Term);
|
---|
| 470 |
|
---|
| 471 | var model = new SymbolicRegressionModel(problemData.TargetVariable, new SymbolicExpressionTree(progRoot), new SymbolicDataAnalysisExpressionTreeBatchInterpreter());
|
---|
| 472 | var sol = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
|
---|
| 473 | return sol;
|
---|
| 474 | }
|
---|
| 475 |
|
---|
| 476 | private ISymbolicExpressionTreeNode CreateLinearCombination(bool[] vars, double[] coef, string[] variables, double beta) {
|
---|
| 477 | var sum = addSy.CreateTreeNode();
|
---|
| 478 | for (int i = 0; i < vars.Length; i++) {
|
---|
| 479 | if (vars[i]) {
|
---|
| 480 | var varNode = (VariableTreeNode)varSy.CreateTreeNode();
|
---|
| 481 | varNode.Weight = coef[i];
|
---|
| 482 | varNode.VariableName = variables[i];
|
---|
| 483 | sum.AddSubtree(varNode);
|
---|
| 484 | }
|
---|
| 485 | }
|
---|
| 486 | sum.AddSubtree(CreateConstant(beta));
|
---|
| 487 | return sum;
|
---|
| 488 | }
|
---|
| 489 |
|
---|
| 490 | private ISymbolicExpressionTreeNode CreateConstant(double value) {
|
---|
| 491 | var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
|
---|
| 492 | constNode.Value = value;
|
---|
| 493 | return constNode;
|
---|
| 494 | }
|
---|
[17971] | 495 | }
|
---|
| 496 |
|
---|
| 497 | public class Agent {
|
---|
| 498 | public ContinuedFraction pocket;
|
---|
| 499 | public double pocketObjValue;
|
---|
| 500 | public ContinuedFraction current;
|
---|
| 501 | public double currentObjValue;
|
---|
| 502 |
|
---|
| 503 | public IList<Agent> children = new List<Agent>();
|
---|
| 504 |
|
---|
| 505 | public IEnumerable<Agent> IterateLevels() {
|
---|
| 506 | var agents = new List<Agent>() { this };
|
---|
| 507 | IterateLevelsRec(this, agents);
|
---|
| 508 | return agents;
|
---|
| 509 | }
|
---|
| 510 | public IEnumerable<Agent> IteratePostOrder() {
|
---|
| 511 | var agents = new List<Agent>();
|
---|
| 512 | IteratePostOrderRec(this, agents);
|
---|
| 513 | return agents;
|
---|
| 514 | }
|
---|
| 515 |
|
---|
| 516 | internal void MaintainInvariant() {
|
---|
| 517 | foreach (var child in children) {
|
---|
| 518 | MaintainInvariant(parent: this, child);
|
---|
| 519 | }
|
---|
| 520 | if (currentObjValue < pocketObjValue) {
|
---|
| 521 | Swap(ref pocket, ref current);
|
---|
| 522 | Swap(ref pocketObjValue, ref currentObjValue);
|
---|
| 523 | }
|
---|
| 524 | }
|
---|
| 525 |
|
---|
| 526 |
|
---|
| 527 | private static void MaintainInvariant(Agent parent, Agent child) {
|
---|
| 528 | if (child.pocketObjValue < parent.pocketObjValue) {
|
---|
| 529 | Swap(ref child.pocket, ref parent.pocket);
|
---|
| 530 | Swap(ref child.pocketObjValue, ref parent.pocketObjValue);
|
---|
| 531 | }
|
---|
| 532 | }
|
---|
| 533 |
|
---|
| 534 | private void IterateLevelsRec(Agent agent, List<Agent> agents) {
|
---|
| 535 | foreach (var child in agent.children) {
|
---|
| 536 | agents.Add(child);
|
---|
| 537 | }
|
---|
| 538 | foreach (var child in agent.children) {
|
---|
| 539 | IterateLevelsRec(child, agents);
|
---|
| 540 | }
|
---|
| 541 | }
|
---|
| 542 |
|
---|
| 543 | private void IteratePostOrderRec(Agent agent, List<Agent> agents) {
|
---|
| 544 | foreach (var child in agent.children) {
|
---|
| 545 | IteratePostOrderRec(child, agents);
|
---|
| 546 | }
|
---|
| 547 | agents.Add(agent);
|
---|
| 548 | }
|
---|
| 549 |
|
---|
| 550 |
|
---|
| 551 | private static void Swap<T>(ref T a, ref T b) {
|
---|
| 552 | var temp = a;
|
---|
| 553 | a = b;
|
---|
| 554 | b = temp;
|
---|
| 555 | }
|
---|
| 556 | }
|
---|
| 557 | }
|
---|