1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Linq;
|
---|
4 | using System.Threading;
|
---|
5 | using HEAL.Attic;
|
---|
6 | using HeuristicLab.Analysis;
|
---|
7 | using HeuristicLab.Common;
|
---|
8 | using HeuristicLab.Core;
|
---|
9 | using HeuristicLab.Data;
|
---|
10 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
11 | using HeuristicLab.Parameters;
|
---|
12 | using HeuristicLab.Problems.DataAnalysis;
|
---|
13 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
14 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
15 | using HeuristicLab.Random;
|
---|
16 |
|
---|
17 | namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression {
|
---|
18 | /// <summary>
|
---|
19 | /// Implementation of Continuous Fraction Regression (CFR) as described in
|
---|
20 | /// Pablo Moscato, Haoyuan Sun, Mohammad Nazmul Haque,
|
---|
21 | /// Analytic Continued Fractions for Regression: A Memetic Algorithm Approach,
|
---|
22 | /// Expert Systems with Applications, Volume 179, 2021, 115018, ISSN 0957-4174,
|
---|
23 | /// https://doi.org/10.1016/j.eswa.2021.115018.
|
---|
24 | /// </summary>
|
---|
25 | [Item("Continuous Fraction Regression (CFR)", "TODO")]
|
---|
26 | [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 999)]
|
---|
27 | [StorableType("7A375270-EAAF-4AD1-82FF-132318D20E09")]
|
---|
28 | public class Algorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
|
---|
29 | private const string MutationRateParameterName = "MutationRate";
|
---|
30 | private const string DepthParameterName = "Depth";
|
---|
31 | private const string NumGenerationsParameterName = "NumGenerations";
|
---|
32 | private const string StagnationGenerationsParameterName = "StagnationGenerations";
|
---|
33 | private const string LocalSearchIterationsParameterName = "LocalSearchIterations";
|
---|
34 | private const string LocalSearchRestartsParameterName = "LocalSearchRestarts";
|
---|
35 | private const string LocalSearchToleranceParameterName = "LocalSearchTolerance";
|
---|
36 | private const string DeltaParameterName = "Delta";
|
---|
37 | private const string ScaleDataParameterName = "ScaleData";
|
---|
38 |
|
---|
39 |
|
---|
40 | #region parameters
|
---|
41 | public IFixedValueParameter<PercentValue> MutationRateParameter => (IFixedValueParameter<PercentValue>)Parameters[MutationRateParameterName];
|
---|
42 | public double MutationRate {
|
---|
43 | get { return MutationRateParameter.Value.Value; }
|
---|
44 | set { MutationRateParameter.Value.Value = value; }
|
---|
45 | }
|
---|
46 | public IFixedValueParameter<IntValue> DepthParameter => (IFixedValueParameter<IntValue>)Parameters[DepthParameterName];
|
---|
47 | public int Depth {
|
---|
48 | get { return DepthParameter.Value.Value; }
|
---|
49 | set { DepthParameter.Value.Value = value; }
|
---|
50 | }
|
---|
51 | public IFixedValueParameter<IntValue> NumGenerationsParameter => (IFixedValueParameter<IntValue>)Parameters[NumGenerationsParameterName];
|
---|
52 | public int NumGenerations {
|
---|
53 | get { return NumGenerationsParameter.Value.Value; }
|
---|
54 | set { NumGenerationsParameter.Value.Value = value; }
|
---|
55 | }
|
---|
56 | public IFixedValueParameter<IntValue> StagnationGenerationsParameter => (IFixedValueParameter<IntValue>)Parameters[StagnationGenerationsParameterName];
|
---|
57 | public int StagnationGenerations {
|
---|
58 | get { return StagnationGenerationsParameter.Value.Value; }
|
---|
59 | set { StagnationGenerationsParameter.Value.Value = value; }
|
---|
60 | }
|
---|
61 | public IFixedValueParameter<IntValue> LocalSearchIterationsParameter => (IFixedValueParameter<IntValue>)Parameters[LocalSearchIterationsParameterName];
|
---|
62 | public int LocalSearchIterations {
|
---|
63 | get { return LocalSearchIterationsParameter.Value.Value; }
|
---|
64 | set { LocalSearchIterationsParameter.Value.Value = value; }
|
---|
65 | }
|
---|
66 | public IFixedValueParameter<IntValue> LocalSearchRestartsParameter => (IFixedValueParameter<IntValue>)Parameters[LocalSearchRestartsParameterName];
|
---|
67 | public int LocalSearchRestarts {
|
---|
68 | get { return LocalSearchRestartsParameter.Value.Value; }
|
---|
69 | set { LocalSearchRestartsParameter.Value.Value = value; }
|
---|
70 | }
|
---|
71 | public IFixedValueParameter<DoubleValue> LocalSearchToleranceParameter => (IFixedValueParameter<DoubleValue>)Parameters[LocalSearchToleranceParameterName];
|
---|
72 | public double LocalSearchTolerance {
|
---|
73 | get { return LocalSearchToleranceParameter.Value.Value; }
|
---|
74 | set { LocalSearchToleranceParameter.Value.Value = value; }
|
---|
75 | }
|
---|
76 | public IFixedValueParameter<PercentValue> DeltaParameter => (IFixedValueParameter<PercentValue>)Parameters[DeltaParameterName];
|
---|
77 | public double Delta {
|
---|
78 | get { return DeltaParameter.Value.Value; }
|
---|
79 | set { DeltaParameter.Value.Value = value; }
|
---|
80 | }
|
---|
81 | public IFixedValueParameter<BoolValue> ScaleDataParameter => (IFixedValueParameter<BoolValue>)Parameters[ScaleDataParameterName];
|
---|
82 | public bool ScaleData {
|
---|
83 | get { return ScaleDataParameter.Value.Value; }
|
---|
84 | set { ScaleDataParameter.Value.Value = value; }
|
---|
85 | }
|
---|
86 | #endregion
|
---|
87 |
|
---|
88 | // storable ctor
|
---|
89 | [StorableConstructor]
|
---|
90 | public Algorithm(StorableConstructorFlag _) : base(_) { }
|
---|
91 |
|
---|
92 | // cloning ctor
|
---|
93 | public Algorithm(Algorithm original, Cloner cloner) : base(original, cloner) { }
|
---|
94 |
|
---|
95 |
|
---|
96 | // default ctor
|
---|
97 | public Algorithm() : base() {
|
---|
98 | Problem = new RegressionProblem();
|
---|
99 | Parameters.Add(new FixedValueParameter<PercentValue>(MutationRateParameterName, "Mutation rate (default 10%)", new PercentValue(0.1)));
|
---|
100 | Parameters.Add(new FixedValueParameter<IntValue>(DepthParameterName, "Depth of the continued fraction representation (default 6)", new IntValue(6)));
|
---|
101 | Parameters.Add(new FixedValueParameter<IntValue>(NumGenerationsParameterName, "The maximum number of generations (default 200)", new IntValue(200)));
|
---|
102 | Parameters.Add(new FixedValueParameter<IntValue>(StagnationGenerationsParameterName, "Number of generations after which the population is re-initialized (default value 5)", new IntValue(5)));
|
---|
103 | Parameters.Add(new FixedValueParameter<IntValue>(LocalSearchIterationsParameterName, "Number of iterations for local search (simplex) (default value 250)", new IntValue(250)));
|
---|
104 | Parameters.Add(new FixedValueParameter<IntValue>(LocalSearchRestartsParameterName, "Number of restarts for local search (default value 4)", new IntValue(4)));
|
---|
105 | Parameters.Add(new FixedValueParameter<DoubleValue>(LocalSearchToleranceParameterName, "The tolerance value for local search (simplex) (default value: 1e-3)", new DoubleValue(1e-3)));
|
---|
106 | Parameters.Add(new FixedValueParameter<PercentValue>(DeltaParameterName, "The relative weight for the number of variables term in the fitness function (default value: 10%)", new PercentValue(0.1)));
|
---|
107 | Parameters.Add(new FixedValueParameter<BoolValue>(ScaleDataParameterName, "Turns on/off scaling of input variable values to the range [0 .. 1] (default: false)", new BoolValue(false)));
|
---|
108 | }
|
---|
109 |
|
---|
110 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
111 | return new Algorithm(this, cloner);
|
---|
112 | }
|
---|
113 |
|
---|
114 | protected override void Run(CancellationToken cancellationToken) {
|
---|
115 | var problemData = Problem.ProblemData;
|
---|
116 | double[,] xy;
|
---|
117 | if (ScaleData) {
|
---|
118 | // Scale data to range 0 .. 1
|
---|
119 | //
|
---|
120 | // Scaling was not used for the experiments in the paper. Statement by the authors: "We did not pre-process the data."
|
---|
121 | var transformations = new List<Transformation<double>>();
|
---|
122 | foreach (var input in problemData.AllowedInputVariables) {
|
---|
123 | var values = problemData.Dataset.GetDoubleValues(input, problemData.TrainingIndices);
|
---|
124 | var linTransformation = new LinearTransformation(problemData.AllowedInputVariables);
|
---|
125 | var min = values.Min();
|
---|
126 | var max = values.Max();
|
---|
127 | var range = max - min;
|
---|
128 | linTransformation.Addend = -min / range;
|
---|
129 | linTransformation.Multiplier = 1.0 / range;
|
---|
130 | transformations.Add(linTransformation);
|
---|
131 | }
|
---|
132 | // do not scale the target
|
---|
133 | transformations.Add(new LinearTransformation(problemData.AllowedInputVariables) { Addend = 0.0, Multiplier = 1.0 });
|
---|
134 | xy = problemData.Dataset.ToArray(problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }),
|
---|
135 | transformations,
|
---|
136 | problemData.TrainingIndices);
|
---|
137 | } else {
|
---|
138 | // no transformation
|
---|
139 | xy = problemData.Dataset.ToArray(problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }),
|
---|
140 | problemData.TrainingIndices);
|
---|
141 | }
|
---|
142 | var nVars = xy.GetLength(1) - 1;
|
---|
143 | var seed = new System.Random().Next();
|
---|
144 | var rand = new MersenneTwister((uint)seed);
|
---|
145 | CFRAlgorithm(nVars, Depth, MutationRate, xy, out var bestObj, rand, NumGenerations, StagnationGenerations,
|
---|
146 | Delta,
|
---|
147 | LocalSearchIterations, LocalSearchRestarts, LocalSearchTolerance, cancellationToken);
|
---|
148 | }
|
---|
149 |
|
---|
150 | private void CFRAlgorithm(int nVars, int depth, double mutationRate, double[,] trainingData,
|
---|
151 | out double bestObj,
|
---|
152 | IRandom rand, int numGen, int stagnatingGens, double evalDelta,
|
---|
153 | int localSearchIterations, int localSearchRestarts, double localSearchTolerance,
|
---|
154 | CancellationToken cancellationToken) {
|
---|
155 | /* Algorithm 1 */
|
---|
156 | /* Generate initial population by a randomized algorithm */
|
---|
157 | var pop = InitialPopulation(nVars, depth, rand, trainingData);
|
---|
158 | bestObj = pop.pocketObjValue;
|
---|
159 | // the best value since the last reset
|
---|
160 | var episodeBestObj = pop.pocketObjValue;
|
---|
161 | var episodeBestObjGen = 0;
|
---|
162 | for (int gen = 1; gen <= numGen && !cancellationToken.IsCancellationRequested; gen++) {
|
---|
163 | /* mutate each current solution in the population */
|
---|
164 | var pop_mu = Mutate(pop, mutationRate, rand);
|
---|
165 | /* generate new population by recombination mechanism */
|
---|
166 | var pop_r = RecombinePopulation(pop_mu, rand, nVars);
|
---|
167 |
|
---|
168 | // "We ran the Local Search after Mutation and recombination operations. We executed the local-search only on the Current solutions."
|
---|
169 | // "We executed the MaintainInvariant() in the following steps:
|
---|
170 | // - After generating the initial population
|
---|
171 | // - after resetting the root
|
---|
172 | // - after executing the local-search on the whole population.
|
---|
173 | // We updated the pocket/ current automatically after mutation and recombination operation."
|
---|
174 |
|
---|
175 | /* local search optimization of current solutions */
|
---|
176 | foreach (var agent in pop_r.IterateLevels()) {
|
---|
177 | LocalSearchSimplex(localSearchIterations, localSearchRestarts, localSearchTolerance, evalDelta, agent.current, ref agent.currentObjValue, trainingData, rand);
|
---|
178 | }
|
---|
179 | foreach (var agent in pop_r.IteratePostOrder()) agent.MaintainInvariant(); // post-order to make sure that the root contains the best model
|
---|
180 |
|
---|
181 |
|
---|
182 | // for detecting stagnation we track the best objective value since the last reset
|
---|
183 | // and reset if this does not change for stagnatingGens
|
---|
184 | if (gen > episodeBestObjGen + stagnatingGens) {
|
---|
185 | Reset(pop_r, nVars, depth, rand, trainingData);
|
---|
186 | episodeBestObj = double.MaxValue;
|
---|
187 | }
|
---|
188 | if (episodeBestObj > pop_r.pocketObjValue) {
|
---|
189 | episodeBestObjGen = gen; // wait at least stagnatingGens until resetting again
|
---|
190 | episodeBestObj = pop_r.pocketObjValue;
|
---|
191 | }
|
---|
192 |
|
---|
193 | /* replace old population with evolved population */
|
---|
194 | pop = pop_r;
|
---|
195 |
|
---|
196 | /* keep track of the best solution */
|
---|
197 | if (bestObj > pop.pocketObjValue) {
|
---|
198 | bestObj = pop.pocketObjValue;
|
---|
199 | Results.AddOrUpdateResult("MSE (best)", new DoubleValue(bestObj));
|
---|
200 | Results.AddOrUpdateResult("Solution", CreateSymbolicRegressionSolution(pop.pocket, trainingData, Problem.ProblemData.AllowedInputVariables.ToArray(), Problem.ProblemData.TargetVariable));
|
---|
201 | }
|
---|
202 |
|
---|
203 | #region visualization and debugging
|
---|
204 | DataTable qualities;
|
---|
205 | int i = 0;
|
---|
206 | if (Results.TryGetValue("Qualities", out var qualitiesResult)) {
|
---|
207 | qualities = (DataTable)qualitiesResult.Value;
|
---|
208 | } else {
|
---|
209 | qualities = new DataTable("Qualities", "Qualities");
|
---|
210 | i = 0;
|
---|
211 | foreach (var node in pop.IterateLevels()) {
|
---|
212 | qualities.Rows.Add(new DataRow($"Quality {i} pocket", "Quality of pocket"));
|
---|
213 | qualities.Rows.Add(new DataRow($"Quality {i} current", "Quality of current"));
|
---|
214 | i++;
|
---|
215 | }
|
---|
216 | Results.AddOrUpdateResult("Qualities", qualities);
|
---|
217 | }
|
---|
218 | i = 0;
|
---|
219 | foreach (var node in pop.IterateLevels()) {
|
---|
220 | qualities.Rows[$"Quality {i} pocket"].Values.Add(node.pocketObjValue);
|
---|
221 | qualities.Rows[$"Quality {i} current"].Values.Add(node.currentObjValue);
|
---|
222 | i++;
|
---|
223 | }
|
---|
224 | #endregion
|
---|
225 | }
|
---|
226 | }
|
---|
227 |
|
---|
228 |
|
---|
229 |
|
---|
230 | private Agent InitialPopulation(int nVars, int depth, IRandom rand, double[,] trainingData) {
|
---|
231 | /* instantiate 13 agents in the population */
|
---|
232 | var pop = new Agent();
|
---|
233 | // see Figure 2
|
---|
234 | for (int i = 0; i < 3; i++) {
|
---|
235 | pop.children.Add(new Agent());
|
---|
236 | for (int j = 0; j < 3; j++) {
|
---|
237 | pop.children[i].children.Add(new Agent());
|
---|
238 | }
|
---|
239 | }
|
---|
240 |
|
---|
241 | // Statement by the authors: "Yes, we use post-order traversal here"
|
---|
242 | foreach (var agent in pop.IteratePostOrder()) {
|
---|
243 | agent.current = new ContinuedFraction(nVars, depth, rand);
|
---|
244 | agent.pocket = new ContinuedFraction(nVars, depth, rand);
|
---|
245 |
|
---|
246 | agent.currentObjValue = Evaluate(agent.current, trainingData, Delta);
|
---|
247 | agent.pocketObjValue = Evaluate(agent.pocket, trainingData, Delta);
|
---|
248 |
|
---|
249 | /* within each agent, the pocket solution always holds the better value of guiding
|
---|
250 | * function than its current solution
|
---|
251 | */
|
---|
252 | agent.MaintainInvariant();
|
---|
253 | }
|
---|
254 | return pop;
|
---|
255 | }
|
---|
256 |
|
---|
257 | // Reset is not described in detail in the paper.
|
---|
258 | // Statement by the authors: "We only replaced the pocket solution of the root with
|
---|
259 | // a randomly generated solution. Then we execute the maintain-invariant process.
|
---|
260 | // It does not initialize the solutions in the entire population."
|
---|
261 | private void Reset(Agent root, int nVars, int depth, IRandom rand, double[,] trainingData) {
|
---|
262 | root.pocket = new ContinuedFraction(nVars, depth, rand);
|
---|
263 | root.current = new ContinuedFraction(nVars, depth, rand);
|
---|
264 |
|
---|
265 | root.currentObjValue = Evaluate(root.current, trainingData, Delta);
|
---|
266 | root.pocketObjValue = Evaluate(root.pocket, trainingData, Delta);
|
---|
267 |
|
---|
268 | foreach (var agent in root.IteratePreOrder()) { agent.MaintainInvariant(); } // Here we push the newly created model down the hierarchy.
|
---|
269 | }
|
---|
270 |
|
---|
271 |
|
---|
272 |
|
---|
273 | private Agent RecombinePopulation(Agent pop, IRandom rand, int nVars) {
|
---|
274 | var l = pop;
|
---|
275 |
|
---|
276 | if (pop.children.Count > 0) {
|
---|
277 | var s1 = pop.children[0];
|
---|
278 | var s2 = pop.children[1];
|
---|
279 | var s3 = pop.children[2];
|
---|
280 |
|
---|
281 | // Statement by the authors: "we are using recently generated solutions.
|
---|
282 | // For an example, in step 1 we got the new current(l), which is being used in
|
---|
283 | // Step 2 to generate current(s3). The current(s3) from Step 2 is being used at
|
---|
284 | // Step 4. These steps are executed sequentially from 1 to 4. Similarly, in the
|
---|
285 | // recombination of lower-level subpopulations, we will have the new current
|
---|
286 | // (the supporters generated at the previous level) as the leader of the subpopulation.
|
---|
287 | l.current = Recombine(l.pocket, s1.current, SelectRandomOp(rand), rand, nVars);
|
---|
288 | s3.current = Recombine(s3.pocket, l.current, SelectRandomOp(rand), rand, nVars);
|
---|
289 | s1.current = Recombine(s1.pocket, s2.current, SelectRandomOp(rand), rand, nVars);
|
---|
290 | s2.current = Recombine(s2.pocket, s3.current, SelectRandomOp(rand), rand, nVars);
|
---|
291 |
|
---|
292 | // recombination works from top to bottom
|
---|
293 | foreach (var child in pop.children) {
|
---|
294 | RecombinePopulation(child, rand, nVars);
|
---|
295 | }
|
---|
296 |
|
---|
297 | }
|
---|
298 | return pop;
|
---|
299 | }
|
---|
300 |
|
---|
301 | private static ContinuedFraction Recombine(ContinuedFraction p1, ContinuedFraction p2, Func<bool[], bool[], bool[]> op, IRandom rand, int nVars) {
|
---|
302 | ContinuedFraction ch = new ContinuedFraction() { h = new Term[p1.h.Length] };
|
---|
303 | /* apply a recombination operator chosen uniformly at random on variables of two parents into offspring */
|
---|
304 | ch.vars = op(p1.vars, p2.vars);
|
---|
305 |
|
---|
306 | /* recombine the coefficients for each term (h) of the continued fraction */
|
---|
307 | for (int i = 0; i < p1.h.Length; i++) {
|
---|
308 | var coefa = p1.h[i].coef; var varsa = p1.h[i].vars;
|
---|
309 | var coefb = p2.h[i].coef; var varsb = p2.h[i].vars;
|
---|
310 |
|
---|
311 | /* recombine coefficient values for variables */
|
---|
312 | var coefx = new double[nVars];
|
---|
313 | var varsx = new bool[nVars]; // deviates from paper, probably forgotten in the pseudo-code
|
---|
314 | for (int vi = 0; vi < nVars; vi++) {
|
---|
315 | if (ch.vars[vi]) { // CHECK: paper uses featAt()
|
---|
316 | if (varsa[vi] && varsb[vi]) {
|
---|
317 | coefx[vi] = coefa[vi] + (rand.NextDouble() * 5 - 1) * (coefb[vi] - coefa[vi]) / 3.0;
|
---|
318 | varsx[vi] = true;
|
---|
319 | } else if (varsa[vi]) {
|
---|
320 | coefx[vi] = coefa[vi];
|
---|
321 | varsx[vi] = true;
|
---|
322 | } else if (varsb[vi]) {
|
---|
323 | coefx[vi] = coefb[vi];
|
---|
324 | varsx[vi] = true;
|
---|
325 | }
|
---|
326 | }
|
---|
327 | }
|
---|
328 | /* update new coefficients of the term in offspring */
|
---|
329 | ch.h[i] = new Term() { coef = coefx, vars = varsx };
|
---|
330 | /* compute new value of constant (beta) for term hi in the offspring solution ch using
|
---|
331 | * beta of p1.hi and p2.hi */
|
---|
332 | ch.h[i].beta = p1.h[i].beta + (rand.NextDouble() * 5 - 1) * (p2.h[i].beta - p1.h[i].beta) / 3.0;
|
---|
333 | }
|
---|
334 | // return LocalSearchSimplex(ch, trainingData); // The paper has a local search step here.
|
---|
335 | // The authors have stated that local search is executed after mutation and recombination
|
---|
336 | // for the current solutions.
|
---|
337 | // Local search and MaintainInvariant is called in the main loop (Alg 1)
|
---|
338 | return ch;
|
---|
339 | }
|
---|
340 |
|
---|
341 | private Agent Mutate(Agent pop, double mutationRate, IRandom rand) {
|
---|
342 | foreach (var agent in pop.IterateLevels()) {
|
---|
343 | if (rand.NextDouble() < mutationRate) {
|
---|
344 | if (agent.currentObjValue < 1.2 * agent.pocketObjValue ||
|
---|
345 | agent.currentObjValue > 2 * agent.pocketObjValue)
|
---|
346 | ToggleVariables(agent.current, rand); // major mutation
|
---|
347 | else
|
---|
348 | ModifyVariable(agent.current, rand); // soft mutation
|
---|
349 | }
|
---|
350 | }
|
---|
351 | return pop;
|
---|
352 | }
|
---|
353 |
|
---|
354 | private void ToggleVariables(ContinuedFraction cfrac, IRandom rand) {
|
---|
355 | double coinToss(double a, double b) {
|
---|
356 | return rand.NextDouble() < 0.5 ? a : b;
|
---|
357 | }
|
---|
358 |
|
---|
359 | /* select a variable index uniformly at random */
|
---|
360 | int N = cfrac.vars.Length;
|
---|
361 | var vIdx = rand.Next(N);
|
---|
362 |
|
---|
363 | /* for each depth of continued fraction, toggle the selection of variables of the term (h) */
|
---|
364 | foreach (var h in cfrac.h) {
|
---|
365 | /* Case 1: cfrac variable is turned ON: Turn OFF the variable, and either 'Remove' or
|
---|
366 | * 'Remember' the coefficient value at random */
|
---|
367 | if (cfrac.vars[vIdx]) { // CHECK: paper uses varAt()
|
---|
368 | h.vars[vIdx] = false; // CHECK: paper uses varAt()
|
---|
369 | h.coef[vIdx] = coinToss(0, h.coef[vIdx]);
|
---|
370 | } else {
|
---|
371 | /* Case 2: term variable is turned OFF: Turn ON the variable, and either 'Remove'
|
---|
372 | * or 'Replace' the coefficient with a random value between -3 and 3 at random */
|
---|
373 | if (!h.vars[vIdx]) {
|
---|
374 | h.vars[vIdx] = true; // CHECK: paper uses varAt()
|
---|
375 | h.coef[vIdx] = coinToss(0, rand.NextDouble() * 6 - 3);
|
---|
376 | }
|
---|
377 | }
|
---|
378 | }
|
---|
379 | /* toggle the randomly selected variable */
|
---|
380 | cfrac.vars[vIdx] = !cfrac.vars[vIdx]; // CHECK: paper uses varAt()
|
---|
381 | }
|
---|
382 |
|
---|
383 | private void ModifyVariable(ContinuedFraction cfrac, IRandom rand) {
|
---|
384 | /* randomly select a variable which is turned ON */
|
---|
385 | var candVars = new List<int>();
|
---|
386 | for (int i = 0; i < cfrac.vars.Length; i++) if (cfrac.vars[i]) candVars.Add(i); // CHECK: paper uses varAt()
|
---|
387 | if (candVars.Count == 0) return; // no variable active
|
---|
388 | var vIdx = candVars[rand.Next(candVars.Count)];
|
---|
389 |
|
---|
390 | /* randomly select a term (h) of continued fraction */
|
---|
391 | var h = cfrac.h[rand.Next(cfrac.h.Length)];
|
---|
392 |
|
---|
393 | /* modify the coefficient value */
|
---|
394 | if (h.vars[vIdx]) { // CHECK: paper uses varAt()
|
---|
395 | h.coef[vIdx] = 0.0;
|
---|
396 | } else {
|
---|
397 | h.coef[vIdx] = rand.NextDouble() * 6 - 3;
|
---|
398 | }
|
---|
399 | /* Toggle the randomly selected variable */
|
---|
400 | h.vars[vIdx] = !h.vars[vIdx]; // CHECK: paper uses varAt()
|
---|
401 | }
|
---|
402 |
|
---|
403 | private static double Evaluate(ContinuedFraction cfrac, double[,] trainingData, double delta) {
|
---|
404 | var dataPoint = new double[trainingData.GetLength(1) - 1];
|
---|
405 | var yIdx = trainingData.GetLength(1) - 1;
|
---|
406 | double sum = 0.0;
|
---|
407 | for (int r = 0; r < trainingData.GetLength(0); r++) {
|
---|
408 | for (int c = 0; c < dataPoint.Length; c++) {
|
---|
409 | dataPoint[c] = trainingData[r, c];
|
---|
410 | }
|
---|
411 | var y = trainingData[r, yIdx];
|
---|
412 | var pred = Evaluate(cfrac, dataPoint);
|
---|
413 | var res = y - pred;
|
---|
414 | sum += res * res;
|
---|
415 | }
|
---|
416 | return sum / trainingData.GetLength(0) * (1 + delta * cfrac.vars.Count(vi => vi));
|
---|
417 | }
|
---|
418 |
|
---|
419 | private static double Evaluate(ContinuedFraction cfrac, double[] dataPoint) {
|
---|
420 | var res = 0.0;
|
---|
421 | for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
|
---|
422 | var hi = cfrac.h[i];
|
---|
423 | var hi1 = cfrac.h[i - 1];
|
---|
424 | var denom = hi.beta + dot(hi.vars, hi.coef, dataPoint) + res;
|
---|
425 | var numerator = hi1.beta + dot(hi1.vars, hi1.coef, dataPoint);
|
---|
426 | res = numerator / denom;
|
---|
427 | }
|
---|
428 | var h0 = cfrac.h[0];
|
---|
429 | res += h0.beta + dot(h0.vars, h0.coef, dataPoint);
|
---|
430 | return res;
|
---|
431 | }
|
---|
432 |
|
---|
433 |
|
---|
434 | private Func<bool[], bool[], bool[]> SelectRandomOp(IRandom rand) {
|
---|
435 | bool[] union(bool[] a, bool[] b) {
|
---|
436 | var res = new bool[a.Length];
|
---|
437 | for (int i = 0; i < a.Length; i++) res[i] = a[i] || b[i];
|
---|
438 | return res;
|
---|
439 | }
|
---|
440 | bool[] intersect(bool[] a, bool[] b) {
|
---|
441 | var res = new bool[a.Length];
|
---|
442 | for (int i = 0; i < a.Length; i++) res[i] = a[i] && b[i];
|
---|
443 | return res;
|
---|
444 | }
|
---|
445 | bool[] symmetricDifference(bool[] a, bool[] b) {
|
---|
446 | var res = new bool[a.Length];
|
---|
447 | for (int i = 0; i < a.Length; i++) res[i] = a[i] ^ b[i];
|
---|
448 | return res;
|
---|
449 | }
|
---|
450 | switch (rand.Next(3)) {
|
---|
451 | case 0: return union;
|
---|
452 | case 1: return intersect;
|
---|
453 | case 2: return symmetricDifference;
|
---|
454 | default: throw new ArgumentException();
|
---|
455 | }
|
---|
456 | }
|
---|
457 |
|
---|
458 | private static double dot(bool[] filter, double[] x, double[] y) {
|
---|
459 | var s = 0.0;
|
---|
460 | for (int i = 0; i < x.Length; i++)
|
---|
461 | if (filter[i])
|
---|
462 | s += x[i] * y[i];
|
---|
463 | return s;
|
---|
464 | }
|
---|
465 |
|
---|
466 |
|
---|
467 | private static void LocalSearchSimplex(int iterations, int restarts, double tolerance, double delta, ContinuedFraction ch, ref double quality, double[,] trainingData, IRandom rand) {
|
---|
468 | double uniformPeturbation = 1.0;
|
---|
469 | int maxEvals = iterations;
|
---|
470 | int numSearches = restarts + 1;
|
---|
471 | var numRows = trainingData.GetLength(0);
|
---|
472 | int numSelectedRows = numRows / 5; // 20% of the training samples
|
---|
473 |
|
---|
474 | quality = Evaluate(ch, trainingData, delta); // get quality with origial coefficients
|
---|
475 |
|
---|
476 | double[] origCoeff = ExtractCoeff(ch);
|
---|
477 | if (origCoeff.Length == 0) return; // no parameters to optimize
|
---|
478 |
|
---|
479 | var bestQuality = quality;
|
---|
480 | var bestCoeff = origCoeff;
|
---|
481 |
|
---|
482 | var fittingData = SelectRandomRows(trainingData, numSelectedRows, rand);
|
---|
483 |
|
---|
484 | double objFunc(double[] curCoeff) {
|
---|
485 | SetCoeff(ch, curCoeff);
|
---|
486 | return Evaluate(ch, fittingData, delta);
|
---|
487 | }
|
---|
488 |
|
---|
489 | for (int count = 0; count < numSearches; count++) {
|
---|
490 |
|
---|
491 | SimplexConstant[] constants = new SimplexConstant[origCoeff.Length];
|
---|
492 | for (int i = 0; i < origCoeff.Length; i++) {
|
---|
493 | constants[i] = new SimplexConstant(origCoeff[i], uniformPeturbation);
|
---|
494 | }
|
---|
495 |
|
---|
496 | RegressionResult result = NelderMeadSimplex.Regress(constants, tolerance, maxEvals, objFunc);
|
---|
497 |
|
---|
498 | var optimizedCoeff = result.Constants;
|
---|
499 | SetCoeff(ch, optimizedCoeff);
|
---|
500 |
|
---|
501 | var newQuality = Evaluate(ch, trainingData, delta);
|
---|
502 |
|
---|
503 | if (newQuality < bestQuality) {
|
---|
504 | bestCoeff = optimizedCoeff;
|
---|
505 | bestQuality = newQuality;
|
---|
506 | }
|
---|
507 | } // reps
|
---|
508 |
|
---|
509 | SetCoeff(ch, bestCoeff);
|
---|
510 | quality = bestQuality;
|
---|
511 | }
|
---|
512 |
|
---|
513 | private static double[,] SelectRandomRows(double[,] trainingData, int numSelectedRows, IRandom rand) {
|
---|
514 | var numRows = trainingData.GetLength(0);
|
---|
515 | var numCols = trainingData.GetLength(1);
|
---|
516 | var selectedRows = Enumerable.Range(0, numRows).Shuffle(rand).Take(numSelectedRows).ToArray();
|
---|
517 | var subset = new double[numSelectedRows, numCols];
|
---|
518 | var i = 0;
|
---|
519 | foreach (var r in selectedRows) {
|
---|
520 | for (int c = 0; c < numCols; c++) {
|
---|
521 | subset[i, c] = trainingData[r, c];
|
---|
522 | }
|
---|
523 | i++;
|
---|
524 | }
|
---|
525 | return subset;
|
---|
526 | }
|
---|
527 |
|
---|
528 | private static double[] ExtractCoeff(ContinuedFraction ch) {
|
---|
529 | var coeff = new List<double>();
|
---|
530 | foreach (var hi in ch.h) {
|
---|
531 | coeff.Add(hi.beta);
|
---|
532 | for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
|
---|
533 | if (hi.vars[vIdx]) coeff.Add(hi.coef[vIdx]);
|
---|
534 | }
|
---|
535 | }
|
---|
536 | return coeff.ToArray();
|
---|
537 | }
|
---|
538 |
|
---|
539 | private static void SetCoeff(ContinuedFraction ch, double[] curCoeff) {
|
---|
540 | int k = 0;
|
---|
541 | foreach (var hi in ch.h) {
|
---|
542 | hi.beta = curCoeff[k++];
|
---|
543 | for (int vIdx = 0; vIdx < hi.vars.Length; vIdx++) {
|
---|
544 | if (hi.vars[vIdx]) hi.coef[vIdx] = curCoeff[k++];
|
---|
545 | }
|
---|
546 | }
|
---|
547 | }
|
---|
548 |
|
---|
549 | #region build a symbolic expression tree
|
---|
550 | Symbol addSy = new Addition();
|
---|
551 | Symbol divSy = new Division();
|
---|
552 | Symbol startSy = new StartSymbol();
|
---|
553 | Symbol progSy = new ProgramRootSymbol();
|
---|
554 | Symbol constSy = new Constant();
|
---|
555 | Symbol varSy = new Problems.DataAnalysis.Symbolic.Variable();
|
---|
556 |
|
---|
557 | private ISymbolicRegressionSolution CreateSymbolicRegressionSolution(ContinuedFraction cfrac, double[,] trainingData, string[] variables, string targetVariable) {
|
---|
558 | ISymbolicExpressionTreeNode res = null;
|
---|
559 | for (int i = cfrac.h.Length - 1; i > 1; i -= 2) {
|
---|
560 | var hi = cfrac.h[i];
|
---|
561 | var hi1 = cfrac.h[i - 1];
|
---|
562 | var denom = CreateLinearCombination(hi.vars, hi.coef, variables, hi.beta);
|
---|
563 | if (res != null) {
|
---|
564 | denom.AddSubtree(res);
|
---|
565 | }
|
---|
566 |
|
---|
567 | var numerator = CreateLinearCombination(hi1.vars, hi1.coef, variables, hi1.beta);
|
---|
568 |
|
---|
569 | res = divSy.CreateTreeNode();
|
---|
570 | res.AddSubtree(numerator);
|
---|
571 | res.AddSubtree(denom);
|
---|
572 | }
|
---|
573 |
|
---|
574 | var h0 = cfrac.h[0];
|
---|
575 | var h0Term = CreateLinearCombination(h0.vars, h0.coef, variables, h0.beta);
|
---|
576 | h0Term.AddSubtree(res);
|
---|
577 |
|
---|
578 | var progRoot = progSy.CreateTreeNode();
|
---|
579 | var start = startSy.CreateTreeNode();
|
---|
580 | progRoot.AddSubtree(start);
|
---|
581 | start.AddSubtree(h0Term);
|
---|
582 |
|
---|
583 | var model = new SymbolicRegressionModel(targetVariable, new SymbolicExpressionTree(progRoot), new SymbolicDataAnalysisExpressionTreeBatchInterpreter());
|
---|
584 | var ds = new Dataset(variables.Concat(new[] { targetVariable }), trainingData);
|
---|
585 | var problemData = new RegressionProblemData(ds, variables, targetVariable);
|
---|
586 | var sol = new SymbolicRegressionSolution(model, problemData);
|
---|
587 | return sol;
|
---|
588 | }
|
---|
589 |
|
---|
590 | private ISymbolicExpressionTreeNode CreateLinearCombination(bool[] vars, double[] coef, string[] variables, double beta) {
|
---|
591 | var sum = addSy.CreateTreeNode();
|
---|
592 | for (int i = 0; i < vars.Length; i++) {
|
---|
593 | if (vars[i]) {
|
---|
594 | var varNode = (VariableTreeNode)varSy.CreateTreeNode();
|
---|
595 | varNode.Weight = coef[i];
|
---|
596 | varNode.VariableName = variables[i];
|
---|
597 | sum.AddSubtree(varNode);
|
---|
598 | }
|
---|
599 | }
|
---|
600 | sum.AddSubtree(CreateConstant(beta));
|
---|
601 | return sum;
|
---|
602 | }
|
---|
603 |
|
---|
604 | private ISymbolicExpressionTreeNode CreateConstant(double value) {
|
---|
605 | var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
|
---|
606 | constNode.Value = value;
|
---|
607 | return constNode;
|
---|
608 | }
|
---|
609 | }
|
---|
610 | #endregion
|
---|
611 | }
|
---|