Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2929_PrioritizedGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.PGE/3.3/PGE.cs @ 16824

Last change on this file since 16824 was 16621, checked in by gkronber, 6 years ago

#2929: changed code to update the best found solution

File size: 20.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Runtime.InteropServices;
5using System.Threading;
6using HeuristicLab.Analysis;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using System.Text.RegularExpressions;
11using HeuristicLab.Optimization;
12using HeuristicLab.Parameters;
13using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
14using HeuristicLab.Problems.DataAnalysis;
15using HeuristicLab.Problems.DataAnalysis.Symbolic;
16using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
17
18namespace PGE {
19  [Item(Name = "Priorizied Grammar Enumeration (PGE)", Description = "Priorizied grammar enumeration algorithm. Worm, T. and Chiu K., 'Prioritized Grammar Enumeration: Symbolic Regression by Dynamic Programming'. GECCO 2013")]
20
21  [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
22
23  [StorableClass]
24  public unsafe class PGE : BasicAlgorithm {
25
26    [DllImport("go-pge.dll", EntryPoint = "addTestData", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
27    public static extern void AddTestData([MarshalAs(UnmanagedType.AnsiBStr)] string indepNames, [MarshalAs(UnmanagedType.AnsiBStr)] string depndNames, double[] matrix, int nEntries);
28
29    [DllImport("go-pge.dll", EntryPoint = "addTrainData", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
30    public static extern void AddTrainData([MarshalAs(UnmanagedType.AnsiBStr)] string indepNames, [MarshalAs(UnmanagedType.AnsiBStr)] string depndNames, double[] matrix, int nEntries);
31
32    [DllImport("go-pge.dll", EntryPoint = "initSearch", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
33    public static extern void InitSearch(int maxGen, int pgeRptEpoch, int pgeRptCount, int pgeArchiveCap, int peelCnt, int evalrCount, double zeroEpsilon, [MarshalAs(UnmanagedType.AnsiBStr)] string initMethod, [MarshalAs(UnmanagedType.AnsiBStr)] string growMethod, int sortType);
34
35    [DllImport("go-pge.dll", EntryPoint = "initTreeParams", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
36    public static extern void InitTreeParams([MarshalAs(UnmanagedType.AnsiBStr)] string roots, [MarshalAs(UnmanagedType.AnsiBStr)] string nodes, [MarshalAs(UnmanagedType.AnsiBStr)] string nonTrig, [MarshalAs(UnmanagedType.AnsiBStr)] string leafs, int numUsableVars, int maxSize, int minSize, int maxDepth, int minDepth);
37
38    [DllImport("go-pge.dll", EntryPoint = "initProblem", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
39    public static extern void InitProblem([MarshalAs(UnmanagedType.AnsiBStr)] string name, int maxIter, double hitRatio, int searchVar, [MarshalAs(UnmanagedType.AnsiBStr)] string ProblemTypeString, int numProcs);
40
41    [DllImport("go-pge.dll", EntryPoint = "stepW", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
42    public static extern int StepW();
43
44    [DllImport("go-pge.dll", EntryPoint = "getStepResult", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
45    public static extern IntPtr GetStepResult(out int testscore, out int nCoeff);
46
47    [DllImport("go-pge.dll", EntryPoint = "getCoeffResult", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
48    public static extern double GetCoeffResult();
49
50    public override Type ProblemType { get { return typeof(RegressionProblem); } }
51    public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
52
53    #region parameter names
54    private static readonly string MaxIterationsParameterName = "MaxIterations";
55    private static readonly string MaxGenParameterName = "MaxGen";
56    private static readonly string EvalrCountParameterName = "EvalrCount";
57    private static readonly string MaxSizeParameterName = "MaxSize";
58    private static readonly string MinSizeParameterName = "MinSize";
59    private static readonly string MaxDepthParameterName = "MaxDepth";
60    private static readonly string MinDepthParameterName = "MinDepth";
61    private static readonly string PgeRptEpochParameterName = "PgeRptEpoch";
62    private static readonly string PgeRptCountParameterName = "PgeRptCount";
63    private static readonly string PgeArchiveCapParameterName = "PgeArchiveCap";
64    private static readonly string PeelCntParameterName = "PeelCnt";
65    private static readonly string ZeroEpsilonParameterName = "ZeroEpsilon";
66    private static readonly string HitRatioParameterName = "HitRatio";
67    private static readonly string InitMethodParameterName = "InitMethod";
68    private static readonly string GrowMethodParameterName = "GrowMethod";
69    private static readonly string RootsParameterName = "Roots";
70    private static readonly string NodesParameterName = "Nodes";
71    private static readonly string NonTrigParameterName = "NonTrig";
72    private static readonly string LeafsParameterName = "Leafs";
73
74    #endregion
75
76    #region parameters                                           
77    private IFixedValueParameter<IntValue> MaxIterationsParameter {
78      get { return (IFixedValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }
79    }
80    public int MaxIterations {
81      get { return MaxIterationsParameter.Value.Value; }
82      set { MaxIterationsParameter.Value.Value = value; }
83    }
84
85    private IFixedValueParameter<IntValue> MaxGenParameter {
86      get { return (IFixedValueParameter<IntValue>)Parameters[MaxGenParameterName]; }
87    }
88    public int MaxGen {
89      get { return MaxGenParameter.Value.Value; }
90      set { MaxGenParameter.Value.Value = value; }
91    }
92
93    private IFixedValueParameter<IntValue> EvalrCountParameter {
94      get { return (IFixedValueParameter<IntValue>)Parameters[EvalrCountParameterName]; }
95    }
96    public int EvalrCount {
97      get { return EvalrCountParameter.Value.Value; }
98      set { EvalrCountParameter.Value.Value = value; }
99    }
100
101    private IFixedValueParameter<IntValue> MaxSizeParameter {
102      get { return (IFixedValueParameter<IntValue>)Parameters[MaxSizeParameterName]; }
103    }
104    public int MaxSize {
105      get { return MaxSizeParameter.Value.Value; }
106      set { MaxSizeParameter.Value.Value = value; }
107    }
108
109    private IFixedValueParameter<IntValue> MinSizeParameter {
110      get { return (IFixedValueParameter<IntValue>)Parameters[MinSizeParameterName]; }
111    }
112    public int MinSize {
113      get { return MinSizeParameter.Value.Value; }
114      set { MinSizeParameter.Value.Value = value; }
115    }
116
117    private IFixedValueParameter<IntValue> MaxDepthParameter {
118      get { return (IFixedValueParameter<IntValue>)Parameters[MaxDepthParameterName]; }
119    }
120    public int MaxDepth {
121      get { return MaxDepthParameter.Value.Value; }
122      set { MaxDepthParameter.Value.Value = value; }
123    }
124
125    private IFixedValueParameter<IntValue> MinDepthParameter {
126      get { return (IFixedValueParameter<IntValue>)Parameters[MinDepthParameterName]; }
127    }
128    public int MinDepth {
129      get { return MinDepthParameter.Value.Value; }
130      set { MinDepthParameter.Value.Value = value; }
131    }
132
133    private IFixedValueParameter<IntValue> PgeRptEpochParameter {
134      get { return (IFixedValueParameter<IntValue>)Parameters[PgeRptEpochParameterName]; }
135    }
136    public int PgeRptEpoch {
137      get { return PgeRptEpochParameter.Value.Value; }
138      set { PgeRptEpochParameter.Value.Value = value; }
139    }
140
141    private IFixedValueParameter<IntValue> PgeRptCountParameter {
142      get { return (IFixedValueParameter<IntValue>)Parameters[PgeRptCountParameterName]; }
143    }
144    public int PgeRptCount {
145      get { return PgeRptCountParameter.Value.Value; }
146      set { PgeRptCountParameter.Value.Value = value; }
147    }
148
149    private IFixedValueParameter<IntValue> PgeArchiveCapParameter {
150      get { return (IFixedValueParameter<IntValue>)Parameters[PgeArchiveCapParameterName]; }
151    }
152    public int PgeArchiveCap {
153      get { return PgeArchiveCapParameter.Value.Value; }
154      set { PgeArchiveCapParameter.Value.Value = value; }
155    }
156
157    private IFixedValueParameter<IntValue> PeelCntParameter {
158      get { return (IFixedValueParameter<IntValue>)Parameters[PeelCntParameterName]; }
159    }
160    public int PeelCnt {
161      get { return PeelCntParameter.Value.Value; }
162      set { PeelCntParameter.Value.Value = value; }
163    }
164
165    private IFixedValueParameter<DoubleValue> ZeroEpsilonParameter {
166      get { return (IFixedValueParameter<DoubleValue>)Parameters[ZeroEpsilonParameterName]; }
167    }
168    public double ZeroEpsilon {
169      get { return ZeroEpsilonParameter.Value.Value; }
170      set { ZeroEpsilonParameter.Value.Value = value; }
171    }
172
173    private IFixedValueParameter<DoubleValue> HitRatioParameter {
174      get { return (IFixedValueParameter<DoubleValue>)Parameters[HitRatioParameterName]; }
175    }
176    public double HitRatio {
177      get { return HitRatioParameter.Value.Value; }
178      set { HitRatioParameter.Value.Value = value; }
179    }
180
181    private IFixedValueParameter<StringValue> InitMethodParameter {
182      get { return (IFixedValueParameter<StringValue>)Parameters[InitMethodParameterName]; }
183    }
184    public string InitMethod {
185      get { return InitMethodParameter.Value.Value; }
186      set { InitMethodParameter.Value.Value = value; }
187    }
188
189    private IFixedValueParameter<StringValue> GrowMethodParameter {
190      get { return (IFixedValueParameter<StringValue>)Parameters[GrowMethodParameterName]; }
191    }
192    public string GrowMethod {
193      get { return GrowMethodParameter.Value.Value; }
194      set { GrowMethodParameter.Value.Value = value; }
195    }
196
197    private IFixedValueParameter<StringValue> RootsParameter {
198      get { return (IFixedValueParameter<StringValue>)Parameters[RootsParameterName]; }
199    }
200    public string Roots {
201      get { return RootsParameter.Value.Value; }
202      set { RootsParameter.Value.Value = value; }
203    }
204
205    private IFixedValueParameter<StringValue> NodesParameter {
206      get { return (IFixedValueParameter<StringValue>)Parameters[NodesParameterName]; }
207    }
208    public string Nodes {
209      get { return NodesParameter.Value.Value; }
210      set { NodesParameter.Value.Value = value; }
211    }
212
213    private IFixedValueParameter<StringValue> NonTrigParameter {
214      get { return (IFixedValueParameter<StringValue>)Parameters[NonTrigParameterName]; }
215    }
216    public string NonTrig {
217      get { return NonTrigParameter.Value.Value; }
218      set { NonTrigParameter.Value.Value = value; }
219    }
220
221    private IFixedValueParameter<StringValue> LeafsParameter {
222      get { return (IFixedValueParameter<StringValue>)Parameters[LeafsParameterName]; }
223    }
224    public string Leafs {
225      get { return LeafsParameter.Value.Value; }
226      set { LeafsParameter.Value.Value = value; }
227    }
228    #endregion
229
230    public PGE() {
231
232      base.Problem = new RegressionProblem();
233
234      // algorithm parameters are shown in the GUI
235      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, new IntValue(50)));
236      Parameters.Add(new FixedValueParameter<IntValue>(MinDepthParameterName, new IntValue(1)));
237      Parameters.Add(new FixedValueParameter<IntValue>(MaxDepthParameterName, new IntValue(6)));
238      Parameters.Add(new FixedValueParameter<IntValue>(MinSizeParameterName, new IntValue(4)));
239      Parameters.Add(new FixedValueParameter<IntValue>(MaxSizeParameterName, new IntValue(50)));
240      Parameters.Add(new FixedValueParameter<IntValue>(EvalrCountParameterName, new IntValue(2)));
241      Parameters.Add(new FixedValueParameter<IntValue>(PeelCntParameterName, new IntValue(3)));
242      Parameters.Add(new FixedValueParameter<IntValue>(PgeArchiveCapParameterName, new IntValue(256)));
243      Parameters.Add(new FixedValueParameter<IntValue>(PgeRptCountParameterName, new IntValue(20)));
244      Parameters.Add(new FixedValueParameter<IntValue>(PgeRptEpochParameterName, new IntValue(1)));
245      Parameters.Add(new FixedValueParameter<IntValue>(MaxGenParameterName, new IntValue(200)));
246
247      Parameters.Add(new FixedValueParameter<StringValue>(InitMethodParameterName, new StringValue("method1")));  // TODO Dropdown
248      Parameters.Add(new FixedValueParameter<StringValue>(GrowMethodParameterName, new StringValue("method1")));
249
250      Parameters.Add(new FixedValueParameter<StringValue>(RootsParameterName, new StringValue("Add")));    // TODO: checkeditemlist
251      Parameters.Add(new FixedValueParameter<StringValue>(NodesParameterName, new StringValue("Add Mul")));  // TODO: checkeditemlist
252      Parameters.Add(new FixedValueParameter<StringValue>(NonTrigParameterName, new StringValue("Add Mul"))); // TODO: checkeditemlist
253      Parameters.Add(new FixedValueParameter<StringValue>(LeafsParameterName, new StringValue("Var ConstantF")));
254
255      Parameters.Add(new FixedValueParameter<DoubleValue>(ZeroEpsilonParameterName, new DoubleValue(0.00001)));
256      Parameters.Add(new FixedValueParameter<DoubleValue>(HitRatioParameterName, new DoubleValue(0.01)));
257    }
258
259
260    [StorableConstructor]
261    public PGE(bool deserializing) : base(deserializing) { }
262
263
264    public PGE(PGE original, Cloner cloner) : base(original, cloner) {
265      // nothing to clone
266    }
267
268    public override IDeepCloneable Clone(Cloner cloner) {
269      return new PGE(this, cloner);
270    }
271
272    protected override void Run(CancellationToken cancellationToken) {
273      Log log = new Log();
274      Results.Add(new Result("Log", log));
275      var iterationsResult = new IntValue(0);
276      Results.Add(new Result("Iteration", iterationsResult));
277      var bestMSEResult = new DoubleValue();
278      Results.Add(new Result("Best MSE", bestMSEResult));
279      var testScoresTable = new DataTable("Test scores");
280      var bestTestScoreRow = new DataRow("Best test score");
281      var curTestScoreRow = new DataRow("Current test score");
282      testScoresTable.Rows.Add(bestTestScoreRow);
283      testScoresTable.Rows.Add(curTestScoreRow);
284      Results.Add(new Result("Test scores", testScoresTable));
285      var lengthsTable = new DataTable("Lengths");
286      var len1Row = new DataRow("Length 1");
287      var len2Row = new DataRow("Length 2");
288      lengthsTable.Rows.Add(len1Row);
289      lengthsTable.Rows.Add(len2Row);
290      Results.Add(new Result("Lengths", lengthsTable));
291
292
293      var bestSolutionResult = new Result("Best solution", typeof(IRegressionSolution));
294      Results.Add(bestSolutionResult);
295
296      var allSolutions = new ItemList<IRegressionSolution>();
297      var allSolutionsResult = new Result("Solutions", allSolutions);
298      Results.Add(allSolutionsResult);
299
300      // TODO: the following is potentially problematic for other go processes run on the same machine at the same time
301      // shouldn't be problematic bc is inherited only, normally only child processes are affected
302      Environment.SetEnvironmentVariable("GOGC", "off");
303      Environment.SetEnvironmentVariable("GODEBUG", "cgocheck=0");
304      Environment.SetEnvironmentVariable("CGO_ENABLED", "1");
305      Environment.SetEnvironmentVariable("PGEDEBUG", "0");
306
307
308      //Constants
309      int sortType = 0; // TODO what's sort type? //
310      //1 = PESORT_PARETO_TRN_ERR
311      //0 = PESORT_PARETO_TST_ERR
312
313      string problemTypeString = "benchmark";
314      int numProc = 12;
315      string problemName = Problem.ProblemData.Name;
316
317
318      var problemData = Problem.ProblemData;
319      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
320      // no idea why the following are IntPtr, this should not be necessary for marshalling, it should be ok to just send the double[,]
321     
322      double[] trainData = GetData(problemData.Dataset, variables, problemData.TrainingIndices);
323      double[] testData = GetData(problemData.Dataset, variables, problemData.TestIndices);
324
325      int nTrainData = Problem.ProblemData.TrainingPartition.Size;
326      int nTestData = Problem.ProblemData.TestPartition.Size;
327
328      if (problemData.AllowedInputVariables.Any(iv => iv.Contains(" ")))
329        throw new NotSupportedException("PGE does not support variable names which contain spaces");
330
331      var inputVariableNames = string.Join(" ", problemData.AllowedInputVariables);
332
333      AddTestData(inputVariableNames, problemData.TargetVariable, testData, nTestData);
334      AddTrainData(inputVariableNames, problemData.TargetVariable, trainData, nTrainData);
335
336      int numberOfUseableVariables = problemData.AllowedInputVariables.Count();
337
338      InitSearch(MaxGen, PgeRptEpoch, PgeRptCount, PgeArchiveCap, PeelCnt, EvalrCount, ZeroEpsilon, InitMethod, GrowMethod, sortType);
339
340      // cUsableVars: list of indices into independent variables
341      InitTreeParams(Roots, Nodes, NonTrig, Leafs, numberOfUseableVariables, MaxSize, MinSize, MaxDepth, MinDepth);
342
343      InitProblem(Name, MaxIterations, HitRatio,
344        searchVar: 0,  // SearchVar: index of dependent variables (this is always zero because we only have one target variable)
345        ProblemTypeString: problemTypeString, numProcs: numProc);
346
347      var bestMSE = double.MaxValue;
348      for (int iter = 1; iter <= MaxIterations; iter++) {
349        iterationsResult.Value = iter;
350
351        int nResults = StepW();
352
353        for (int iResult = 0; iResult < nResults; iResult++) {
354          int nCoeff = 0;
355          int testScore = 0;
356
357          IntPtr eqn = GetStepResult(out testScore, out nCoeff);
358          string eqnStr = Marshal.PtrToStringAnsi(eqn);
359
360          double[] coeff = new double[nCoeff];
361          for (int iCoeff = 0; iCoeff < nCoeff; iCoeff++) {
362            coeff[iCoeff] = GetCoeffResult();
363          }
364          log.LogMessage("Push/Pop (" + iResult + ", " + testScore + ") " + eqnStr + " coeff: " + string.Join(" ", coeff));
365
366          if (!string.IsNullOrEmpty(eqnStr)) {
367            var sol = CreateSolution(problemData, eqnStr, coeff, problemData.AllowedInputVariables.ToArray());
368            allSolutions.Add(sol);
369            if (sol.TrainingMeanSquaredError < bestMSE) {
370              // update best quality
371              bestMSE = sol.TrainingMeanSquaredError;
372              bestMSEResult.Value = bestMSE;
373              bestSolutionResult.Value = sol;
374            }
375          }
376          bestTestScoreRow.Values.Add(bestMSEResult.Value); // always add the current best test score to data row
377          curTestScoreRow.Values.Add(testScore);
378        }
379
380        if (cancellationToken.IsCancellationRequested) break;
381      }
382
383      // Results.Add(new Result("Execution time", new TimeSpanValue(this.ExecutionTime)));
384    }
385
386    private static readonly Regex varRegex = new Regex(@"X_(\d)+");
387    private static readonly Regex coeffRegex = new Regex(@"C_(\d)+");
388
389    private IRegressionSolution CreateSolution(IRegressionProblemData problemData, string eqnStr, double[] coeff, string[] usableVariables) {
390      // coefficients are named e.g. "C_0" in the PGE expressions
391      // -> replace all patterns "C_\d" by the corresponding coefficients
392      var match = coeffRegex.Match(eqnStr);
393      while (match.Success) {
394        var coeffIdx = int.Parse(match.Groups[1].ToString());
395        eqnStr = eqnStr.Substring(0, match.Index) +
396          "(" + coeff[coeffIdx].ToString(System.Globalization.CultureInfo.InvariantCulture) + ")" +
397          eqnStr.Substring(match.Index + match.Length);
398        match = coeffRegex.Match(eqnStr);
399      }
400
401      // variables are named e.g. "X_0" in the PGE expressions
402      // -> replace all patterns "X_\d" by the corresponding variable name
403      match = varRegex.Match(eqnStr);
404      while (match.Success) {
405        var varIdx = int.Parse(match.Groups[1].ToString());
406        eqnStr = eqnStr.Substring(0, match.Index) +
407          "'" + usableVariables[varIdx] + "'" +
408          eqnStr.Substring(match.Index + match.Length);
409        match = varRegex.Match(eqnStr);
410      }
411
412      var parser = new InfixExpressionParser();
413      var tree = parser.Parse(eqnStr);
414      var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
415      return model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
416    }
417
418    public override bool SupportsPause {
419      get { return false; }
420    }
421
422    private static double[] GetData(IDataset ds, IEnumerable<string> variableNames, IEnumerable<int> rows) {
423      var dim = variableNames.Count();
424      double[] val = new double[rows.Count() * dim];
425      int r = 0;
426      foreach (var row in rows) {
427        int c = 0;
428        foreach (var var in variableNames) {
429          val[r * dim + c] = ds.GetDoubleValue(var, r);
430          c++;
431        }
432        r++;
433      }
434      return val;
435    }
436  }
437}
Note: See TracBrowser for help on using the repository browser.