Free cookie consent management tool by TermsFeed Policy Generator

source: branches/PersistenceReintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs @ 16300

Last change on this file since 16300 was 14929, checked in by gkronber, 8 years ago

#2520 fixed unit tests for new persistence: loading & storing all samples

File size: 19.2 KB
RevLine 
[13645]1#region License Information
2/* HeuristicLab
[14185]3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[13645]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics.Contracts;
25using System.Linq;
[13658]26using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
[13645]27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
32using HeuristicLab.Random;
[14929]33using HeuristicLab.Persistence;
[13645]34
35namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
36  public static class MctsSymbolicRegressionStatic {
37    // TODO: SGD with adagrad instead of lbfgs?
38    // TODO: check Taylor expansion capabilities (ln(x), sqrt(x), exp(x)) in combination with GBT
39    // TODO: optimize for 3 targets concurrently (y, 1/y, exp(y), and log(y))? Would simplify the number of possible expressions again
40    #region static API
41
[14929]42    [StorableType("c370f13f-cfd9-4388-be24-09a979d74740")]
[13645]43    public interface IState {
44      bool Done { get; }
45      ISymbolicRegressionModel BestModel { get; }
46      double BestSolutionTrainingQuality { get; }
47      double BestSolutionTestQuality { get; }
[13651]48      int TotalRollouts { get; }
49      int EffectiveRollouts { get; }
50      int FuncEvaluations { get; }
51      int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
52      // TODO other stats on LM optimizer might be interesting here
[13645]53    }
54
55    // created through factory method
56    private class State : IState {
57      private const int MaxParams = 100;
58
59      // state variables used by MCTS
60      internal readonly Automaton automaton;
61      internal IRandom random { get; private set; }
62      internal readonly Tree tree;
63      internal readonly Func<byte[], int, double> evalFun;
[13658]64      internal readonly IPolicy treePolicy;
[13651]65      // MCTS might get stuck. Track statistics on the number of effective rollouts
66      internal int totalRollouts;
67      internal int effectiveRollouts;
[13645]68
69
70      // state variables used only internally (for eval function)
71      private readonly IRegressionProblemData problemData;
72      private readonly double[][] x;
73      private readonly double[] y;
74      private readonly double[][] testX;
75      private readonly double[] testY;
76      private readonly double[] scalingFactor;
77      private readonly double[] scalingOffset;
78      private readonly int constOptIterations;
79      private readonly double lowerEstimationLimit, upperEstimationLimit;
80
81      private readonly ExpressionEvaluator evaluator, testEvaluator;
82
83      // values for best solution
84      private double bestRSq;
85      private byte[] bestCode;
86      private int bestNParams;
87      private double[] bestConsts;
88
[13651]89      // stats
90      private int funcEvaluations;
91      private int gradEvaluations;
92
[13645]93      // buffers
94      private readonly double[] ones; // vector of ones (as default params)
95      private readonly double[] constsBuf;
96      private readonly double[] predBuf, testPredBuf;
97      private readonly double[][] gradBuf;
98
[13658]99      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables, int constOptIterations,
100        IPolicy treePolicy = null,
[13645]101        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
102        bool allowProdOfVars = true,
103        bool allowExp = true,
104        bool allowLog = true,
105        bool allowInv = true,
106        bool allowMultipleTerms = false) {
107
108        this.problemData = problemData;
109        this.constOptIterations = constOptIterations;
110        this.evalFun = this.Eval;
111        this.lowerEstimationLimit = lowerEstimationLimit;
112        this.upperEstimationLimit = upperEstimationLimit;
113
114        random = new MersenneTwister(randSeed);
115
116        // prepare data for evaluation
117        double[][] x;
118        double[] y;
119        double[][] testX;
120        double[] testY;
121        double[] scalingFactor;
122        double[] scalingOffset;
123        // get training and test datasets (scale linearly based on training set if required)
124        GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
125        GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
126        this.x = x;
127        this.y = y;
128        this.testX = testX;
129        this.testY = testY;
130        this.scalingFactor = scalingFactor;
131        this.scalingOffset = scalingOffset;
132        this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
133        // we need a separate evaluator because the vector length for the test dataset might differ
134        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
135
136        this.automaton = new Automaton(x, maxVariables, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
[13658]137        this.treePolicy = treePolicy ?? new Ucb();
138        this.tree = new Tree() { state = automaton.CurrentState, actionStatistics = treePolicy.CreateActionStatistics() };
[13645]139
140        // reset best solution
141        this.bestRSq = 0;
142        // code for default solution (constant model)
143        this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
144        this.bestNParams = 0;
145        this.bestConsts = null;
146
147        // init buffers
148        this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
149        constsBuf = new double[MaxParams];
150        this.predBuf = new double[y.Length];
151        this.testPredBuf = new double[testY.Length];
152
153        this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
154      }
155
156      #region IState inferface
[13658]157      public bool Done { get { return tree != null && tree.Done; } }
[13645]158
159      public double BestSolutionTrainingQuality {
160        get {
161          evaluator.Exec(bestCode, x, bestConsts, predBuf);
162          return RSq(y, predBuf);
163        }
164      }
165
166      public double BestSolutionTestQuality {
167        get {
168          testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
169          return RSq(testY, testPredBuf);
170        }
171      }
172
173      // takes the code of the best solution and creates and equivalent symbolic regression model
174      public ISymbolicRegressionModel BestModel {
175        get {
176          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
177          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
178
179          var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
[13941]180          var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
[13645]181
182          // model has already been scaled linearly in Eval
183          return model;
184        }
185      }
[13651]186
187      public int TotalRollouts { get { return totalRollouts; } }
188      public int EffectiveRollouts { get { return effectiveRollouts; } }
189      public int FuncEvaluations { get { return funcEvaluations; } }
190      public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
191
[13645]192      #endregion
193
194      private double Eval(byte[] code, int nParams) {
195        double[] optConsts;
196        double q;
197        Eval(code, nParams, out q, out optConsts);
198
199        if (q > bestRSq) {
200          bestRSq = q;
201          bestNParams = nParams;
202          this.bestCode = new byte[code.Length];
203          this.bestConsts = new double[bestNParams];
204
205          Array.Copy(code, bestCode, code.Length);
206          Array.Copy(optConsts, bestConsts, bestNParams);
207        }
208
209        return q;
210      }
211
212      private void Eval(byte[] code, int nParams, out double rsq, out double[] optConsts) {
213        // we make a first pass to determine a valid starting configuration for all constants
214        // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
215        // scale and offset are set to optimal starting configuration
216        // assumes scale is the first param and offset is the last param
217        double alpha;
218        double beta;
219
220        // reset constants
221        Array.Copy(ones, constsBuf, nParams);
222        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
[13651]223        funcEvaluations++;
[13645]224
225        // calc opt scaling (alpha*f(x) + beta)
226        OnlineCalculatorError error;
227        OnlineLinearScalingParameterCalculator.Calculate(predBuf, y, out alpha, out beta, out error);
228        if (error == OnlineCalculatorError.None) {
229          constsBuf[0] *= beta;
230          constsBuf[nParams - 1] = constsBuf[nParams - 1] * beta + alpha;
231        }
232        if (nParams <= 2 || constOptIterations <= 0) {
233          // if we don't need to optimize parameters then we are done
234          // changing scale and offset does not influence r²
235          rsq = RSq(y, predBuf);
236          optConsts = constsBuf;
237        } else {
238          // optimize constants using the starting point calculated above
239          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
[13651]240
[13645]241          evaluator.Exec(code, x, constsBuf, predBuf);
[13651]242          funcEvaluations++;
243
[13645]244          rsq = RSq(y, predBuf);
245          optConsts = constsBuf;
246        }
247      }
248
249
250
251      #region helpers
252      private static double RSq(IEnumerable<double> x, IEnumerable<double> y) {
253        OnlineCalculatorError error;
254        double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
255        return error == OnlineCalculatorError.None ? r * r : 0.0;
256      }
257
258
259      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
[13651]260        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
[13645]261        Array.Copy(consts, optConsts, nParams);
262
263        alglib.minlmstate state;
264        alglib.minlmreport rep = null;
[13657]265        alglib.minlmcreatevj(y.Length, optConsts, out state);
[13645]266        alglib.minlmsetcond(state, 0.0, epsF, 0.0, nIters);
267        //alglib.minlmsetgradientcheck(state, 0.000001);
268        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
269        alglib.minlmresults(state, out optConsts, out rep);
[13651]270        funcEvaluations += rep.nfunc;
271        gradEvaluations += rep.njac * nParams;
[13645]272
273        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
274
275        // only use optimized constants if successful
276        if (rep.terminationtype >= 0) {
277          Array.Copy(optConsts, consts, optConsts.Length);
278        }
279      }
280
281      private void Func(double[] arg, double[] fi, object obj) {
282        var code = (byte[])obj;
283        evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
284        for (int r = 0; r < predBuf.Length; r++) {
285          var res = predBuf[r] - y[r];
286          fi[r] = res;
287        }
288      }
289      private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
290        int nParams = arg.Length;
291        var code = (byte[])obj;
292        evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
293        for (int r = 0; r < predBuf.Length; r++) {
294          var res = predBuf[r] - y[r];
295          fi[r] = res;
296
297          for (int k = 0; k < nParams; k++) {
298            jac[r, k] = gradBuf[k][r];
299          }
300        }
301      }
302      #endregion
303    }
304
[13658]305    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
306      bool scaleVariables = true, int constOptIterations = 0,
307      IPolicy policy = null,
308      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
[13645]309      bool allowProdOfVars = true,
310      bool allowExp = true,
311      bool allowLog = true,
312      bool allowInv = true,
313      bool allowMultipleTerms = false
314      ) {
[13658]315      return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations,
316        policy,
[13645]317        lowerEstimationLimit, upperEstimationLimit,
318        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
319    }
320
321    // returns the quality of the evaluated solution
322    public static double MakeStep(IState state) {
323      var mctsState = state as State;
324      if (mctsState == null) throw new ArgumentException("state");
325      if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
326
327      return TreeSearch(mctsState);
328    }
329    #endregion
330
331    private static double TreeSearch(State mctsState) {
332      var automaton = mctsState.automaton;
333      var tree = mctsState.tree;
334      var eval = mctsState.evalFun;
335      var rand = mctsState.random;
[13658]336      var treePolicy = mctsState.treePolicy;
[13651]337      double q = 0;
338      bool success = false;
339      do {
340        automaton.Reset();
[13658]341        success = TryTreeSearchRec(rand, tree, automaton, eval, treePolicy, out q);
[13651]342        mctsState.totalRollouts++;
[13658]343      } while (!success && !tree.Done);
[13651]344      mctsState.effectiveRollouts++;
345      return q;
[13645]346    }
347
[13651]348    // tree search might fail because of constraints for expressions
349    // in this case we get stuck we just restart
350    // see ConstraintHandler.cs for more info
[13658]351    private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
352      out double q) {
[13645]353      Tree selectedChild = null;
354      Contract.Assert(tree.state == automaton.CurrentState);
[13658]355      Contract.Assert(!tree.Done);
[13645]356      if (tree.children == null) {
357        if (automaton.IsFinalState(tree.state)) {
358          // final state
[13658]359          tree.Done = true;
[13645]360
361          // EVALUATE
362          byte[] code; int nParams;
363          automaton.GetCode(out code, out nParams);
364          q = eval(code, nParams);
[13658]365
366          treePolicy.Update(tree.actionStatistics, q);
[13651]367          return true; // we reached a final state
[13645]368        } else {
369          // EXPAND
370          int[] possibleFollowStates;
371          int nFs;
372          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
[13651]373          if (nFs == 0) {
374            // stuck in a dead end (no final state and no allowed follow states)
375            q = 0;
[13658]376            tree.Done = true;
[13651]377            tree.children = null;
378            return false;
379          }
[13645]380          tree.children = new Tree[nFs];
[13651]381          for (int i = 0; i < tree.children.Length; i++)
[13658]382            tree.children[i] = new Tree() { children = null, state = possibleFollowStates[i], actionStatistics = treePolicy.CreateActionStatistics() };
[13645]383
[13657]384          selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
[13645]385        }
386      } else {
387        // tree.children != null
388        // UCT selection within tree
[13658]389        int selectedIdx = 0;
390        if (tree.children.Length > 1) {
391          selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);
392        }
393        selectedChild = tree.children[selectedIdx];
[13645]394      }
395      // make selected step and recurse
396      automaton.Goto(selectedChild.state);
[13658]397      var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);
[13651]398      if (success) {
399        // only update if successful
[13658]400        treePolicy.Update(tree.actionStatistics, q);
[13651]401      }
[13645]402
[13658]403      tree.Done = tree.children.All(ch => ch.Done);
404      if (tree.Done) {
[13651]405        tree.children = null; // cut off the sub-branch if it has been fully explored
[13645]406      }
[13651]407      return success;
[13645]408    }
409
410    private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
411      // if one of the new children leads to a final state then go there
412      // otherwise choose a random child
413      int selectedChildIdx = -1;
414      // find first final state if there is one
415      for (int i = 0; i < tree.children.Length; i++) {
416        if (automaton.IsFinalState(tree.children[i].state)) {
417          selectedChildIdx = i;
418          break;
419        }
420      }
[13669]421      // no final state -> select a the first child
[13645]422      if (selectedChildIdx == -1) {
[13669]423        selectedChildIdx = 0;
[13645]424      }
425      return tree.children[selectedChildIdx];
426    }
427
428    // scales data and extracts values from dataset into arrays
429    private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,
430      out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
431      xs = new double[problemData.AllowedInputVariables.Count()][];
432
433      var i = 0;
434      if (scaleVariables) {
435        scalingFactor = new double[xs.Length];
436        scalingOffset = new double[xs.Length];
437      } else {
438        scalingFactor = null;
439        scalingOffset = null;
440      }
441      foreach (var var in problemData.AllowedInputVariables) {
442        if (scaleVariables) {
443          var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
444          var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
445          var range = maxX - minX;
446
447          // scaledX = (x - min) / range
448          var sf = 1.0 / range;
449          var offset = -minX / range;
450          scalingFactor[i] = sf;
451          scalingOffset[i] = offset;
452          i++;
453        }
454      }
455
456      GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
457    }
458
459    // extract values from dataset into arrays
460    private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,
461     out double[][] xs, out double[] y) {
462      xs = new double[problemData.AllowedInputVariables.Count()][];
463
464      int i = 0;
465      foreach (var var in problemData.AllowedInputVariables) {
466        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
467        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
468        xs[i++] =
469          problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
470      }
471
472      y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
473    }
474  }
475}
Note: See TracBrowser for help on using the repository browser.