Free cookie consent management tool by TermsFeed Policy Generator

source: branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs @ 14777

Last change on this file since 14777 was 14185, checked in by swagner, 9 years ago

#2526: Updated year of copyrights in license headers

File size: 19.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics.Contracts;
25using System.Linq;
26using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
35  public static class MctsSymbolicRegressionStatic {
36    // TODO: SGD with adagrad instead of lbfgs?
37    // TODO: check Taylor expansion capabilities (ln(x), sqrt(x), exp(x)) in combination with GBT
38    // TODO: optimize for 3 targets concurrently (y, 1/y, exp(y), and log(y))? Would simplify the number of possible expressions again
39    #region static API
40
41    public interface IState {
42      bool Done { get; }
43      ISymbolicRegressionModel BestModel { get; }
44      double BestSolutionTrainingQuality { get; }
45      double BestSolutionTestQuality { get; }
46      int TotalRollouts { get; }
47      int EffectiveRollouts { get; }
48      int FuncEvaluations { get; }
49      int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
50      // TODO other stats on LM optimizer might be interesting here
51    }
52
53    // created through factory method
54    private class State : IState {
55      private const int MaxParams = 100;
56
57      // state variables used by MCTS
58      internal readonly Automaton automaton;
59      internal IRandom random { get; private set; }
60      internal readonly Tree tree;
61      internal readonly Func<byte[], int, double> evalFun;
62      internal readonly IPolicy treePolicy;
63      // MCTS might get stuck. Track statistics on the number of effective rollouts
64      internal int totalRollouts;
65      internal int effectiveRollouts;
66
67
68      // state variables used only internally (for eval function)
69      private readonly IRegressionProblemData problemData;
70      private readonly double[][] x;
71      private readonly double[] y;
72      private readonly double[][] testX;
73      private readonly double[] testY;
74      private readonly double[] scalingFactor;
75      private readonly double[] scalingOffset;
76      private readonly int constOptIterations;
77      private readonly double lowerEstimationLimit, upperEstimationLimit;
78
79      private readonly ExpressionEvaluator evaluator, testEvaluator;
80
81      // values for best solution
82      private double bestRSq;
83      private byte[] bestCode;
84      private int bestNParams;
85      private double[] bestConsts;
86
87      // stats
88      private int funcEvaluations;
89      private int gradEvaluations;
90
91      // buffers
92      private readonly double[] ones; // vector of ones (as default params)
93      private readonly double[] constsBuf;
94      private readonly double[] predBuf, testPredBuf;
95      private readonly double[][] gradBuf;
96
97      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables, int constOptIterations,
98        IPolicy treePolicy = null,
99        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
100        bool allowProdOfVars = true,
101        bool allowExp = true,
102        bool allowLog = true,
103        bool allowInv = true,
104        bool allowMultipleTerms = false) {
105
106        this.problemData = problemData;
107        this.constOptIterations = constOptIterations;
108        this.evalFun = this.Eval;
109        this.lowerEstimationLimit = lowerEstimationLimit;
110        this.upperEstimationLimit = upperEstimationLimit;
111
112        random = new MersenneTwister(randSeed);
113
114        // prepare data for evaluation
115        double[][] x;
116        double[] y;
117        double[][] testX;
118        double[] testY;
119        double[] scalingFactor;
120        double[] scalingOffset;
121        // get training and test datasets (scale linearly based on training set if required)
122        GenerateData(problemData, scaleVariables, problemData.TrainingIndices, out x, out y, out scalingFactor, out scalingOffset);
123        GenerateData(problemData, problemData.TestIndices, scalingFactor, scalingOffset, out testX, out testY);
124        this.x = x;
125        this.y = y;
126        this.testX = testX;
127        this.testY = testY;
128        this.scalingFactor = scalingFactor;
129        this.scalingOffset = scalingOffset;
130        this.evaluator = new ExpressionEvaluator(y.Length, lowerEstimationLimit, upperEstimationLimit);
131        // we need a separate evaluator because the vector length for the test dataset might differ
132        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
133
134        this.automaton = new Automaton(x, maxVariables, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
135        this.treePolicy = treePolicy ?? new Ucb();
136        this.tree = new Tree() { state = automaton.CurrentState, actionStatistics = treePolicy.CreateActionStatistics() };
137
138        // reset best solution
139        this.bestRSq = 0;
140        // code for default solution (constant model)
141        this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
142        this.bestNParams = 0;
143        this.bestConsts = null;
144
145        // init buffers
146        this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
147        constsBuf = new double[MaxParams];
148        this.predBuf = new double[y.Length];
149        this.testPredBuf = new double[testY.Length];
150
151        this.gradBuf = Enumerable.Range(0, MaxParams).Select(_ => new double[y.Length]).ToArray();
152      }
153
154      #region IState inferface
155      public bool Done { get { return tree != null && tree.Done; } }
156
157      public double BestSolutionTrainingQuality {
158        get {
159          evaluator.Exec(bestCode, x, bestConsts, predBuf);
160          return RSq(y, predBuf);
161        }
162      }
163
164      public double BestSolutionTestQuality {
165        get {
166          testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
167          return RSq(testY, testPredBuf);
168        }
169      }
170
171      // takes the code of the best solution and creates and equivalent symbolic regression model
172      public ISymbolicRegressionModel BestModel {
173        get {
174          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
175          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
176
177          var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
178          var model = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerEstimationLimit, upperEstimationLimit);
179
180          // model has already been scaled linearly in Eval
181          return model;
182        }
183      }
184
185      public int TotalRollouts { get { return totalRollouts; } }
186      public int EffectiveRollouts { get { return effectiveRollouts; } }
187      public int FuncEvaluations { get { return funcEvaluations; } }
188      public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
189
190      #endregion
191
192      private double Eval(byte[] code, int nParams) {
193        double[] optConsts;
194        double q;
195        Eval(code, nParams, out q, out optConsts);
196
197        if (q > bestRSq) {
198          bestRSq = q;
199          bestNParams = nParams;
200          this.bestCode = new byte[code.Length];
201          this.bestConsts = new double[bestNParams];
202
203          Array.Copy(code, bestCode, code.Length);
204          Array.Copy(optConsts, bestConsts, bestNParams);
205        }
206
207        return q;
208      }
209
210      private void Eval(byte[] code, int nParams, out double rsq, out double[] optConsts) {
211        // we make a first pass to determine a valid starting configuration for all constants
212        // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
213        // scale and offset are set to optimal starting configuration
214        // assumes scale is the first param and offset is the last param
215        double alpha;
216        double beta;
217
218        // reset constants
219        Array.Copy(ones, constsBuf, nParams);
220        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
221        funcEvaluations++;
222
223        // calc opt scaling (alpha*f(x) + beta)
224        OnlineCalculatorError error;
225        OnlineLinearScalingParameterCalculator.Calculate(predBuf, y, out alpha, out beta, out error);
226        if (error == OnlineCalculatorError.None) {
227          constsBuf[0] *= beta;
228          constsBuf[nParams - 1] = constsBuf[nParams - 1] * beta + alpha;
229        }
230        if (nParams <= 2 || constOptIterations <= 0) {
231          // if we don't need to optimize parameters then we are done
232          // changing scale and offset does not influence r²
233          rsq = RSq(y, predBuf);
234          optConsts = constsBuf;
235        } else {
236          // optimize constants using the starting point calculated above
237          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
238
239          evaluator.Exec(code, x, constsBuf, predBuf);
240          funcEvaluations++;
241
242          rsq = RSq(y, predBuf);
243          optConsts = constsBuf;
244        }
245      }
246
247
248
249      #region helpers
250      private static double RSq(IEnumerable<double> x, IEnumerable<double> y) {
251        OnlineCalculatorError error;
252        double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
253        return error == OnlineCalculatorError.None ? r * r : 0.0;
254      }
255
256
257      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
258        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
259        Array.Copy(consts, optConsts, nParams);
260
261        alglib.minlmstate state;
262        alglib.minlmreport rep = null;
263        alglib.minlmcreatevj(y.Length, optConsts, out state);
264        alglib.minlmsetcond(state, 0.0, epsF, 0.0, nIters);
265        //alglib.minlmsetgradientcheck(state, 0.000001);
266        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
267        alglib.minlmresults(state, out optConsts, out rep);
268        funcEvaluations += rep.nfunc;
269        gradEvaluations += rep.njac * nParams;
270
271        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
272
273        // only use optimized constants if successful
274        if (rep.terminationtype >= 0) {
275          Array.Copy(optConsts, consts, optConsts.Length);
276        }
277      }
278
279      private void Func(double[] arg, double[] fi, object obj) {
280        var code = (byte[])obj;
281        evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
282        for (int r = 0; r < predBuf.Length; r++) {
283          var res = predBuf[r] - y[r];
284          fi[r] = res;
285        }
286      }
287      private void FuncAndJacobian(double[] arg, double[] fi, double[,] jac, object obj) {
288        int nParams = arg.Length;
289        var code = (byte[])obj;
290        evaluator.ExecGradient(code, x, arg, predBuf, gradBuf); // gradients are nParams x vLen
291        for (int r = 0; r < predBuf.Length; r++) {
292          var res = predBuf[r] - y[r];
293          fi[r] = res;
294
295          for (int k = 0; k < nParams; k++) {
296            jac[r, k] = gradBuf[k][r];
297          }
298        }
299      }
300      #endregion
301    }
302
303    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
304      bool scaleVariables = true, int constOptIterations = 0,
305      IPolicy policy = null,
306      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
307      bool allowProdOfVars = true,
308      bool allowExp = true,
309      bool allowLog = true,
310      bool allowInv = true,
311      bool allowMultipleTerms = false
312      ) {
313      return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations,
314        policy,
315        lowerEstimationLimit, upperEstimationLimit,
316        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
317    }
318
319    // returns the quality of the evaluated solution
320    public static double MakeStep(IState state) {
321      var mctsState = state as State;
322      if (mctsState == null) throw new ArgumentException("state");
323      if (mctsState.Done) throw new NotSupportedException("The tree search has enumerated all possible solutions.");
324
325      return TreeSearch(mctsState);
326    }
327    #endregion
328
329    private static double TreeSearch(State mctsState) {
330      var automaton = mctsState.automaton;
331      var tree = mctsState.tree;
332      var eval = mctsState.evalFun;
333      var rand = mctsState.random;
334      var treePolicy = mctsState.treePolicy;
335      double q = 0;
336      bool success = false;
337      do {
338        automaton.Reset();
339        success = TryTreeSearchRec(rand, tree, automaton, eval, treePolicy, out q);
340        mctsState.totalRollouts++;
341      } while (!success && !tree.Done);
342      mctsState.effectiveRollouts++;
343      return q;
344    }
345
346    // tree search might fail because of constraints for expressions
347    // in this case we get stuck we just restart
348    // see ConstraintHandler.cs for more info
349    private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
350      out double q) {
351      Tree selectedChild = null;
352      Contract.Assert(tree.state == automaton.CurrentState);
353      Contract.Assert(!tree.Done);
354      if (tree.children == null) {
355        if (automaton.IsFinalState(tree.state)) {
356          // final state
357          tree.Done = true;
358
359          // EVALUATE
360          byte[] code; int nParams;
361          automaton.GetCode(out code, out nParams);
362          q = eval(code, nParams);
363
364          treePolicy.Update(tree.actionStatistics, q);
365          return true; // we reached a final state
366        } else {
367          // EXPAND
368          int[] possibleFollowStates;
369          int nFs;
370          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
371          if (nFs == 0) {
372            // stuck in a dead end (no final state and no allowed follow states)
373            q = 0;
374            tree.Done = true;
375            tree.children = null;
376            return false;
377          }
378          tree.children = new Tree[nFs];
379          for (int i = 0; i < tree.children.Length; i++)
380            tree.children[i] = new Tree() { children = null, state = possibleFollowStates[i], actionStatistics = treePolicy.CreateActionStatistics() };
381
382          selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
383        }
384      } else {
385        // tree.children != null
386        // UCT selection within tree
387        int selectedIdx = 0;
388        if (tree.children.Length > 1) {
389          selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);
390        }
391        selectedChild = tree.children[selectedIdx];
392      }
393      // make selected step and recurse
394      automaton.Goto(selectedChild.state);
395      var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);
396      if (success) {
397        // only update if successful
398        treePolicy.Update(tree.actionStatistics, q);
399      }
400
401      tree.Done = tree.children.All(ch => ch.Done);
402      if (tree.Done) {
403        tree.children = null; // cut off the sub-branch if it has been fully explored
404      }
405      return success;
406    }
407
408    private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
409      // if one of the new children leads to a final state then go there
410      // otherwise choose a random child
411      int selectedChildIdx = -1;
412      // find first final state if there is one
413      for (int i = 0; i < tree.children.Length; i++) {
414        if (automaton.IsFinalState(tree.children[i].state)) {
415          selectedChildIdx = i;
416          break;
417        }
418      }
419      // no final state -> select a the first child
420      if (selectedChildIdx == -1) {
421        selectedChildIdx = 0;
422      }
423      return tree.children[selectedChildIdx];
424    }
425
426    // scales data and extracts values from dataset into arrays
427    private static void GenerateData(IRegressionProblemData problemData, bool scaleVariables, IEnumerable<int> rows,
428      out double[][] xs, out double[] y, out double[] scalingFactor, out double[] scalingOffset) {
429      xs = new double[problemData.AllowedInputVariables.Count()][];
430
431      var i = 0;
432      if (scaleVariables) {
433        scalingFactor = new double[xs.Length];
434        scalingOffset = new double[xs.Length];
435      } else {
436        scalingFactor = null;
437        scalingOffset = null;
438      }
439      foreach (var var in problemData.AllowedInputVariables) {
440        if (scaleVariables) {
441          var minX = problemData.Dataset.GetDoubleValues(var, rows).Min();
442          var maxX = problemData.Dataset.GetDoubleValues(var, rows).Max();
443          var range = maxX - minX;
444
445          // scaledX = (x - min) / range
446          var sf = 1.0 / range;
447          var offset = -minX / range;
448          scalingFactor[i] = sf;
449          scalingOffset[i] = offset;
450          i++;
451        }
452      }
453
454      GenerateData(problemData, rows, scalingFactor, scalingOffset, out xs, out y);
455    }
456
457    // extract values from dataset into arrays
458    private static void GenerateData(IRegressionProblemData problemData, IEnumerable<int> rows, double[] scalingFactor, double[] scalingOffset,
459     out double[][] xs, out double[] y) {
460      xs = new double[problemData.AllowedInputVariables.Count()][];
461
462      int i = 0;
463      foreach (var var in problemData.AllowedInputVariables) {
464        var sf = scalingFactor == null ? 1.0 : scalingFactor[i];
465        var offset = scalingFactor == null ? 0.0 : scalingOffset[i];
466        xs[i++] =
467          problemData.Dataset.GetDoubleValues(var, rows).Select(xi => xi * sf + offset).ToArray();
468      }
469
470      y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
471    }
472  }
473}
Note: See TracBrowser for help on using the repository browser.