1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Linq;
|
---|
4 | using System.Runtime.InteropServices;
|
---|
5 | using System.Threading;
|
---|
6 | using HeuristicLab.Analysis;
|
---|
7 | using HeuristicLab.Common;
|
---|
8 | using HeuristicLab.Core;
|
---|
9 | using HeuristicLab.Data;
|
---|
10 | using System.Text.RegularExpressions;
|
---|
11 | using HeuristicLab.Optimization;
|
---|
12 | using HeuristicLab.Parameters;
|
---|
13 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
|
---|
14 | using HeuristicLab.Problems.DataAnalysis;
|
---|
15 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
16 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
17 |
|
---|
18 | namespace PGE {
|
---|
19 | [Item(Name = "Priorizied Grammar Enumeration (PGE)", Description = "Priorizied grammar enumeration algorithm. Worm, T. and Chiu K., 'Prioritized Grammar Enumeration: Symbolic Regression by Dynamic Programming'. GECCO 2013")]
|
---|
20 |
|
---|
21 | [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
|
---|
22 |
|
---|
23 | [StorableClass]
|
---|
24 | public unsafe class PGE : BasicAlgorithm {
|
---|
25 |
|
---|
26 | [DllImport("go-pge.dll", EntryPoint = "addTestData", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
|
---|
27 | public static extern void AddTestData([MarshalAs(UnmanagedType.AnsiBStr)] string indepNames, [MarshalAs(UnmanagedType.AnsiBStr)] string depndNames, double[] matrix, int nEntries);
|
---|
28 |
|
---|
29 | [DllImport("go-pge.dll", EntryPoint = "addTrainData", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
|
---|
30 | public static extern void AddTrainData([MarshalAs(UnmanagedType.AnsiBStr)] string indepNames, [MarshalAs(UnmanagedType.AnsiBStr)] string depndNames, double[] matrix, int nEntries);
|
---|
31 |
|
---|
32 | [DllImport("go-pge.dll", EntryPoint = "initSearch", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
|
---|
33 | public static extern void InitSearch(int maxGen, int pgeRptEpoch, int pgeRptCount, int pgeArchiveCap, int peelCnt, int evalrCount, double zeroEpsilon, [MarshalAs(UnmanagedType.AnsiBStr)] string initMethod, [MarshalAs(UnmanagedType.AnsiBStr)] string growMethod, int sortType);
|
---|
34 |
|
---|
35 | [DllImport("go-pge.dll", EntryPoint = "initTreeParams", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
|
---|
36 | public static extern void InitTreeParams([MarshalAs(UnmanagedType.AnsiBStr)] string roots, [MarshalAs(UnmanagedType.AnsiBStr)] string nodes, [MarshalAs(UnmanagedType.AnsiBStr)] string nonTrig, [MarshalAs(UnmanagedType.AnsiBStr)] string leafs, int numUsableVars, int maxSize, int minSize, int maxDepth, int minDepth);
|
---|
37 |
|
---|
38 | [DllImport("go-pge.dll", EntryPoint = "initProblem", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
|
---|
39 | public static extern void InitProblem([MarshalAs(UnmanagedType.AnsiBStr)] string name, int maxIter, double hitRatio, int searchVar, [MarshalAs(UnmanagedType.AnsiBStr)] string ProblemTypeString, int numProcs);
|
---|
40 |
|
---|
41 | [DllImport("go-pge.dll", EntryPoint = "stepW", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
|
---|
42 | public static extern int StepW();
|
---|
43 |
|
---|
44 | [DllImport("go-pge.dll", EntryPoint = "getStepResult", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
|
---|
45 | public static extern IntPtr GetStepResult(out int testscore, out int nCoeff);
|
---|
46 |
|
---|
47 | [DllImport("go-pge.dll", EntryPoint = "getCoeffResult", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
|
---|
48 | public static extern double GetCoeffResult();
|
---|
49 |
|
---|
50 | public override Type ProblemType { get { return typeof(RegressionProblem); } }
|
---|
51 | public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
|
---|
52 |
|
---|
53 | #region parameter names
|
---|
54 | private static readonly string MaxIterationsParameterName = "MaxIterations";
|
---|
55 | private static readonly string MaxGenParameterName = "MaxGen";
|
---|
56 | private static readonly string EvalrCountParameterName = "EvalrCount";
|
---|
57 | private static readonly string MaxSizeParameterName = "MaxSize";
|
---|
58 | private static readonly string MinSizeParameterName = "MinSize";
|
---|
59 | private static readonly string MaxDepthParameterName = "MaxDepth";
|
---|
60 | private static readonly string MinDepthParameterName = "MinDepth";
|
---|
61 | private static readonly string PgeRptEpochParameterName = "PgeRptEpoch";
|
---|
62 | private static readonly string PgeRptCountParameterName = "PgeRptCount";
|
---|
63 | private static readonly string PgeArchiveCapParameterName = "PgeArchiveCap";
|
---|
64 | private static readonly string PeelCntParameterName = "PeelCnt";
|
---|
65 | private static readonly string ZeroEpsilonParameterName = "ZeroEpsilon";
|
---|
66 | private static readonly string HitRatioParameterName = "HitRatio";
|
---|
67 | private static readonly string InitMethodParameterName = "InitMethod";
|
---|
68 | private static readonly string GrowMethodParameterName = "GrowMethod";
|
---|
69 | private static readonly string RootsParameterName = "Roots";
|
---|
70 | private static readonly string NodesParameterName = "Nodes";
|
---|
71 | private static readonly string NonTrigParameterName = "NonTrig";
|
---|
72 | private static readonly string LeafsParameterName = "Leafs";
|
---|
73 |
|
---|
74 | #endregion
|
---|
75 |
|
---|
76 | #region parameters
|
---|
77 | private IFixedValueParameter<IntValue> MaxIterationsParameter {
|
---|
78 | get { return (IFixedValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }
|
---|
79 | }
|
---|
80 | public int MaxIterations {
|
---|
81 | get { return MaxIterationsParameter.Value.Value; }
|
---|
82 | set { MaxIterationsParameter.Value.Value = value; }
|
---|
83 | }
|
---|
84 |
|
---|
85 | private IFixedValueParameter<IntValue> MaxGenParameter {
|
---|
86 | get { return (IFixedValueParameter<IntValue>)Parameters[MaxGenParameterName]; }
|
---|
87 | }
|
---|
88 | public int MaxGen {
|
---|
89 | get { return MaxGenParameter.Value.Value; }
|
---|
90 | set { MaxGenParameter.Value.Value = value; }
|
---|
91 | }
|
---|
92 |
|
---|
93 | private IFixedValueParameter<IntValue> EvalrCountParameter {
|
---|
94 | get { return (IFixedValueParameter<IntValue>)Parameters[EvalrCountParameterName]; }
|
---|
95 | }
|
---|
96 | public int EvalrCount {
|
---|
97 | get { return EvalrCountParameter.Value.Value; }
|
---|
98 | set { EvalrCountParameter.Value.Value = value; }
|
---|
99 | }
|
---|
100 |
|
---|
101 | private IFixedValueParameter<IntValue> MaxSizeParameter {
|
---|
102 | get { return (IFixedValueParameter<IntValue>)Parameters[MaxSizeParameterName]; }
|
---|
103 | }
|
---|
104 | public int MaxSize {
|
---|
105 | get { return MaxSizeParameter.Value.Value; }
|
---|
106 | set { MaxSizeParameter.Value.Value = value; }
|
---|
107 | }
|
---|
108 |
|
---|
109 | private IFixedValueParameter<IntValue> MinSizeParameter {
|
---|
110 | get { return (IFixedValueParameter<IntValue>)Parameters[MinSizeParameterName]; }
|
---|
111 | }
|
---|
112 | public int MinSize {
|
---|
113 | get { return MinSizeParameter.Value.Value; }
|
---|
114 | set { MinSizeParameter.Value.Value = value; }
|
---|
115 | }
|
---|
116 |
|
---|
117 | private IFixedValueParameter<IntValue> MaxDepthParameter {
|
---|
118 | get { return (IFixedValueParameter<IntValue>)Parameters[MaxDepthParameterName]; }
|
---|
119 | }
|
---|
120 | public int MaxDepth {
|
---|
121 | get { return MaxDepthParameter.Value.Value; }
|
---|
122 | set { MaxDepthParameter.Value.Value = value; }
|
---|
123 | }
|
---|
124 |
|
---|
125 | private IFixedValueParameter<IntValue> MinDepthParameter {
|
---|
126 | get { return (IFixedValueParameter<IntValue>)Parameters[MinDepthParameterName]; }
|
---|
127 | }
|
---|
128 | public int MinDepth {
|
---|
129 | get { return MinDepthParameter.Value.Value; }
|
---|
130 | set { MinDepthParameter.Value.Value = value; }
|
---|
131 | }
|
---|
132 |
|
---|
133 | private IFixedValueParameter<IntValue> PgeRptEpochParameter {
|
---|
134 | get { return (IFixedValueParameter<IntValue>)Parameters[PgeRptEpochParameterName]; }
|
---|
135 | }
|
---|
136 | public int PgeRptEpoch {
|
---|
137 | get { return PgeRptEpochParameter.Value.Value; }
|
---|
138 | set { PgeRptEpochParameter.Value.Value = value; }
|
---|
139 | }
|
---|
140 |
|
---|
141 | private IFixedValueParameter<IntValue> PgeRptCountParameter {
|
---|
142 | get { return (IFixedValueParameter<IntValue>)Parameters[PgeRptCountParameterName]; }
|
---|
143 | }
|
---|
144 | public int PgeRptCount {
|
---|
145 | get { return PgeRptCountParameter.Value.Value; }
|
---|
146 | set { PgeRptCountParameter.Value.Value = value; }
|
---|
147 | }
|
---|
148 |
|
---|
149 | private IFixedValueParameter<IntValue> PgeArchiveCapParameter {
|
---|
150 | get { return (IFixedValueParameter<IntValue>)Parameters[PgeArchiveCapParameterName]; }
|
---|
151 | }
|
---|
152 | public int PgeArchiveCap {
|
---|
153 | get { return PgeArchiveCapParameter.Value.Value; }
|
---|
154 | set { PgeArchiveCapParameter.Value.Value = value; }
|
---|
155 | }
|
---|
156 |
|
---|
157 | private IFixedValueParameter<IntValue> PeelCntParameter {
|
---|
158 | get { return (IFixedValueParameter<IntValue>)Parameters[PeelCntParameterName]; }
|
---|
159 | }
|
---|
160 | public int PeelCnt {
|
---|
161 | get { return PeelCntParameter.Value.Value; }
|
---|
162 | set { PeelCntParameter.Value.Value = value; }
|
---|
163 | }
|
---|
164 |
|
---|
165 | private IFixedValueParameter<DoubleValue> ZeroEpsilonParameter {
|
---|
166 | get { return (IFixedValueParameter<DoubleValue>)Parameters[ZeroEpsilonParameterName]; }
|
---|
167 | }
|
---|
168 | public double ZeroEpsilon {
|
---|
169 | get { return ZeroEpsilonParameter.Value.Value; }
|
---|
170 | set { ZeroEpsilonParameter.Value.Value = value; }
|
---|
171 | }
|
---|
172 |
|
---|
173 | private IFixedValueParameter<DoubleValue> HitRatioParameter {
|
---|
174 | get { return (IFixedValueParameter<DoubleValue>)Parameters[HitRatioParameterName]; }
|
---|
175 | }
|
---|
176 | public double HitRatio {
|
---|
177 | get { return HitRatioParameter.Value.Value; }
|
---|
178 | set { HitRatioParameter.Value.Value = value; }
|
---|
179 | }
|
---|
180 |
|
---|
181 | private IFixedValueParameter<StringValue> InitMethodParameter {
|
---|
182 | get { return (IFixedValueParameter<StringValue>)Parameters[InitMethodParameterName]; }
|
---|
183 | }
|
---|
184 | public string InitMethod {
|
---|
185 | get { return InitMethodParameter.Value.Value; }
|
---|
186 | set { InitMethodParameter.Value.Value = value; }
|
---|
187 | }
|
---|
188 |
|
---|
189 | private IFixedValueParameter<StringValue> GrowMethodParameter {
|
---|
190 | get { return (IFixedValueParameter<StringValue>)Parameters[GrowMethodParameterName]; }
|
---|
191 | }
|
---|
192 | public string GrowMethod {
|
---|
193 | get { return GrowMethodParameter.Value.Value; }
|
---|
194 | set { GrowMethodParameter.Value.Value = value; }
|
---|
195 | }
|
---|
196 |
|
---|
197 | private IFixedValueParameter<StringValue> RootsParameter {
|
---|
198 | get { return (IFixedValueParameter<StringValue>)Parameters[RootsParameterName]; }
|
---|
199 | }
|
---|
200 | public string Roots {
|
---|
201 | get { return RootsParameter.Value.Value; }
|
---|
202 | set { RootsParameter.Value.Value = value; }
|
---|
203 | }
|
---|
204 |
|
---|
205 | private IFixedValueParameter<StringValue> NodesParameter {
|
---|
206 | get { return (IFixedValueParameter<StringValue>)Parameters[NodesParameterName]; }
|
---|
207 | }
|
---|
208 | public string Nodes {
|
---|
209 | get { return NodesParameter.Value.Value; }
|
---|
210 | set { NodesParameter.Value.Value = value; }
|
---|
211 | }
|
---|
212 |
|
---|
213 | private IFixedValueParameter<StringValue> NonTrigParameter {
|
---|
214 | get { return (IFixedValueParameter<StringValue>)Parameters[NonTrigParameterName]; }
|
---|
215 | }
|
---|
216 | public string NonTrig {
|
---|
217 | get { return NonTrigParameter.Value.Value; }
|
---|
218 | set { NonTrigParameter.Value.Value = value; }
|
---|
219 | }
|
---|
220 |
|
---|
221 | private IFixedValueParameter<StringValue> LeafsParameter {
|
---|
222 | get { return (IFixedValueParameter<StringValue>)Parameters[LeafsParameterName]; }
|
---|
223 | }
|
---|
224 | public string Leafs {
|
---|
225 | get { return LeafsParameter.Value.Value; }
|
---|
226 | set { LeafsParameter.Value.Value = value; }
|
---|
227 | }
|
---|
228 | #endregion
|
---|
229 |
|
---|
230 | public PGE() {
|
---|
231 |
|
---|
232 | base.Problem = new RegressionProblem();
|
---|
233 |
|
---|
234 | // algorithm parameters are shown in the GUI
|
---|
235 | Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, new IntValue(50)));
|
---|
236 | Parameters.Add(new FixedValueParameter<IntValue>(MinDepthParameterName, new IntValue(1)));
|
---|
237 | Parameters.Add(new FixedValueParameter<IntValue>(MaxDepthParameterName, new IntValue(6)));
|
---|
238 | Parameters.Add(new FixedValueParameter<IntValue>(MinSizeParameterName, new IntValue(4)));
|
---|
239 | Parameters.Add(new FixedValueParameter<IntValue>(MaxSizeParameterName, new IntValue(50)));
|
---|
240 | Parameters.Add(new FixedValueParameter<IntValue>(EvalrCountParameterName, new IntValue(2)));
|
---|
241 | Parameters.Add(new FixedValueParameter<IntValue>(PeelCntParameterName, new IntValue(3)));
|
---|
242 | Parameters.Add(new FixedValueParameter<IntValue>(PgeArchiveCapParameterName, new IntValue(256)));
|
---|
243 | Parameters.Add(new FixedValueParameter<IntValue>(PgeRptCountParameterName, new IntValue(20)));
|
---|
244 | Parameters.Add(new FixedValueParameter<IntValue>(PgeRptEpochParameterName, new IntValue(1)));
|
---|
245 | Parameters.Add(new FixedValueParameter<IntValue>(MaxGenParameterName, new IntValue(200)));
|
---|
246 |
|
---|
247 | Parameters.Add(new FixedValueParameter<StringValue>(InitMethodParameterName, new StringValue("method1"))); // TODO Dropdown
|
---|
248 | Parameters.Add(new FixedValueParameter<StringValue>(GrowMethodParameterName, new StringValue("method1")));
|
---|
249 |
|
---|
250 | Parameters.Add(new FixedValueParameter<StringValue>(RootsParameterName, new StringValue("Add"))); // TODO: checkeditemlist
|
---|
251 | Parameters.Add(new FixedValueParameter<StringValue>(NodesParameterName, new StringValue("Add Mul"))); // TODO: checkeditemlist
|
---|
252 | Parameters.Add(new FixedValueParameter<StringValue>(NonTrigParameterName, new StringValue("Add Mul"))); // TODO: checkeditemlist
|
---|
253 | Parameters.Add(new FixedValueParameter<StringValue>(LeafsParameterName, new StringValue("Var ConstantF")));
|
---|
254 |
|
---|
255 | Parameters.Add(new FixedValueParameter<DoubleValue>(ZeroEpsilonParameterName, new DoubleValue(0.00001)));
|
---|
256 | Parameters.Add(new FixedValueParameter<DoubleValue>(HitRatioParameterName, new DoubleValue(0.01)));
|
---|
257 | }
|
---|
258 |
|
---|
259 |
|
---|
260 | [StorableConstructor]
|
---|
261 | public PGE(bool deserializing) : base(deserializing) { }
|
---|
262 |
|
---|
263 |
|
---|
264 | public PGE(PGE original, Cloner cloner) : base(original, cloner) {
|
---|
265 | // nothing to clone
|
---|
266 | }
|
---|
267 |
|
---|
268 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
269 | return new PGE(this, cloner);
|
---|
270 | }
|
---|
271 |
|
---|
272 | protected override void Run(CancellationToken cancellationToken) {
|
---|
273 | Log log = new Log();
|
---|
274 | Results.Add(new Result("Log", log));
|
---|
275 | var iterationsResult = new IntValue(0);
|
---|
276 | Results.Add(new Result("Iteration", iterationsResult));
|
---|
277 | var bestMSEResult = new DoubleValue();
|
---|
278 | Results.Add(new Result("Best MSE", bestMSEResult));
|
---|
279 | var testScoresTable = new DataTable("Test scores");
|
---|
280 | var bestTestScoreRow = new DataRow("Best test score");
|
---|
281 | var curTestScoreRow = new DataRow("Current test score");
|
---|
282 | testScoresTable.Rows.Add(bestTestScoreRow);
|
---|
283 | testScoresTable.Rows.Add(curTestScoreRow);
|
---|
284 | Results.Add(new Result("Test scores", testScoresTable));
|
---|
285 | var lengthsTable = new DataTable("Lengths");
|
---|
286 | var len1Row = new DataRow("Length 1");
|
---|
287 | var len2Row = new DataRow("Length 2");
|
---|
288 | lengthsTable.Rows.Add(len1Row);
|
---|
289 | lengthsTable.Rows.Add(len2Row);
|
---|
290 | Results.Add(new Result("Lengths", lengthsTable));
|
---|
291 |
|
---|
292 |
|
---|
293 | var bestSolutionResult = new Result("Best solution", typeof(IRegressionSolution));
|
---|
294 | Results.Add(bestSolutionResult);
|
---|
295 |
|
---|
296 | var allSolutions = new ItemList<IRegressionSolution>();
|
---|
297 | var allSolutionsResult = new Result("Solutions", allSolutions);
|
---|
298 | Results.Add(allSolutionsResult);
|
---|
299 |
|
---|
300 | // TODO: the following is potentially problematic for other go processes run on the same machine at the same time
|
---|
301 | // shouldn't be problematic bc is inherited only, normally only child processes are affected
|
---|
302 | Environment.SetEnvironmentVariable("GOGC", "off");
|
---|
303 | Environment.SetEnvironmentVariable("GODEBUG", "cgocheck=0");
|
---|
304 | Environment.SetEnvironmentVariable("CGO_ENABLED", "1");
|
---|
305 | Environment.SetEnvironmentVariable("PGEDEBUG", "0");
|
---|
306 |
|
---|
307 |
|
---|
308 | //Constants
|
---|
309 | int sortType = 0; // TODO what's sort type? //
|
---|
310 | //1 = PESORT_PARETO_TRN_ERR
|
---|
311 | //0 = PESORT_PARETO_TST_ERR
|
---|
312 |
|
---|
313 | string problemTypeString = "benchmark";
|
---|
314 | int numProc = 12;
|
---|
315 | string problemName = Problem.ProblemData.Name;
|
---|
316 |
|
---|
317 |
|
---|
318 | var problemData = Problem.ProblemData;
|
---|
319 | var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
|
---|
320 | // no idea why the following are IntPtr, this should not be necessary for marshalling, it should be ok to just send the double[,]
|
---|
321 |
|
---|
322 | double[] trainData = GetData(problemData.Dataset, variables, problemData.TrainingIndices);
|
---|
323 | double[] testData = GetData(problemData.Dataset, variables, problemData.TestIndices);
|
---|
324 |
|
---|
325 | int nTrainData = Problem.ProblemData.TrainingPartition.Size;
|
---|
326 | int nTestData = Problem.ProblemData.TestPartition.Size;
|
---|
327 |
|
---|
328 | if (problemData.AllowedInputVariables.Any(iv => iv.Contains(" ")))
|
---|
329 | throw new NotSupportedException("PGE does not support variable names which contain spaces");
|
---|
330 |
|
---|
331 | var inputVariableNames = string.Join(" ", problemData.AllowedInputVariables);
|
---|
332 |
|
---|
333 | AddTestData(inputVariableNames, problemData.TargetVariable, testData, nTestData);
|
---|
334 | AddTrainData(inputVariableNames, problemData.TargetVariable, trainData, nTrainData);
|
---|
335 |
|
---|
336 | int numberOfUseableVariables = problemData.AllowedInputVariables.Count();
|
---|
337 |
|
---|
338 | InitSearch(MaxGen, PgeRptEpoch, PgeRptCount, PgeArchiveCap, PeelCnt, EvalrCount, ZeroEpsilon, InitMethod, GrowMethod, sortType);
|
---|
339 |
|
---|
340 | // cUsableVars: list of indices into independent variables
|
---|
341 | InitTreeParams(Roots, Nodes, NonTrig, Leafs, numberOfUseableVariables, MaxSize, MinSize, MaxDepth, MinDepth);
|
---|
342 |
|
---|
343 | InitProblem(Name, MaxIterations, HitRatio,
|
---|
344 | searchVar: 0, // SearchVar: index of dependent variables (this is always zero because we only have one target variable)
|
---|
345 | ProblemTypeString: problemTypeString, numProcs: numProc);
|
---|
346 |
|
---|
347 | var bestMSE = double.MaxValue;
|
---|
348 | for (int iter = 1; iter <= MaxIterations; iter++) {
|
---|
349 | iterationsResult.Value = iter;
|
---|
350 |
|
---|
351 | int nResults = StepW();
|
---|
352 |
|
---|
353 | for (int iResult = 0; iResult < nResults; iResult++) {
|
---|
354 | int nCoeff = 0;
|
---|
355 | int testScore = 0;
|
---|
356 |
|
---|
357 | IntPtr eqn = GetStepResult(out testScore, out nCoeff);
|
---|
358 | string eqnStr = Marshal.PtrToStringAnsi(eqn);
|
---|
359 |
|
---|
360 | double[] coeff = new double[nCoeff];
|
---|
361 | for (int iCoeff = 0; iCoeff < nCoeff; iCoeff++) {
|
---|
362 | coeff[iCoeff] = GetCoeffResult();
|
---|
363 | }
|
---|
364 | log.LogMessage("Push/Pop (" + iResult + ", " + testScore + ") " + eqnStr + " coeff: " + string.Join(" ", coeff));
|
---|
365 |
|
---|
366 | if (!string.IsNullOrEmpty(eqnStr)) {
|
---|
367 | var sol = CreateSolution(problemData, eqnStr, coeff, problemData.AllowedInputVariables.ToArray());
|
---|
368 | allSolutions.Add(sol);
|
---|
369 | if (sol.TrainingMeanSquaredError < bestMSE) {
|
---|
370 | // update best quality
|
---|
371 | bestMSE = sol.TrainingMeanSquaredError;
|
---|
372 | bestMSEResult.Value = bestMSE;
|
---|
373 | bestSolutionResult.Value = sol;
|
---|
374 | }
|
---|
375 | }
|
---|
376 | bestTestScoreRow.Values.Add(bestMSEResult.Value); // always add the current best test score to data row
|
---|
377 | curTestScoreRow.Values.Add(testScore);
|
---|
378 | }
|
---|
379 |
|
---|
380 | if (cancellationToken.IsCancellationRequested) break;
|
---|
381 | }
|
---|
382 |
|
---|
383 | // Results.Add(new Result("Execution time", new TimeSpanValue(this.ExecutionTime)));
|
---|
384 | }
|
---|
385 |
|
---|
386 | private static readonly Regex varRegex = new Regex(@"X_(\d)+");
|
---|
387 | private static readonly Regex coeffRegex = new Regex(@"C_(\d)+");
|
---|
388 |
|
---|
389 | private IRegressionSolution CreateSolution(IRegressionProblemData problemData, string eqnStr, double[] coeff, string[] usableVariables) {
|
---|
390 | // coefficients are named e.g. "C_0" in the PGE expressions
|
---|
391 | // -> replace all patterns "C_\d" by the corresponding coefficients
|
---|
392 | var match = coeffRegex.Match(eqnStr);
|
---|
393 | while (match.Success) {
|
---|
394 | var coeffIdx = int.Parse(match.Groups[1].ToString());
|
---|
395 | eqnStr = eqnStr.Substring(0, match.Index) +
|
---|
396 | "(" + coeff[coeffIdx].ToString(System.Globalization.CultureInfo.InvariantCulture) + ")" +
|
---|
397 | eqnStr.Substring(match.Index + match.Length);
|
---|
398 | match = coeffRegex.Match(eqnStr);
|
---|
399 | }
|
---|
400 |
|
---|
401 | // variables are named e.g. "X_0" in the PGE expressions
|
---|
402 | // -> replace all patterns "X_\d" by the corresponding variable name
|
---|
403 | match = varRegex.Match(eqnStr);
|
---|
404 | while (match.Success) {
|
---|
405 | var varIdx = int.Parse(match.Groups[1].ToString());
|
---|
406 | eqnStr = eqnStr.Substring(0, match.Index) +
|
---|
407 | "'" + usableVariables[varIdx] + "'" +
|
---|
408 | eqnStr.Substring(match.Index + match.Length);
|
---|
409 | match = varRegex.Match(eqnStr);
|
---|
410 | }
|
---|
411 |
|
---|
412 | var parser = new InfixExpressionParser();
|
---|
413 | var tree = parser.Parse(eqnStr);
|
---|
414 | var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
|
---|
415 | return model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
|
---|
416 | }
|
---|
417 |
|
---|
418 | public override bool SupportsPause {
|
---|
419 | get { return false; }
|
---|
420 | }
|
---|
421 |
|
---|
422 | private static double[] GetData(IDataset ds, IEnumerable<string> variableNames, IEnumerable<int> rows) {
|
---|
423 | var dim = variableNames.Count();
|
---|
424 | double[] val = new double[rows.Count() * dim];
|
---|
425 | int r = 0;
|
---|
426 | foreach (var row in rows) {
|
---|
427 | int c = 0;
|
---|
428 | foreach (var var in variableNames) {
|
---|
429 | val[r * dim + c] = ds.GetDoubleValue(var, r);
|
---|
430 | c++;
|
---|
431 | }
|
---|
432 | r++;
|
---|
433 | }
|
---|
434 | return val;
|
---|
435 | }
|
---|
436 | }
|
---|
437 | }
|
---|