Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2929_PrioritizedGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.PGE/3.3/PGE.cs @ 16614

Last change on this file since 16614 was 16614, checked in by hmaislin, 5 years ago

#2929: Fixed result bug, added multitest C / Powershell, passed pagie_1

File size: 21.4 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Runtime.InteropServices;
5using System.Threading;
6using HeuristicLab.Analysis;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using System.Text.RegularExpressions;
11using HeuristicLab.Optimization;
12using HeuristicLab.Parameters;
13using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
14using HeuristicLab.Problems.DataAnalysis;
15using HeuristicLab.Problems.DataAnalysis.Symbolic;
16using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
17
18namespace PGE {
19  [Item(Name = "Priorizied Grammar Enumeration (PGE)", Description = "Priorizied grammar enumeration algorithm. Worm, T. and Chiu K., 'Prioritized Grammar Enumeration: Symbolic Regression by Dynamic Programming'. GECCO 2013")]
20
21  [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
22
23  [StorableClass]
24  public unsafe class PGE : BasicAlgorithm {
25
26    [DllImport("go-pge.dll", EntryPoint = "addTestData", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
27    public static extern void AddTestData(IntPtr indepNames, IntPtr depndNames, IntPtr matrix, int nEntries);
28
29    [DllImport("go-pge.dll", EntryPoint = "addTrainData", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
30    public static extern void AddTrainData(IntPtr indepNames, IntPtr depndNames, IntPtr matrix, int nEntries);
31
32    [DllImport("go-pge.dll", EntryPoint = "initSearch", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
33    public static extern void InitSearch(int maxGen, int pgeRptEpoch, int pgeRptCount, int pgeArchiveCap, int peelCnt, int evalrCount, double zeroEpsilon, IntPtr initMethod, IntPtr growMethod, int sortType);
34
35    [DllImport("go-pge.dll", EntryPoint = "initTreeParams", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
36    public static extern void InitTreeParams(IntPtr roots, IntPtr nodes, IntPtr nonTrig, IntPtr leafs, int numUsableVars, int maxSize, int minSize, int maxDepth, int minDepth);
37
38    [DllImport("go-pge.dll", EntryPoint = "initProblem", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
39    public static extern void InitProblem(IntPtr name, int maxIter, double hitRatio, int searchVar, IntPtr ProblemTypeString, int numProcs);
40
41    [DllImport("go-pge.dll", EntryPoint = "stepW", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
42    public static extern int StepW();
43
44    [DllImport("go-pge.dll", EntryPoint = "getStepResult", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
45    public static extern IntPtr GetStepResult(out int testscore, out int nCoeff);
46
47    [DllImport("go-pge.dll", EntryPoint = "getCoeffResult", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
48    public static extern double GetCoeffResult();
49
50    public override Type ProblemType { get { return typeof(RegressionProblem); } }
51    public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
52
53    #region parameter names
54    private static readonly string MaxIterationsParameterName = "MaxIterations";
55    private static readonly string MaxGenParameterName = "MaxGen";
56    private static readonly string EvalrCountParameterName = "EvalrCount";
57    private static readonly string MaxSizeParameterName = "MaxSize";
58    private static readonly string MinSizeParameterName = "MinSize";
59    private static readonly string MaxDepthParameterName = "MaxDepth";
60    private static readonly string MinDepthParameterName = "MinDepth";
61    private static readonly string PgeRptEpochParameterName = "PgeRptEpoch";
62    private static readonly string PgeRptCountParameterName = "PgeRptCount";
63    private static readonly string PgeArchiveCapParameterName = "PgeArchiveCap";
64    private static readonly string PeelCntParameterName = "PeelCnt";
65    private static readonly string ZeroEpsilonParameterName = "ZeroEpsilon";
66    private static readonly string HitRatioParameterName = "HitRatio";
67    private static readonly string InitMethodParameterName = "InitMethod";
68    private static readonly string GrowMethodParameterName = "GrowMethod";
69    private static readonly string RootsParameterName = "Roots";
70    private static readonly string NodesParameterName = "Nodes";
71    private static readonly string NonTrigParameterName = "NonTrig";
72    private static readonly string LeafsParameterName = "Leafs";
73
74    #endregion
75
76    #region parameters                                           
77    private IFixedValueParameter<IntValue> MaxIterationsParameter {
78      get { return (IFixedValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }
79    }
80    public int MaxIterations {
81      get { return MaxIterationsParameter.Value.Value; }
82      set { MaxIterationsParameter.Value.Value = value; }
83    }
84
85    private IFixedValueParameter<IntValue> MaxGenParameter {
86      get { return (IFixedValueParameter<IntValue>)Parameters[MaxGenParameterName]; }
87    }
88    public int MaxGen {
89      get { return MaxGenParameter.Value.Value; }
90      set { MaxGenParameter.Value.Value = value; }
91    }
92
93    private IFixedValueParameter<IntValue> EvalrCountParameter {
94      get { return (IFixedValueParameter<IntValue>)Parameters[EvalrCountParameterName]; }
95    }
96    public int EvalrCount {
97      get { return EvalrCountParameter.Value.Value; }
98      set { EvalrCountParameter.Value.Value = value; }
99    }
100
101    private IFixedValueParameter<IntValue> MaxSizeParameter {
102      get { return (IFixedValueParameter<IntValue>)Parameters[MaxSizeParameterName]; }
103    }
104    public int MaxSize {
105      get { return MaxSizeParameter.Value.Value; }
106      set { MaxSizeParameter.Value.Value = value; }
107    }
108
109    private IFixedValueParameter<IntValue> MinSizeParameter {
110      get { return (IFixedValueParameter<IntValue>)Parameters[MinSizeParameterName]; }
111    }
112    public int MinSize {
113      get { return MinSizeParameter.Value.Value; }
114      set { MinSizeParameter.Value.Value = value; }
115    }
116
117    private IFixedValueParameter<IntValue> MaxDepthParameter {
118      get { return (IFixedValueParameter<IntValue>)Parameters[MaxDepthParameterName]; }
119    }
120    public int MaxDepth {
121      get { return MaxDepthParameter.Value.Value; }
122      set { MaxDepthParameter.Value.Value = value; }
123    }
124
125    private IFixedValueParameter<IntValue> MinDepthParameter {
126      get { return (IFixedValueParameter<IntValue>)Parameters[MinDepthParameterName]; }
127    }
128    public int MinDepth {
129      get { return MinDepthParameter.Value.Value; }
130      set { MinDepthParameter.Value.Value = value; }
131    }
132
133    private IFixedValueParameter<IntValue> PgeRptEpochParameter {
134      get { return (IFixedValueParameter<IntValue>)Parameters[PgeRptEpochParameterName]; }
135    }
136    public int PgeRptEpoch {
137      get { return PgeRptEpochParameter.Value.Value; }
138      set { PgeRptEpochParameter.Value.Value = value; }
139    }
140
141    private IFixedValueParameter<IntValue> PgeRptCountParameter {
142      get { return (IFixedValueParameter<IntValue>)Parameters[PgeRptCountParameterName]; }
143    }
144    public int PgeRptCount {
145      get { return PgeRptCountParameter.Value.Value; }
146      set { PgeRptCountParameter.Value.Value = value; }
147    }
148
149    private IFixedValueParameter<IntValue> PgeArchiveCapParameter {
150      get { return (IFixedValueParameter<IntValue>)Parameters[PgeArchiveCapParameterName]; }
151    }
152    public int PgeArchiveCap {
153      get { return PgeArchiveCapParameter.Value.Value; }
154      set { PgeArchiveCapParameter.Value.Value = value; }
155    }
156
157    private IFixedValueParameter<IntValue> PeelCntParameter {
158      get { return (IFixedValueParameter<IntValue>)Parameters[PeelCntParameterName]; }
159    }
160    public int PeelCnt {
161      get { return PeelCntParameter.Value.Value; }
162      set { PeelCntParameter.Value.Value = value; }
163    }
164
165    private IFixedValueParameter<DoubleValue> ZeroEpsilonParameter {
166      get { return (IFixedValueParameter<DoubleValue>)Parameters[ZeroEpsilonParameterName]; }
167    }
168    public double ZeroEpsilon {
169      get { return ZeroEpsilonParameter.Value.Value; }
170      set { ZeroEpsilonParameter.Value.Value = value; }
171    }
172
173    private IFixedValueParameter<DoubleValue> HitRatioParameter {
174      get { return (IFixedValueParameter<DoubleValue>)Parameters[HitRatioParameterName]; }
175    }
176    public double HitRatio {
177      get { return HitRatioParameter.Value.Value; }
178      set { HitRatioParameter.Value.Value = value; }
179    }
180
181    private IFixedValueParameter<StringValue> InitMethodParameter {
182      get { return (IFixedValueParameter<StringValue>)Parameters[InitMethodParameterName]; }
183    }
184    public string InitMethod {
185      get { return InitMethodParameter.Value.Value; }
186      set { InitMethodParameter.Value.Value = value; }
187    }
188
189    private IFixedValueParameter<StringValue> GrowMethodParameter {
190      get { return (IFixedValueParameter<StringValue>)Parameters[GrowMethodParameterName]; }
191    }
192    public string GrowMethod {
193      get { return GrowMethodParameter.Value.Value; }
194      set { GrowMethodParameter.Value.Value = value; }
195    }
196
197    private IFixedValueParameter<StringValue> RootsParameter {
198      get { return (IFixedValueParameter<StringValue>)Parameters[RootsParameterName]; }
199    }
200    public string Roots {
201      get { return RootsParameter.Value.Value; }
202      set { RootsParameter.Value.Value = value; }
203    }
204
205    private IFixedValueParameter<StringValue> NodesParameter {
206      get { return (IFixedValueParameter<StringValue>)Parameters[NodesParameterName]; }
207    }
208    public string Nodes {
209      get { return NodesParameter.Value.Value; }
210      set { NodesParameter.Value.Value = value; }
211    }
212
213    private IFixedValueParameter<StringValue> NonTrigParameter {
214      get { return (IFixedValueParameter<StringValue>)Parameters[NonTrigParameterName]; }
215    }
216    public string NonTrig {
217      get { return NonTrigParameter.Value.Value; }
218      set { NonTrigParameter.Value.Value = value; }
219    }
220
221    private IFixedValueParameter<StringValue> LeafsParameter {
222      get { return (IFixedValueParameter<StringValue>)Parameters[LeafsParameterName]; }
223    }
224    public string Leafs {
225      get { return LeafsParameter.Value.Value; }
226      set { LeafsParameter.Value.Value = value; }
227    }
228    #endregion
229
230    public PGE() {
231
232      base.Problem = new RegressionProblem();
233
234      // algorithm parameters are shown in the GUI
235      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, new IntValue(50)));
236      Parameters.Add(new FixedValueParameter<IntValue>(MinDepthParameterName, new IntValue(1)));
237      Parameters.Add(new FixedValueParameter<IntValue>(MaxDepthParameterName, new IntValue(6)));
238      Parameters.Add(new FixedValueParameter<IntValue>(MinSizeParameterName, new IntValue(4)));
239      Parameters.Add(new FixedValueParameter<IntValue>(MaxSizeParameterName, new IntValue(50)));
240      Parameters.Add(new FixedValueParameter<IntValue>(EvalrCountParameterName, new IntValue(2)));
241      Parameters.Add(new FixedValueParameter<IntValue>(PeelCntParameterName, new IntValue(3)));
242      Parameters.Add(new FixedValueParameter<IntValue>(PgeArchiveCapParameterName, new IntValue(256)));
243      Parameters.Add(new FixedValueParameter<IntValue>(PgeRptCountParameterName, new IntValue(20)));
244      Parameters.Add(new FixedValueParameter<IntValue>(PgeRptEpochParameterName, new IntValue(1)));
245      Parameters.Add(new FixedValueParameter<IntValue>(MaxGenParameterName, new IntValue(200)));
246
247      Parameters.Add(new FixedValueParameter<StringValue>(InitMethodParameterName, new StringValue("method1")));  // TODO Dropdown
248      Parameters.Add(new FixedValueParameter<StringValue>(GrowMethodParameterName, new StringValue("method1")));
249
250      Parameters.Add(new FixedValueParameter<StringValue>(RootsParameterName, new StringValue("Add")));    // TODO: checkeditemlist
251      Parameters.Add(new FixedValueParameter<StringValue>(NodesParameterName, new StringValue("Add Mul")));  // TODO: checkeditemlist
252      Parameters.Add(new FixedValueParameter<StringValue>(NonTrigParameterName, new StringValue("Add Mul"))); // TODO: checkeditemlist
253      Parameters.Add(new FixedValueParameter<StringValue>(LeafsParameterName, new StringValue("Var ConstantF")));
254
255      Parameters.Add(new FixedValueParameter<DoubleValue>(ZeroEpsilonParameterName, new DoubleValue(0.00001)));
256      Parameters.Add(new FixedValueParameter<DoubleValue>(HitRatioParameterName, new DoubleValue(0.01)));
257    }
258
259
260    [StorableConstructor]
261    public PGE(bool deserializing) : base(deserializing) { }
262
263
264    public PGE(PGE original, Cloner cloner) : base(original, cloner) {
265      // nothing to clone
266    }
267
268    public override IDeepCloneable Clone(Cloner cloner) {
269      return new PGE(this, cloner);
270    }
271
272    protected override void Run(CancellationToken cancellationToken) {
273      Log log = new Log();
274      Results.Add(new Result("Log", log));
275      var iterationsResult = new IntValue(0);
276      Results.Add(new Result("Iteration", iterationsResult));
277      var bestTestScoreResult = new IntValue(0); // TODO: why is test score an int?
278      Results.Add(new Result("Best test score", bestTestScoreResult));
279      var testScoresTable = new DataTable("Test scores");
280      var bestTestScoreRow = new DataRow("Best test score");
281      var curTestScoreRow = new DataRow("Current test score");
282      testScoresTable.Rows.Add(bestTestScoreRow);
283      testScoresTable.Rows.Add(curTestScoreRow);
284      Results.Add(new Result("Test scores", testScoresTable));
285      var lengthsTable = new DataTable("Lengths");
286      var len1Row = new DataRow("Length 1");
287      var len2Row = new DataRow("Length 2");
288      lengthsTable.Rows.Add(len1Row);
289      lengthsTable.Rows.Add(len2Row);
290      Results.Add(new Result("Lengths", lengthsTable));
291
292
293      var bestSolutionResult = new Result("Best solution", typeof(IRegressionSolution));
294      Results.Add(bestSolutionResult);
295
296      // TODO: the following is potentially problematic for other go processes run on the same machine at the same time
297      // shouldn't be problematic bc is inherited only, normally only child processes are affected
298      Environment.SetEnvironmentVariable("GOGC", "off");
299      Environment.SetEnvironmentVariable("GODEBUG", "cgocheck=0");
300      Environment.SetEnvironmentVariable("CGO_ENABLED", "1");
301      Environment.SetEnvironmentVariable("PGEDEBUG", "0");
302
303
304      //Constants
305      int sortType = 0; // TODO what's sort type?
306      string problemTypeString = "benchmark";
307      int numProc = 12;
308      string problemName = Problem.ProblemData.Name;
309
310
311      var problemData = Problem.ProblemData;
312      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
313      // no idea why the following are IntPtr, this should not be necessary for marshalling, it should be ok to just send the double[,]
314      int nTrainData;
315      int nTestData;
316      IntPtr trainData = GetData(problemData.Dataset, variables, problemData.TrainingIndices, out nTrainData);
317      IntPtr testData = GetData(problemData.Dataset, variables, problemData.TestIndices, out nTestData);
318
319      nTrainData = Problem.ProblemData.TrainingPartition.Size;
320      nTestData = Problem.ProblemData.TestPartition.Size;
321
322      if (problemData.AllowedInputVariables.Any(iv => iv.Contains(" ")))
323        throw new NotSupportedException("PGE does not support variable names which contain spaces");
324
325      var inputVariableNames = string.Join(" ", problemData.AllowedInputVariables);
326
327      IntPtr cIndepNames = Marshal.StringToHGlobalAnsi(inputVariableNames);
328      IntPtr cDependentNames = Marshal.StringToHGlobalAnsi(problemData.TargetVariable);
329      // Dependent- and Independentnames are the variables from the test/train data, e.g. from "Korns_02.trn" indep: x y z v w  dep: f(xs)
330
331      IntPtr cInitMethod = Marshal.StringToHGlobalAnsi(InitMethod);
332      IntPtr cGrowMethod = Marshal.StringToHGlobalAnsi(GrowMethod);
333
334      IntPtr cRoots = Marshal.StringToHGlobalAnsi(Roots);
335      IntPtr cNodes = Marshal.StringToHGlobalAnsi(Nodes);
336      IntPtr cNonTrig = Marshal.StringToHGlobalAnsi(NonTrig);
337      IntPtr cLeafs = Marshal.StringToHGlobalAnsi(Leafs);
338
339      IntPtr cName = Marshal.StringToHGlobalAnsi(problemName);
340      IntPtr cProblemTypeString = Marshal.StringToHGlobalAnsi(problemTypeString);
341
342
343      AddTestData(cIndepNames, cDependentNames, testData, nTestData);
344      AddTrainData(cIndepNames, cDependentNames, trainData, nTrainData);
345
346      int numberOfUseableVariables = problemData.AllowedInputVariables.Count();
347
348      InitSearch(MaxGen, PgeRptEpoch, PgeRptCount, PgeArchiveCap, PeelCnt, EvalrCount, ZeroEpsilon, cInitMethod, cGrowMethod, sortType);
349
350      // cUsableVars: list of indices into independent variables
351      InitTreeParams(cRoots, cNodes, cNonTrig, cLeafs, numberOfUseableVariables, MaxSize, MinSize, MaxDepth, MinDepth);
352
353      InitProblem(cName, MaxIterations, HitRatio,
354        searchVar: 0,  // SearchVar: index of dependent variables (this is always zero because we only have one target variable)
355        ProblemTypeString: cProblemTypeString, numProcs: numProc);
356
357      var bestTestScore = int.MaxValue;
358      for (int iter = 1; iter <= MaxIterations; iter++) {
359        iterationsResult.Value = iter;
360
361        int nResults = StepW();
362
363        for (int iResult = 0; iResult < nResults; iResult++) {
364          int nCoeff = 0;
365          int testScore = 0;
366
367          IntPtr eqn = GetStepResult(out testScore, out nCoeff);
368          string eqnStr = Marshal.PtrToStringAnsi(eqn);
369
370          double[] coeff = new double[nCoeff];
371          for (int iCoeff = 0; iCoeff < nCoeff; iCoeff++) {
372            coeff[iCoeff] = GetCoeffResult();
373          }
374          log.LogMessage("Push/Pop (" + iResult + ", " + testScore + ") " + eqnStr + " coeff: " + string.Join(" ", coeff));
375
376          if (!string.IsNullOrEmpty(eqnStr) && (testScore < bestTestScore)) {
377            // update best quality
378            bestTestScore = testScore;
379            bestTestScoreResult.Value = testScore;
380            var sol = CreateSolution(problemData, eqnStr, coeff, problemData.AllowedInputVariables.ToArray());
381            bestSolutionResult.Value = sol;
382          }
383          bestTestScoreRow.Values.Add(bestTestScoreResult.Value); // always add the current best test score to data row
384          curTestScoreRow.Values.Add(testScore);
385        }
386
387        if (cancellationToken.IsCancellationRequested) break;
388      }
389
390      Marshal.FreeHGlobal(trainData);
391      Marshal.FreeHGlobal(testData);
392
393      Marshal.FreeHGlobal(cIndepNames);
394      Marshal.FreeHGlobal(cDependentNames);
395
396      Marshal.FreeHGlobal(cInitMethod);
397      Marshal.FreeHGlobal(cGrowMethod);
398
399      Marshal.FreeHGlobal(cRoots);
400      Marshal.FreeHGlobal(cNodes);
401      Marshal.FreeHGlobal(cNonTrig);
402      Marshal.FreeHGlobal(cLeafs);
403
404      Marshal.FreeHGlobal(cName);
405      Marshal.FreeHGlobal(cProblemTypeString);
406
407      // Results.Add(new Result("Execution time", new TimeSpanValue(this.ExecutionTime)));
408    }
409
410    private static readonly Regex varRegex = new Regex(@"X_(\d)+");
411    private static readonly Regex coeffRegex = new Regex(@"C_(\d)+");
412
413    private IRegressionSolution CreateSolution(IRegressionProblemData problemData, string eqnStr, double[] coeff, string[] usableVariables) {
414      // coefficients are named e.g. "C_0" in the PGE expressions
415      // -> replace all patterns "C_\d" by the corresponding coefficients
416      var match = coeffRegex.Match(eqnStr);
417      while (match.Success) {
418        var coeffIdx = int.Parse(match.Groups[1].ToString());
419        eqnStr = eqnStr.Substring(0, match.Index) +
420          "(" + coeff[coeffIdx].ToString(System.Globalization.CultureInfo.InvariantCulture) + ")" +
421          eqnStr.Substring(match.Index + match.Length);
422        match = coeffRegex.Match(eqnStr);
423      }
424
425      // variables are named e.g. "X_0" in the PGE expressions
426      // -> replace all patterns "X_\d" by the corresponding variable name
427      match = varRegex.Match(eqnStr);
428      while (match.Success) {
429        var varIdx = int.Parse(match.Groups[1].ToString());
430        eqnStr = eqnStr.Substring(0, match.Index) +
431          "'" + usableVariables[varIdx] + "'" +
432          eqnStr.Substring(match.Index + match.Length);
433        match = varRegex.Match(eqnStr);
434      }
435
436      var parser = new InfixExpressionParser();
437      var tree = parser.Parse(eqnStr);
438      var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
439      return model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
440    }
441
442    public override bool SupportsPause {
443      get { return false; }
444    }
445
446    private static IntPtr GetData(IDataset ds, IEnumerable<string> variableNames, IEnumerable<int> rows, out int n) {
447
448      var dim = variableNames.Count();
449      double[] val = new double[rows.Count() * dim];
450      int r = 0;
451      foreach (var row in rows) {
452        int c = 0;
453        foreach (var var in variableNames) {
454          val[r * dim + c] = ds.GetDoubleValue(var, r);
455          c++;
456        }
457        r++;
458      }
459
460      n = val.Length;
461
462      // TODO: seems strange to marshal this explicitly, we can just send the data over to PGE
463      IntPtr data = Marshal.AllocHGlobal(sizeof(double) * val.Length);
464      Marshal.Copy(val, 0, data, val.Length);
465
466      return data;
467    }
468  }
469}
Note: See TracBrowser for help on using the repository browser.