Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2929_PrioritizedGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.PGE/3.3/PGE.cs @ 16620

Last change on this file since 16620 was 16620, checked in by hmaislin, 5 years ago

#2929: Reorganized folder structure for make script, removed explicit marshalling, erased go-side logging

File size: 20.3 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Runtime.InteropServices;
5using System.Threading;
6using HeuristicLab.Analysis;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using System.Text.RegularExpressions;
11using HeuristicLab.Optimization;
12using HeuristicLab.Parameters;
13using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
14using HeuristicLab.Problems.DataAnalysis;
15using HeuristicLab.Problems.DataAnalysis.Symbolic;
16using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
17
18namespace PGE {
19  [Item(Name = "Priorizied Grammar Enumeration (PGE)", Description = "Priorizied grammar enumeration algorithm. Worm, T. and Chiu K., 'Prioritized Grammar Enumeration: Symbolic Regression by Dynamic Programming'. GECCO 2013")]
20
21  [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
22
23  [StorableClass]
24  public unsafe class PGE : BasicAlgorithm {
25
26    [DllImport("go-pge.dll", EntryPoint = "addTestData", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
27    public static extern void AddTestData([MarshalAs(UnmanagedType.AnsiBStr)] string indepNames, [MarshalAs(UnmanagedType.AnsiBStr)] string depndNames, double[] matrix, int nEntries);
28
29    [DllImport("go-pge.dll", EntryPoint = "addTrainData", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
30    public static extern void AddTrainData([MarshalAs(UnmanagedType.AnsiBStr)] string indepNames, [MarshalAs(UnmanagedType.AnsiBStr)] string depndNames, double[] matrix, int nEntries);
31
32    [DllImport("go-pge.dll", EntryPoint = "initSearch", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
33    public static extern void InitSearch(int maxGen, int pgeRptEpoch, int pgeRptCount, int pgeArchiveCap, int peelCnt, int evalrCount, double zeroEpsilon, [MarshalAs(UnmanagedType.AnsiBStr)] string initMethod, [MarshalAs(UnmanagedType.AnsiBStr)] string growMethod, int sortType);
34
35    [DllImport("go-pge.dll", EntryPoint = "initTreeParams", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
36    public static extern void InitTreeParams([MarshalAs(UnmanagedType.AnsiBStr)] string roots, [MarshalAs(UnmanagedType.AnsiBStr)] string nodes, [MarshalAs(UnmanagedType.AnsiBStr)] string nonTrig, [MarshalAs(UnmanagedType.AnsiBStr)] string leafs, int numUsableVars, int maxSize, int minSize, int maxDepth, int minDepth);
37
38    [DllImport("go-pge.dll", EntryPoint = "initProblem", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
39    public static extern void InitProblem([MarshalAs(UnmanagedType.AnsiBStr)] string name, int maxIter, double hitRatio, int searchVar, [MarshalAs(UnmanagedType.AnsiBStr)] string ProblemTypeString, int numProcs);
40
41    [DllImport("go-pge.dll", EntryPoint = "stepW", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
42    public static extern int StepW();
43
44    [DllImport("go-pge.dll", EntryPoint = "getStepResult", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
45    public static extern IntPtr GetStepResult(out int testscore, out int nCoeff);
46
47    [DllImport("go-pge.dll", EntryPoint = "getCoeffResult", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
48    public static extern double GetCoeffResult();
49
50    public override Type ProblemType { get { return typeof(RegressionProblem); } }
51    public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
52
53    #region parameter names
54    private static readonly string MaxIterationsParameterName = "MaxIterations";
55    private static readonly string MaxGenParameterName = "MaxGen";
56    private static readonly string EvalrCountParameterName = "EvalrCount";
57    private static readonly string MaxSizeParameterName = "MaxSize";
58    private static readonly string MinSizeParameterName = "MinSize";
59    private static readonly string MaxDepthParameterName = "MaxDepth";
60    private static readonly string MinDepthParameterName = "MinDepth";
61    private static readonly string PgeRptEpochParameterName = "PgeRptEpoch";
62    private static readonly string PgeRptCountParameterName = "PgeRptCount";
63    private static readonly string PgeArchiveCapParameterName = "PgeArchiveCap";
64    private static readonly string PeelCntParameterName = "PeelCnt";
65    private static readonly string ZeroEpsilonParameterName = "ZeroEpsilon";
66    private static readonly string HitRatioParameterName = "HitRatio";
67    private static readonly string InitMethodParameterName = "InitMethod";
68    private static readonly string GrowMethodParameterName = "GrowMethod";
69    private static readonly string RootsParameterName = "Roots";
70    private static readonly string NodesParameterName = "Nodes";
71    private static readonly string NonTrigParameterName = "NonTrig";
72    private static readonly string LeafsParameterName = "Leafs";
73
74    #endregion
75
76    #region parameters                                           
77    private IFixedValueParameter<IntValue> MaxIterationsParameter {
78      get { return (IFixedValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }
79    }
80    public int MaxIterations {
81      get { return MaxIterationsParameter.Value.Value; }
82      set { MaxIterationsParameter.Value.Value = value; }
83    }
84
85    private IFixedValueParameter<IntValue> MaxGenParameter {
86      get { return (IFixedValueParameter<IntValue>)Parameters[MaxGenParameterName]; }
87    }
88    public int MaxGen {
89      get { return MaxGenParameter.Value.Value; }
90      set { MaxGenParameter.Value.Value = value; }
91    }
92
93    private IFixedValueParameter<IntValue> EvalrCountParameter {
94      get { return (IFixedValueParameter<IntValue>)Parameters[EvalrCountParameterName]; }
95    }
96    public int EvalrCount {
97      get { return EvalrCountParameter.Value.Value; }
98      set { EvalrCountParameter.Value.Value = value; }
99    }
100
101    private IFixedValueParameter<IntValue> MaxSizeParameter {
102      get { return (IFixedValueParameter<IntValue>)Parameters[MaxSizeParameterName]; }
103    }
104    public int MaxSize {
105      get { return MaxSizeParameter.Value.Value; }
106      set { MaxSizeParameter.Value.Value = value; }
107    }
108
109    private IFixedValueParameter<IntValue> MinSizeParameter {
110      get { return (IFixedValueParameter<IntValue>)Parameters[MinSizeParameterName]; }
111    }
112    public int MinSize {
113      get { return MinSizeParameter.Value.Value; }
114      set { MinSizeParameter.Value.Value = value; }
115    }
116
117    private IFixedValueParameter<IntValue> MaxDepthParameter {
118      get { return (IFixedValueParameter<IntValue>)Parameters[MaxDepthParameterName]; }
119    }
120    public int MaxDepth {
121      get { return MaxDepthParameter.Value.Value; }
122      set { MaxDepthParameter.Value.Value = value; }
123    }
124
125    private IFixedValueParameter<IntValue> MinDepthParameter {
126      get { return (IFixedValueParameter<IntValue>)Parameters[MinDepthParameterName]; }
127    }
128    public int MinDepth {
129      get { return MinDepthParameter.Value.Value; }
130      set { MinDepthParameter.Value.Value = value; }
131    }
132
133    private IFixedValueParameter<IntValue> PgeRptEpochParameter {
134      get { return (IFixedValueParameter<IntValue>)Parameters[PgeRptEpochParameterName]; }
135    }
136    public int PgeRptEpoch {
137      get { return PgeRptEpochParameter.Value.Value; }
138      set { PgeRptEpochParameter.Value.Value = value; }
139    }
140
141    private IFixedValueParameter<IntValue> PgeRptCountParameter {
142      get { return (IFixedValueParameter<IntValue>)Parameters[PgeRptCountParameterName]; }
143    }
144    public int PgeRptCount {
145      get { return PgeRptCountParameter.Value.Value; }
146      set { PgeRptCountParameter.Value.Value = value; }
147    }
148
149    private IFixedValueParameter<IntValue> PgeArchiveCapParameter {
150      get { return (IFixedValueParameter<IntValue>)Parameters[PgeArchiveCapParameterName]; }
151    }
152    public int PgeArchiveCap {
153      get { return PgeArchiveCapParameter.Value.Value; }
154      set { PgeArchiveCapParameter.Value.Value = value; }
155    }
156
157    private IFixedValueParameter<IntValue> PeelCntParameter {
158      get { return (IFixedValueParameter<IntValue>)Parameters[PeelCntParameterName]; }
159    }
160    public int PeelCnt {
161      get { return PeelCntParameter.Value.Value; }
162      set { PeelCntParameter.Value.Value = value; }
163    }
164
165    private IFixedValueParameter<DoubleValue> ZeroEpsilonParameter {
166      get { return (IFixedValueParameter<DoubleValue>)Parameters[ZeroEpsilonParameterName]; }
167    }
168    public double ZeroEpsilon {
169      get { return ZeroEpsilonParameter.Value.Value; }
170      set { ZeroEpsilonParameter.Value.Value = value; }
171    }
172
173    private IFixedValueParameter<DoubleValue> HitRatioParameter {
174      get { return (IFixedValueParameter<DoubleValue>)Parameters[HitRatioParameterName]; }
175    }
176    public double HitRatio {
177      get { return HitRatioParameter.Value.Value; }
178      set { HitRatioParameter.Value.Value = value; }
179    }
180
181    private IFixedValueParameter<StringValue> InitMethodParameter {
182      get { return (IFixedValueParameter<StringValue>)Parameters[InitMethodParameterName]; }
183    }
184    public string InitMethod {
185      get { return InitMethodParameter.Value.Value; }
186      set { InitMethodParameter.Value.Value = value; }
187    }
188
189    private IFixedValueParameter<StringValue> GrowMethodParameter {
190      get { return (IFixedValueParameter<StringValue>)Parameters[GrowMethodParameterName]; }
191    }
192    public string GrowMethod {
193      get { return GrowMethodParameter.Value.Value; }
194      set { GrowMethodParameter.Value.Value = value; }
195    }
196
197    private IFixedValueParameter<StringValue> RootsParameter {
198      get { return (IFixedValueParameter<StringValue>)Parameters[RootsParameterName]; }
199    }
200    public string Roots {
201      get { return RootsParameter.Value.Value; }
202      set { RootsParameter.Value.Value = value; }
203    }
204
205    private IFixedValueParameter<StringValue> NodesParameter {
206      get { return (IFixedValueParameter<StringValue>)Parameters[NodesParameterName]; }
207    }
208    public string Nodes {
209      get { return NodesParameter.Value.Value; }
210      set { NodesParameter.Value.Value = value; }
211    }
212
213    private IFixedValueParameter<StringValue> NonTrigParameter {
214      get { return (IFixedValueParameter<StringValue>)Parameters[NonTrigParameterName]; }
215    }
216    public string NonTrig {
217      get { return NonTrigParameter.Value.Value; }
218      set { NonTrigParameter.Value.Value = value; }
219    }
220
221    private IFixedValueParameter<StringValue> LeafsParameter {
222      get { return (IFixedValueParameter<StringValue>)Parameters[LeafsParameterName]; }
223    }
224    public string Leafs {
225      get { return LeafsParameter.Value.Value; }
226      set { LeafsParameter.Value.Value = value; }
227    }
228    #endregion
229
230    public PGE() {
231
232      base.Problem = new RegressionProblem();
233
234      // algorithm parameters are shown in the GUI
235      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, new IntValue(50)));
236      Parameters.Add(new FixedValueParameter<IntValue>(MinDepthParameterName, new IntValue(1)));
237      Parameters.Add(new FixedValueParameter<IntValue>(MaxDepthParameterName, new IntValue(6)));
238      Parameters.Add(new FixedValueParameter<IntValue>(MinSizeParameterName, new IntValue(4)));
239      Parameters.Add(new FixedValueParameter<IntValue>(MaxSizeParameterName, new IntValue(50)));
240      Parameters.Add(new FixedValueParameter<IntValue>(EvalrCountParameterName, new IntValue(2)));
241      Parameters.Add(new FixedValueParameter<IntValue>(PeelCntParameterName, new IntValue(3)));
242      Parameters.Add(new FixedValueParameter<IntValue>(PgeArchiveCapParameterName, new IntValue(256)));
243      Parameters.Add(new FixedValueParameter<IntValue>(PgeRptCountParameterName, new IntValue(20)));
244      Parameters.Add(new FixedValueParameter<IntValue>(PgeRptEpochParameterName, new IntValue(1)));
245      Parameters.Add(new FixedValueParameter<IntValue>(MaxGenParameterName, new IntValue(200)));
246
247      Parameters.Add(new FixedValueParameter<StringValue>(InitMethodParameterName, new StringValue("method1")));  // TODO Dropdown
248      Parameters.Add(new FixedValueParameter<StringValue>(GrowMethodParameterName, new StringValue("method1")));
249
250      Parameters.Add(new FixedValueParameter<StringValue>(RootsParameterName, new StringValue("Add")));    // TODO: checkeditemlist
251      Parameters.Add(new FixedValueParameter<StringValue>(NodesParameterName, new StringValue("Add Mul")));  // TODO: checkeditemlist
252      Parameters.Add(new FixedValueParameter<StringValue>(NonTrigParameterName, new StringValue("Add Mul"))); // TODO: checkeditemlist
253      Parameters.Add(new FixedValueParameter<StringValue>(LeafsParameterName, new StringValue("Var ConstantF")));
254
255      Parameters.Add(new FixedValueParameter<DoubleValue>(ZeroEpsilonParameterName, new DoubleValue(0.00001)));
256      Parameters.Add(new FixedValueParameter<DoubleValue>(HitRatioParameterName, new DoubleValue(0.01)));
257    }
258
259
260    [StorableConstructor]
261    public PGE(bool deserializing) : base(deserializing) { }
262
263
264    public PGE(PGE original, Cloner cloner) : base(original, cloner) {
265      // nothing to clone
266    }
267
268    public override IDeepCloneable Clone(Cloner cloner) {
269      return new PGE(this, cloner);
270    }
271
272    protected override void Run(CancellationToken cancellationToken) {
273      Log log = new Log();
274      Results.Add(new Result("Log", log));
275      var iterationsResult = new IntValue(0);
276      Results.Add(new Result("Iteration", iterationsResult));
277      var bestTestScoreResult = new IntValue(0); // TODO: why is test score an int?
278      Results.Add(new Result("Best test score", bestTestScoreResult));
279      var testScoresTable = new DataTable("Test scores");
280      var bestTestScoreRow = new DataRow("Best test score");
281      var curTestScoreRow = new DataRow("Current test score");
282      testScoresTable.Rows.Add(bestTestScoreRow);
283      testScoresTable.Rows.Add(curTestScoreRow);
284      Results.Add(new Result("Test scores", testScoresTable));
285      var lengthsTable = new DataTable("Lengths");
286      var len1Row = new DataRow("Length 1");
287      var len2Row = new DataRow("Length 2");
288      lengthsTable.Rows.Add(len1Row);
289      lengthsTable.Rows.Add(len2Row);
290      Results.Add(new Result("Lengths", lengthsTable));
291
292
293      var bestSolutionResult = new Result("Best solution", typeof(IRegressionSolution));
294      Results.Add(bestSolutionResult);
295
296      // TODO: the following is potentially problematic for other go processes run on the same machine at the same time
297      // shouldn't be problematic bc is inherited only, normally only child processes are affected
298      Environment.SetEnvironmentVariable("GOGC", "off");
299      Environment.SetEnvironmentVariable("GODEBUG", "cgocheck=0");
300      Environment.SetEnvironmentVariable("CGO_ENABLED", "1");
301      Environment.SetEnvironmentVariable("PGEDEBUG", "0");
302
303
304      //Constants
305      int sortType = 0; // TODO what's sort type? //
306      //1 = PESORT_PARETO_TRN_ERR
307      //0 = PESORT_PARETO_TST_ERR
308
309      string problemTypeString = "benchmark";
310      int numProc = 12;
311      string problemName = Problem.ProblemData.Name;
312
313
314      var problemData = Problem.ProblemData;
315      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
316      // no idea why the following are IntPtr, this should not be necessary for marshalling, it should be ok to just send the double[,]
317     
318      double[] trainData = GetData(problemData.Dataset, variables, problemData.TrainingIndices);
319      double[] testData = GetData(problemData.Dataset, variables, problemData.TestIndices);
320
321      int nTrainData = Problem.ProblemData.TrainingPartition.Size;
322      int nTestData = Problem.ProblemData.TestPartition.Size;
323
324      if (problemData.AllowedInputVariables.Any(iv => iv.Contains(" ")))
325        throw new NotSupportedException("PGE does not support variable names which contain spaces");
326
327      var inputVariableNames = string.Join(" ", problemData.AllowedInputVariables);
328
329      AddTestData(inputVariableNames, problemData.TargetVariable, testData, nTestData);
330      AddTrainData(inputVariableNames, problemData.TargetVariable, trainData, nTrainData);
331
332      int numberOfUseableVariables = problemData.AllowedInputVariables.Count();
333
334      InitSearch(MaxGen, PgeRptEpoch, PgeRptCount, PgeArchiveCap, PeelCnt, EvalrCount, ZeroEpsilon, InitMethod, GrowMethod, sortType);
335
336      // cUsableVars: list of indices into independent variables
337      InitTreeParams(Roots, Nodes, NonTrig, Leafs, numberOfUseableVariables, MaxSize, MinSize, MaxDepth, MinDepth);
338
339      InitProblem(Name, MaxIterations, HitRatio,
340        searchVar: 0,  // SearchVar: index of dependent variables (this is always zero because we only have one target variable)
341        ProblemTypeString: problemTypeString, numProcs: numProc);
342
343      var bestTestScore = int.MaxValue;
344      for (int iter = 1; iter <= MaxIterations; iter++) {
345        iterationsResult.Value = iter;
346
347        int nResults = StepW();
348
349        for (int iResult = 0; iResult < nResults; iResult++) {
350          int nCoeff = 0;
351          int testScore = 0;
352
353          IntPtr eqn = GetStepResult(out testScore, out nCoeff);
354          string eqnStr = Marshal.PtrToStringAnsi(eqn);
355
356          double[] coeff = new double[nCoeff];
357          for (int iCoeff = 0; iCoeff < nCoeff; iCoeff++) {
358            coeff[iCoeff] = GetCoeffResult();
359          }
360          log.LogMessage("Push/Pop (" + iResult + ", " + testScore + ") " + eqnStr + " coeff: " + string.Join(" ", coeff));
361
362          if (!string.IsNullOrEmpty(eqnStr) && (testScore < bestTestScore)) {
363            // update best quality
364            bestTestScore = testScore;
365            bestTestScoreResult.Value = testScore;
366            var sol = CreateSolution(problemData, eqnStr, coeff, problemData.AllowedInputVariables.ToArray());
367            bestSolutionResult.Value = sol;
368          }
369          bestTestScoreRow.Values.Add(bestTestScoreResult.Value); // always add the current best test score to data row
370          curTestScoreRow.Values.Add(testScore);
371        }
372
373        if (cancellationToken.IsCancellationRequested) break;
374      }
375
376      // Results.Add(new Result("Execution time", new TimeSpanValue(this.ExecutionTime)));
377    }
378
379    private static readonly Regex varRegex = new Regex(@"X_(\d)+");
380    private static readonly Regex coeffRegex = new Regex(@"C_(\d)+");
381
382    private IRegressionSolution CreateSolution(IRegressionProblemData problemData, string eqnStr, double[] coeff, string[] usableVariables) {
383      // coefficients are named e.g. "C_0" in the PGE expressions
384      // -> replace all patterns "C_\d" by the corresponding coefficients
385      var match = coeffRegex.Match(eqnStr);
386      while (match.Success) {
387        var coeffIdx = int.Parse(match.Groups[1].ToString());
388        eqnStr = eqnStr.Substring(0, match.Index) +
389          "(" + coeff[coeffIdx].ToString(System.Globalization.CultureInfo.InvariantCulture) + ")" +
390          eqnStr.Substring(match.Index + match.Length);
391        match = coeffRegex.Match(eqnStr);
392      }
393
394      // variables are named e.g. "X_0" in the PGE expressions
395      // -> replace all patterns "X_\d" by the corresponding variable name
396      match = varRegex.Match(eqnStr);
397      while (match.Success) {
398        var varIdx = int.Parse(match.Groups[1].ToString());
399        eqnStr = eqnStr.Substring(0, match.Index) +
400          "'" + usableVariables[varIdx] + "'" +
401          eqnStr.Substring(match.Index + match.Length);
402        match = varRegex.Match(eqnStr);
403      }
404
405      var parser = new InfixExpressionParser();
406      var tree = parser.Parse(eqnStr);
407      var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
408      return model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
409    }
410
411    public override bool SupportsPause {
412      get { return false; }
413    }
414
415    private static double[] GetData(IDataset ds, IEnumerable<string> variableNames, IEnumerable<int> rows) {
416      var dim = variableNames.Count();
417      double[] val = new double[rows.Count() * dim];
418      int r = 0;
419      foreach (var row in rows) {
420        int c = 0;
421        foreach (var var in variableNames) {
422          val[r * dim + c] = ds.GetDoubleValue(var, r);
423          c++;
424        }
425        r++;
426      }
427      return val;
428    }
429  }
430}
Note: See TracBrowser for help on using the repository browser.