Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2929_PrioritizedGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.PGE/3.3/PGE.cs @ 16200

Last change on this file since 16200 was 16200, checked in by hmaislin, 6 years ago

#2929: Updated DLL to get coeff results and release memory

File size: 19.3 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Runtime.InteropServices;
5using System.Text;
6using System.Threading;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using HeuristicLab.Optimization;
11using HeuristicLab.Parameters;
12using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
13using HeuristicLab.Problems.DataAnalysis;     
14
15namespace PGE {
16  [Item(Name = "Priorizied Grammar Enumeration (PGE)", Description = "Priorizied grammar enumeration algorithm. Worm, T. and Chiu K., 'Prioritized Grammar Enumeration: Symbolic Regression by Dynamic Programming'. GECCO 2013")]
17
18  [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
19
20  [StorableClass]
21  public unsafe class PGE : BasicAlgorithm {
22
23    [DllImport("go-pge.dll", EntryPoint = "addTestData", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
24    public static extern void AddTestData(IntPtr indepNames, IntPtr depndNames, IntPtr matrix, int nEntries);
25
26    [DllImport("go-pge.dll", EntryPoint = "addTrainData", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
27    public static extern void AddTrainData(IntPtr indepNames, IntPtr depndNames, IntPtr matrix, int nEntries);
28
29    [DllImport("go-pge.dll", EntryPoint = "initSearch", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
30    public static extern void InitSearch(int maxGen, int pgeRptEpoch, int pgeRptCount, int pgeArchiveCap, int peelCnt, int evalrCount, double zeroEpsilon, IntPtr initMethod, IntPtr growMethod, int sortType);
31
32    [DllImport("go-pge.dll", EntryPoint = "initTreeParams", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
33    public static extern void InitTreeParams(IntPtr roots, IntPtr nodes, IntPtr nonTrig, IntPtr leafs, IntPtr usableVars, int numUsableVars, int maxSize, int minSize, int maxDepth, int minDepth);
34
35    [DllImport("go-pge.dll", EntryPoint = "initProblem", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
36    public static extern void InitProblem(IntPtr name, int maxIter, double hitRatio, int searchVar, IntPtr ProblemTypeString, int numProcs);
37
38    [DllImport("go-pge.dll", EntryPoint = "stepW", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
39    public static extern void StepW();
40
41    [DllImport("go-pge.dll", EntryPoint = "getStepResult", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
42    public static extern IntPtr GetStepResult(out int noBestPush, out int bestNewMinErr, out int bestlen1, out int bestlen2, out int testscore, out int nCoeff);
43
44    [DllImport("go-pge.dll", EntryPoint = "getCoeffResult", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.StdCall)]
45    public static extern double GetCoeffResult();
46
47    public override Type ProblemType { get { return typeof(RegressionProblem); } }
48    public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
49
50    #region parameter names
51    private static readonly string MaxIterationsParameterName = "MaxIterations";                     
52    private static readonly string MaxGenParameterName = "MaxGen";
53    private static readonly string EvalrCountParameterName = "EvalrCount";
54    private static readonly string MaxSizeParameterName = "MaxSize";
55    private static readonly string MinSizeParameterName = "MinSize";
56    private static readonly string MaxDepthParameterName = "MaxDepth";
57    private static readonly string SearchVarParameterName = "SearchVar";
58    private static readonly string MinDepthParameterName = "MinDepth";
59    private static readonly string PgeRptEpochParameterName = "PgeRptEpoch";
60    private static readonly string PgeRptCountParameterName = "PgeRptCount";
61    private static readonly string PgeArchiveCapParameterName = "PgeArchiveCap";
62    private static readonly string PeelCntParameterName = "PeelCnt";
63    private static readonly string ZeroEpsilonParameterName = "ZeroEpsilon";
64    private static readonly string HitRatioParameterName = "HitRatio";               
65    private static readonly string InitMethodParameterName = "InitMethod";
66    private static readonly string GrowMethodParameterName = "GrowMethod";
67    private static readonly string RootsParameterName = "Roots";
68    private static readonly string NodesParameterName = "Nodes";
69    private static readonly string NonTrigParameterName = "NonTrig";
70    private static readonly string LeafsParameterName = "Leafs";
71
72    #endregion
73
74    #region parameters                                           
75    private IFixedValueParameter<IntValue> MaxIterationsParameter {
76      get { return (IFixedValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }
77    }
78    public int MaxIterations {
79      get { return MaxIterationsParameter.Value.Value; }
80      set { MaxIterationsParameter.Value.Value = value; }
81    }
82
83    private IFixedValueParameter<IntValue> MaxGenParameter {
84      get { return (IFixedValueParameter<IntValue>)Parameters[MaxGenParameterName]; }
85    }
86    public int MaxGen {
87      get { return MaxGenParameter.Value.Value; }
88      set { MaxGenParameter.Value.Value = value; }
89    }
90
91    private IFixedValueParameter<IntValue> EvalrCountParameter {
92      get { return (IFixedValueParameter<IntValue>)Parameters[EvalrCountParameterName]; }
93    }
94    public int EvalrCount {
95      get { return EvalrCountParameter.Value.Value; }
96      set { EvalrCountParameter.Value.Value = value; }
97    }
98
99    private IFixedValueParameter<IntValue> MaxSizeParameter {
100      get { return (IFixedValueParameter<IntValue>)Parameters[MaxSizeParameterName]; }
101    }
102    public int MaxSize {
103      get { return MaxSizeParameter.Value.Value; }
104      set { MaxSizeParameter.Value.Value = value; }
105    }
106
107    private IFixedValueParameter<IntValue> MinSizeParameter {
108      get { return (IFixedValueParameter<IntValue>)Parameters[MinSizeParameterName]; }
109    }
110    public int MinSize {
111      get { return MinSizeParameter.Value.Value; }
112      set { MinSizeParameter.Value.Value = value; }
113    }
114
115    private IFixedValueParameter<IntValue> MaxDepthParameter {
116      get { return (IFixedValueParameter<IntValue>)Parameters[MaxDepthParameterName]; }
117    }
118    public int MaxDepth {
119      get { return MaxDepthParameter.Value.Value; }
120      set { MaxDepthParameter.Value.Value = value; }
121    }
122
123    private IFixedValueParameter<IntValue> SearchVarParameter {
124      get { return (IFixedValueParameter<IntValue>)Parameters[SearchVarParameterName]; }
125    }
126    public int SearchVar {
127      get { return SearchVarParameter.Value.Value; }
128      set { SearchVarParameter.Value.Value = value; }
129    }
130
131    private IFixedValueParameter<IntValue> MinDepthParameter {
132      get { return (IFixedValueParameter<IntValue>)Parameters[MinDepthParameterName]; }
133    }
134    public int MinDepth {
135      get { return MinDepthParameter.Value.Value; }
136      set { MinDepthParameter.Value.Value = value; }
137    }
138
139    private IFixedValueParameter<IntValue> PgeRptEpochParameter {
140      get { return (IFixedValueParameter<IntValue>)Parameters[PgeRptEpochParameterName]; }
141    }
142    public int PgeRptEpoch {
143      get { return PgeRptEpochParameter.Value.Value; }
144      set { PgeRptEpochParameter.Value.Value = value; }
145    }
146
147    private IFixedValueParameter<IntValue> PgeRptCountParameter {
148      get { return (IFixedValueParameter<IntValue>)Parameters[PgeRptCountParameterName]; }
149    }
150    public int PgeRptCount {
151      get { return PgeRptCountParameter.Value.Value; }
152      set { PgeRptCountParameter.Value.Value = value; }
153    }
154
155    private IFixedValueParameter<IntValue> PgeArchiveCapParameter {
156      get { return (IFixedValueParameter<IntValue>)Parameters[PgeArchiveCapParameterName]; }
157    }
158    public int PgeArchiveCap {
159      get { return PgeArchiveCapParameter.Value.Value; }
160      set { PgeArchiveCapParameter.Value.Value = value; }
161    }
162
163    private IFixedValueParameter<IntValue> PeelCntParameter {
164      get { return (IFixedValueParameter<IntValue>)Parameters[PeelCntParameterName]; }
165    }
166    public int PeelCnt {
167      get { return PeelCntParameter.Value.Value; }
168      set { PeelCntParameter.Value.Value = value; }
169    }
170
171    private IFixedValueParameter<DoubleValue> ZeroEpsilonParameter {
172      get { return (IFixedValueParameter<DoubleValue>)Parameters[ZeroEpsilonParameterName]; }
173    }
174    public double ZeroEpsilon {
175      get { return ZeroEpsilonParameter.Value.Value; }
176      set { ZeroEpsilonParameter.Value.Value = value; }
177    }
178
179    private IFixedValueParameter<DoubleValue> HitRatioParameter {
180      get { return (IFixedValueParameter<DoubleValue>)Parameters[HitRatioParameterName]; }
181    }
182    public double HitRatio {
183      get { return HitRatioParameter.Value.Value; }
184      set { HitRatioParameter.Value.Value = value; }
185    }
186
187    private IFixedValueParameter<StringValue> InitMethodParameter {
188      get { return (IFixedValueParameter<StringValue>)Parameters[InitMethodParameterName]; }
189    }
190    public string InitMethod {
191      get { return InitMethodParameter.Value.Value; }
192      set { InitMethodParameter.Value.Value = value; }
193    }
194
195    private IFixedValueParameter<StringValue> GrowMethodParameter {
196      get { return (IFixedValueParameter<StringValue>)Parameters[GrowMethodParameterName]; }
197    }
198    public string GrowMethod {
199      get { return GrowMethodParameter.Value.Value; }
200      set { GrowMethodParameter.Value.Value = value; }
201    }
202
203    private IFixedValueParameter<StringValue> RootsParameter {
204      get { return (IFixedValueParameter<StringValue>)Parameters[RootsParameterName]; }
205    }
206    public string Roots {
207      get { return RootsParameter.Value.Value; }
208      set { RootsParameter.Value.Value = value; }
209    }
210
211    private IFixedValueParameter<StringValue> NodesParameter {
212      get { return (IFixedValueParameter<StringValue>)Parameters[NodesParameterName]; }
213    }
214    public string Nodes {
215      get { return NodesParameter.Value.Value; }
216      set { NodesParameter.Value.Value = value; }
217    }
218
219    private IFixedValueParameter<StringValue> NonTrigParameter {
220      get { return (IFixedValueParameter<StringValue>)Parameters[NonTrigParameterName]; }
221    }
222    public string NonTrig {
223      get { return NonTrigParameter.Value.Value; }
224      set { NonTrigParameter.Value.Value = value; }
225    }
226
227    private IFixedValueParameter<StringValue> LeafsParameter {
228      get { return (IFixedValueParameter<StringValue>)Parameters[LeafsParameterName]; }
229    }
230    public string Leafs {
231      get { return LeafsParameter.Value.Value; }
232      set { LeafsParameter.Value.Value = value; }
233    }
234    #endregion
235
236    public PGE() {
237
238      base.Problem = new RegressionProblem();
239
240      // algorithm parameters are shown in the GUI
241      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, new IntValue(50)));
242      Parameters.Add(new FixedValueParameter<IntValue>(SearchVarParameterName, new IntValue(0)));
243      Parameters.Add(new FixedValueParameter<IntValue>(MinDepthParameterName, new IntValue(1)));
244      Parameters.Add(new FixedValueParameter<IntValue>(MaxDepthParameterName, new IntValue(6)));
245      Parameters.Add(new FixedValueParameter<IntValue>(MinSizeParameterName, new IntValue(4)));
246      Parameters.Add(new FixedValueParameter<IntValue>(MaxSizeParameterName, new IntValue(50)));
247      Parameters.Add(new FixedValueParameter<IntValue>(EvalrCountParameterName, new IntValue(2)));
248      Parameters.Add(new FixedValueParameter<IntValue>(PeelCntParameterName, new IntValue(3)));
249      Parameters.Add(new FixedValueParameter<IntValue>(PgeArchiveCapParameterName, new IntValue(256)));
250      Parameters.Add(new FixedValueParameter<IntValue>(PgeRptCountParameterName, new IntValue(20)));
251      Parameters.Add(new FixedValueParameter<IntValue>(PgeRptEpochParameterName, new IntValue(1)));
252      Parameters.Add(new FixedValueParameter<IntValue>(MaxGenParameterName, new IntValue(200)));               
253                                                                                                                                     
254      Parameters.Add(new FixedValueParameter<StringValue>(InitMethodParameterName, new StringValue("method1")));  // TODO Dropdown
255      Parameters.Add(new FixedValueParameter<StringValue>(GrowMethodParameterName, new StringValue("method1")));
256
257      Parameters.Add(new FixedValueParameter<StringValue>(RootsParameterName, new StringValue("Add")));    // TODO: checkeditemlist
258      Parameters.Add(new FixedValueParameter<StringValue>(NodesParameterName, new StringValue("Add Mul")));  // TODO: checkeditemlist
259      Parameters.Add(new FixedValueParameter<StringValue>(NonTrigParameterName, new StringValue("Add Mul"))); // TODO: checkeditemlist
260      Parameters.Add(new FixedValueParameter<StringValue>(LeafsParameterName, new StringValue("Var ConstantF")));
261
262      Parameters.Add(new FixedValueParameter<DoubleValue>(ZeroEpsilonParameterName, new DoubleValue(0.00001)));
263      Parameters.Add(new FixedValueParameter<DoubleValue>(HitRatioParameterName, new DoubleValue(0.01)));
264    }
265
266 
267    [StorableConstructor]
268    public PGE(bool deserializing) : base(deserializing) { }
269
270   
271    public PGE(PGE original, Cloner cloner) : base(original, cloner) {   
272      // nothing to clone
273    }
274
275    public override IDeepCloneable Clone(Cloner cloner) {
276      return new PGE(this, cloner);
277    }
278
279    protected override void Run(CancellationToken cancellationToken) {
280
281      // TODO: the following is potentially problematic for other go processes run on the same machine at the same time
282      Environment.SetEnvironmentVariable("GOGC", "off");
283      Environment.SetEnvironmentVariable("GODEBUG", "cgocheck=0");
284      Environment.SetEnvironmentVariable("CGO_ENABLED", "1");
285
286
287      int nTrainData = Problem.ProblemData.TrainingPartition.Size;
288      int nTestData = Problem.ProblemData.TestPartition.Size;
289
290      //Constants
291      int sortType = 0;
292      string problemTypeString = "benchmark";
293      int NumProcs = 12;
294      string problemName = Problem.ProblemData.Name;
295
296
297      int indepLen = Problem.ProblemData.AllowedInputVariables.Count();
298
299      var problemData = Problem.ProblemData;
300      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
301      // no idea why the following are IntPtr, this should not be necessary for marshalling, it should be ok to just send the double[,]
302      IntPtr trainData = GetData(problemData.Dataset, variables, problemData.TrainingIndices);
303      IntPtr testData = GetData(problemData.Dataset, variables, problemData.TestIndices);
304
305      int numberOfUseableVariables = problemData.AllowedInputVariables.Count();
306      var inputVariableNames = string.Join(" ", problemData.AllowedInputVariables);
307      // TODO: does this work when input variables contain spaces?
308      // is split on the go side, just for simpler passing
309
310      IntPtr cIndepNames = Marshal.StringToHGlobalAnsi(inputVariableNames);
311      IntPtr cDependentNames = Marshal.StringToHGlobalAnsi(problemData.TargetVariable);
312      // TODO: is it ok to use any variable here?
313      // Dependent- and Independentnames are the variables from the test/train data, e.g. from "Korns_02.trn" dep: x y z v w  indep: f(xs)
314
315      IntPtr cInitMethod = Marshal.StringToHGlobalAnsi(InitMethod);
316      IntPtr cGrowMethod = Marshal.StringToHGlobalAnsi(GrowMethod);
317
318      IntPtr cRoots = Marshal.StringToHGlobalAnsi(Roots);
319      IntPtr cNodes = Marshal.StringToHGlobalAnsi(Nodes);
320      IntPtr cNonTrig = Marshal.StringToHGlobalAnsi(NonTrig);
321      IntPtr cLeafs = Marshal.StringToHGlobalAnsi(Leafs);
322
323      IntPtr cName = Marshal.StringToHGlobalAnsi(problemName);
324      IntPtr cProblemTypeString = Marshal.StringToHGlobalAnsi(problemTypeString);
325
326
327      AddTestData(cIndepNames, cDependentNames, testData, nTestData);
328
329      AddTrainData(cIndepNames, cDependentNames, trainData, nTrainData);
330
331      IntPtr cUseableVars = GetUsableVars(problemData.AllowedInputVariables.Count());
332
333      InitSearch(MaxGen, PgeRptEpoch, PgeRptCount, PgeArchiveCap, PeelCnt, EvalrCount, ZeroEpsilon, cInitMethod, cGrowMethod, sortType);
334
335      InitTreeParams(cRoots, cNodes, cNonTrig, cLeafs, cUseableVars, numberOfUseableVariables, MaxSize, MinSize, MaxDepth, MinDepth);
336
337      InitProblem(cName, MaxIterations, HitRatio, SearchVar, cProblemTypeString, NumProcs);
338
339      var curItersItem = new IntValue();
340
341      for (int iter = 1; iter <= MaxIterations; iter++) {
342        curItersItem.Value = iter;
343
344        StepW();  // TODO: alg crashes here
345
346        for (int iPeel = 0; iPeel < PeelCnt; iPeel++) {
347          int nobestpush = 0;       //bool
348          int bestNewMinError = 0;  //bool
349          int bestlen1 = 0;
350          int bestlen2 = 0;
351          int nCoeff = 0;
352          int testScore = 0;
353
354          IntPtr eqn = GetStepResult(out nobestpush, out bestNewMinError, out bestlen1, out bestlen2, out testScore, out nCoeff);
355
356          string eqnStr = Marshal.PtrToStringAnsi(eqn);
357
358          if (nobestpush == 1)
359          {
360            Console.WriteLine("No best push");
361          } else {
362            Console.WriteLine("Push/Pop (" + bestlen1 + "," + bestlen2 + ") " + eqnStr);
363
364            StringBuilder sb = new StringBuilder("");
365            for (int iCoeff = 0; iCoeff < nCoeff; iCoeff++) {
366              double coeffVal = GetCoeffResult();
367              Console.WriteLine("Coeff: " + coeffVal);
368              sb.Append(coeffVal + "; ");
369            }
370
371            var curItersResult = new Result("Iteration " + iter + " " + iPeel, curItersItem);
372            var coeffItersResult = new Result("Coeff " + iter + " " + iPeel, new StringValue(sb.ToString()));
373
374            var bestQualityItem = new StringValue(eqnStr);
375            var bestQualityResult = new Result("Best quality " + iter + " " + iPeel, bestQualityItem);
376            Results.Add(curItersResult);
377            Results.Add(coeffItersResult);
378            Results.Add(bestQualityResult);
379   
380          }
381
382        }
383
384        if (cancellationToken.IsCancellationRequested) break;
385      }
386
387      Marshal.FreeHGlobal(cUseableVars);
388      Marshal.FreeHGlobal(trainData);
389      Marshal.FreeHGlobal(testData);
390      // Results.Add(new Result("Execution time", new TimeSpanValue(this.ExecutionTime)));
391    }
392
393    public override bool SupportsPause {
394      get { return false; }
395    }
396
397    private static IntPtr GetUsableVars(int n) {
398      long[] vars = new long[n];
399
400      for (int i = 0; i < n; i++) {
401        vars[i] = i;
402      }
403
404      IntPtr usableVars = Marshal.AllocHGlobal(sizeof(long) * n);
405      Marshal.Copy(vars, 0, usableVars, n);
406
407      return usableVars;
408    }
409
410    private static IntPtr GetData(IDataset ds, IEnumerable<string> variableNames, IEnumerable<int> rows) {
411
412      var dim = variableNames.Count();
413      double[] val = new double[rows.Count() * dim];
414      int r = 0;
415      foreach(var row in rows) {
416        int c = 0;
417        foreach(var var in variableNames) {
418          val[r * dim + c] = ds.GetDoubleValue(var, r);
419          c++;
420        }
421        r++;
422      }
423     
424
425      // TODO: seems strange to marshal this explicitly, we can just send the data over to PGE
426      IntPtr data = Marshal.AllocHGlobal(sizeof(double) * val.Length);
427      Marshal.Copy(val, 0, data, val.Length);
428
429      return data;
430    }
431  }
432}
Note: See TracBrowser for help on using the repository browser.