Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2893_BNLR/HeuristicLab.Algorithms.DataAnalysis/3.4/NonlinearRegression/BayesianNonlinearRegression.cs @ 15750

Last change on this file since 15750 was 15750, checked in by gkronber, 6 years ago

#2893: added scaling by number of rows to make leapfrog integration independent from the number of rows

File size: 17.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Analysis;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
36using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
37using HeuristicLab.Random;
38
39namespace HeuristicLab.Algorithms.DataAnalysis {
40  /// <summary>
41  /// Bayesian non-linear regression data analysis algorithm.
42  /// </summary>
43  [Item("Bayesian Nonlinear Regression (BNLR)", "Nonlinear regression algorithm which uses HMC to create samples for the posterior distribution for the model parameters.")]
44  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
45  [StorableClass]
46  public sealed class BayesianNonlinearRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
47    private const string RegressionSolutionResultName = "Regression solution";
48    private const string ModelStructureParameterName = "Model structure";
49    private const string IterationsParameterName = "Iterations";
50    private const string RestartsParameterName = "Restarts";
51    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
52    private const string SeedParameterName = "Seed";
53    private const string InitParamsRandomlyParameterName = "InitializeParametersRandomly";
54
55    public IFixedValueParameter<StringValue> ModelStructureParameter {
56      get { return (IFixedValueParameter<StringValue>)Parameters[ModelStructureParameterName]; }
57    }
58    public IFixedValueParameter<IntValue> IterationsParameter {
59      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
60    }
61
62    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
63      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
64    }
65
66    public IFixedValueParameter<IntValue> SeedParameter {
67      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
68    }
69
70    public IFixedValueParameter<IntValue> RestartsParameter {
71      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
72    }
73
74    public IFixedValueParameter<BoolValue> InitParametersRandomlyParameter {
75      get { return (IFixedValueParameter<BoolValue>)Parameters[InitParamsRandomlyParameterName]; }
76    }
77
78    public IFixedValueParameter<IntValue> LeapFrogStepsParameter {
79      get { return (IFixedValueParameter<IntValue>)Parameters["LeapFrogSteps"]; }
80    }
81    public IFixedValueParameter<DoubleValue> LeapFrogStepSizeParameter {
82      get { return (IFixedValueParameter<DoubleValue>)Parameters["LeapFrogStepSize"]; }
83    }
84    public IFixedValueParameter<DoubleValue> NoiseSigmaParameter {
85      get { return (IFixedValueParameter<DoubleValue>)Parameters["NoiseSigma"]; }
86    }
87
88    public string ModelStructure {
89      get { return ModelStructureParameter.Value.Value; }
90      set { ModelStructureParameter.Value.Value = value; }
91    }
92
93    public int Iterations {
94      get { return IterationsParameter.Value.Value; }
95      set { IterationsParameter.Value.Value = value; }
96    }
97
98    public int Restarts {
99      get { return RestartsParameter.Value.Value; }
100      set { RestartsParameter.Value.Value = value; }
101    }
102
103    public int Seed {
104      get { return SeedParameter.Value.Value; }
105      set { SeedParameter.Value.Value = value; }
106    }
107
108    public bool SetSeedRandomly {
109      get { return SetSeedRandomlyParameter.Value.Value; }
110      set { SetSeedRandomlyParameter.Value.Value = value; }
111    }
112
113    public bool InitializeParametersRandomly {
114      get { return InitParametersRandomlyParameter.Value.Value; }
115      set { InitParametersRandomlyParameter.Value.Value = value; }
116    }
117
118    public int LeapFrogSteps {
119      get { return LeapFrogStepsParameter.Value.Value; }
120      set { LeapFrogStepsParameter.Value.Value = value; }
121    }
122    public double LeapFrogStepSize {
123      get { return LeapFrogStepSizeParameter.Value.Value; }
124      set { LeapFrogStepSizeParameter.Value.Value = value; }
125    }
126    public double NoiseSigma {
127      get { return NoiseSigmaParameter.Value.Value; }
128      set { NoiseSigmaParameter.Value.Value = value; }
129    }
130
131    [StorableConstructor]
132    private BayesianNonlinearRegression(bool deserializing) : base(deserializing) { }
133    private BayesianNonlinearRegression(BayesianNonlinearRegression original, Cloner cloner)
134      : base(original, cloner) {
135    }
136    public BayesianNonlinearRegression()
137      : base() {
138      Problem = new RegressionProblem();
139      Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
140      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "The maximum number of iterations for constants optimization.", new IntValue(200)));
141      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of independent random restarts (>0)", new IntValue(10)));
142      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The PRNG seed value.", new IntValue()));
143      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "Switch to determine if the random number seed should be initialized randomly.", new BoolValue(true)));
144      Parameters.Add(new FixedValueParameter<BoolValue>(InitParamsRandomlyParameterName, "Switch to determine if the real-valued model parameters should be initialized randomly in each restart.", new BoolValue(false)));
145
146      Parameters.Add(new FixedValueParameter<IntValue>("LeapFrogSteps", "LeapFrogSteps", new IntValue(10)));
147      Parameters.Add(new FixedValueParameter<DoubleValue>("LeapFrogStepSize", "LeapFrogStepSize", new DoubleValue(0.1)));
148      Parameters.Add(new FixedValueParameter<DoubleValue>("NoiseSigma", "NoiseSigma", new DoubleValue(0.1)));
149
150      SetParameterHiddenState();
151
152      InitParametersRandomlyParameter.Value.ValueChanged += (sender, args) => {
153        SetParameterHiddenState();
154      };
155    }
156
157    private void SetParameterHiddenState() {
158      var hide = !InitializeParametersRandomly;
159      RestartsParameter.Hidden = hide;
160      SeedParameter.Hidden = hide;
161      SetSeedRandomlyParameter.Hidden = hide;
162    }
163
164    [StorableHook(HookType.AfterDeserialization)]
165    private void AfterDeserialization() {
166    }
167
168    public override IDeepCloneable Clone(Cloner cloner) {
169      return new BayesianNonlinearRegression(this, cloner);
170    }
171
172    #region nonlinear regression
173    protected override void Run(CancellationToken cancellationToken) {
174      IRegressionSolution bestSolution = null;
175      if (SetSeedRandomly) Seed = (new System.Random()).Next();
176      var rand = new MersenneTwister((uint)Seed);
177
178      double[][] chain;
179      if (InitializeParametersRandomly) {
180        var qualityTable = new DataTable("RMSE table");
181        qualityTable.VisualProperties.YAxisLogScale = true;
182        var trainRMSERow = new DataRow("RMSE (train)");
183        trainRMSERow.VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
184        var testRMSERow = new DataRow("RMSE test");
185        testRMSERow.VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
186
187        qualityTable.Rows.Add(trainRMSERow);
188        qualityTable.Rows.Add(testRMSERow);
189        Results.Add(new Result(qualityTable.Name, qualityTable.Name + " for all restarts", qualityTable));
190        bestSolution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, rand, LeapFrogSteps, LeapFrogStepSize, NoiseSigma, out chain);
191        trainRMSERow.Values.Add(bestSolution.TrainingRootMeanSquaredError);
192        testRMSERow.Values.Add(bestSolution.TestRootMeanSquaredError);
193        for (int r = 0; r < Restarts; r++) {
194          var solution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, rand, LeapFrogSteps, LeapFrogStepSize, NoiseSigma, out chain);
195          trainRMSERow.Values.Add(solution.TrainingRootMeanSquaredError);
196          testRMSERow.Values.Add(solution.TestRootMeanSquaredError);
197          if (solution.TrainingRootMeanSquaredError < bestSolution.TrainingRootMeanSquaredError) {
198            bestSolution = solution;
199          }
200        }
201      } else {
202        bestSolution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, rand, LeapFrogSteps, LeapFrogStepSize, NoiseSigma, out chain);
203        var nRows = chain.First().Length;
204        var chainTable = new DataTable("Chain");
205        var rows = new DataRow[nRows];
206        for (int i = 0; i < nRows; i++) {
207          rows[i] = new DataRow(i.ToString());
208        }
209        foreach(var sample in chain) {
210          for(int i=0;i<sample.Length;i++) {
211            rows[i].Values.Add(sample[i]);
212          }
213        }
214        for (int i = 0; i < nRows; i++) {
215          chainTable.Rows.Add(rows[i]);
216        }
217
218        Results.Add(new Result("Chain", chainTable));
219      }
220
221      Results.Add(new Result(RegressionSolutionResultName, "The nonlinear regression solution.", bestSolution));
222      Results.Add(new Result("Root mean square error (train)", "The root of the mean of squared errors of the regression solution on the training set.", new DoubleValue(bestSolution.TrainingRootMeanSquaredError)));
223      Results.Add(new Result("Root mean square error (test)", "The root of the mean of squared errors of the regression solution on the test set.", new DoubleValue(bestSolution.TestRootMeanSquaredError)));
224    }
225
226    /// <summary>
227    /// Detemines the posterior distribution for the model parameters using Hamiltonian Monte Carlo.
228    /// Model is specified as infix expression containing variable names and numbers.
229    /// Prior distribution for the parameters is N(0,\lambda I)
230    /// </summary>-
231    /// <param name="problemData">Training and test data</param>
232    /// <param name="modelStructure">The function as infix expression</param>
233    /// <param name="maxIterations">Number of samples for HMC</param>
234    public static IRegressionSolution CreateRegressionSolution(
235      IRegressionProblemData problemData, string modelStructure, int maxIterations, IRandom random,
236      int leapFrogSteps, double leapFrogStepSize, double noiseSigma, out double[][] chain
237      ) {
238      var parser = new InfixExpressionParser();
239      var tree = parser.Parse(modelStructure);
240      // parser handles double and string variables equally by creating a VariableTreeNode
241      // post-process to replace VariableTreeNodes by FactorVariableTreeNodes for all string variables
242      var factorSymbol = new FactorVariable();
243      factorSymbol.VariableNames =
244        problemData.AllowedInputVariables.Where(name => problemData.Dataset.VariableHasType<string>(name));
245      factorSymbol.AllVariableNames = factorSymbol.VariableNames;
246      factorSymbol.VariableValues =
247        factorSymbol.VariableNames.Select(name =>
248        new KeyValuePair<string, Dictionary<string, int>>(name,
249        problemData.Dataset.GetReadOnlyStringValues(name).Distinct()
250        .Select((n, i) => Tuple.Create(n, i))
251        .ToDictionary(tup => tup.Item1, tup => tup.Item2)));
252
253      foreach (var parent in tree.IterateNodesPrefix().ToArray()) {
254        for (int i = 0; i < parent.SubtreeCount; i++) {
255          var varChild = parent.GetSubtree(i) as VariableTreeNode;
256          var factorVarChild = parent.GetSubtree(i) as FactorVariableTreeNode;
257          if (varChild != null && factorSymbol.VariableNames.Contains(varChild.VariableName)) {
258            parent.RemoveSubtree(i);
259            var factorTreeNode = (FactorVariableTreeNode)factorSymbol.CreateTreeNode();
260            factorTreeNode.VariableName = varChild.VariableName;
261            factorTreeNode.Weights =
262              factorTreeNode.Symbol.GetVariableValues(factorTreeNode.VariableName).Select(_ => 1.0).ToArray();
263            // weight = 1.0 for each value
264            parent.InsertSubtree(i, factorTreeNode);
265          } else if (factorVarChild != null && factorSymbol.VariableNames.Contains(factorVarChild.VariableName)) {
266            if (factorSymbol.GetVariableValues(factorVarChild.VariableName).Count() != factorVarChild.Weights.Length)
267              throw new ArgumentException(
268                string.Format("Factor variable {0} needs exactly {1} weights",
269                factorVarChild.VariableName,
270                factorSymbol.GetVariableValues(factorVarChild.VariableName).Count()));
271            parent.RemoveSubtree(i);
272            var factorTreeNode = (FactorVariableTreeNode)factorSymbol.CreateTreeNode();
273            factorTreeNode.VariableName = factorVarChild.VariableName;
274            factorTreeNode.Weights = factorVarChild.Weights;
275            parent.InsertSubtree(i, factorTreeNode);
276          }
277        }
278      }
279
280      // TODO: useful?
281      // initialize constants randomly
282      // if (random != null) {
283      //   foreach (var node in tree.IterateNodesPrefix().OfType<ConstantTreeNode>()) {
284      //     double f = Math.Exp(NormalDistributedRandom.NextDouble(random, 0, 1));
285      //     double s = random.NextDouble() < 0.5 ? -1 : 1;
286      //     node.Value = s * node.Value * f;
287      //   }
288      // }
289
290      double[] initialConstants;
291      var negLogLikelihood = CreateNegLogLikelihoodFunction(problemData, tree, noiseSigma, out initialConstants);
292
293
294      // create parameter sample
295      var sampledParameters = HamiltonianMonteCarlo.SampleChain(initialConstants, negLogLikelihood, random, leapFrogStepSize, leapFrogSteps)
296        .Take(maxIterations);
297
298      chain = sampledParameters.ToArray();
299
300      var model = new BayesianNonlinearRegressionModel(tree, chain,
301        new SymbolicDataAnalysisExpressionTreeLinearInterpreter(), problemData.TargetVariable, problemData.AllowedInputVariables);
302      var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
303
304      solution.Model.Name = "Regression Model";
305      solution.Name = "Regression Solution";
306      return solution;
307    }
308
309    private static Func<double[], Tuple<double, double[]>> CreateNegLogLikelihoodFunction(
310      IRegressionProblemData problemData, ISymbolicExpressionTree tree, double noiseSigma,
311      out double[] initialConstants) {
312      List<TreeToAutoDiffTermConverter.DataForVariable> parameters;
313      TreeToAutoDiffTermConverter.ParametricFunction func;
314      TreeToAutoDiffTermConverter.ParametricFunctionGradient funcGrad;
315      TreeToAutoDiffTermConverter.TryConvertToAutoDiff(tree, false, false, out parameters, out initialConstants, out func, out funcGrad);
316
317      double variance = noiseSigma * noiseSigma;
318
319      IDataset ds = problemData.Dataset;
320      var rows = problemData.TrainingIndices;
321      int N = rows.Count();
322      var xs = new double[N][];
323      int row = 0;
324      foreach (var r in rows) {
325        int col = 0;
326        xs[row] = new double[parameters.Count];
327        foreach (var info in parameters) {
328          if (ds.VariableHasType<double>(info.variableName)) {
329            xs[row][col] = ds.GetDoubleValue(info.variableName, r + info.lag);
330          } else if (ds.VariableHasType<string>(info.variableName)) {
331            xs[row][col] = ds.GetStringValue(info.variableName, r) == info.variableValue ? 1 : 0;
332          } else throw new InvalidProgramException("found a variable of unknown type");
333          col++;
334        }
335        row++;
336      }
337      var ys = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
338
339      return (double[] p) => {
340        var logProbSum = 0.0;
341        var scalingFactor = 1.0 / (2.0 * variance * N);
342        double[] gSum = new double[p.Length];
343        for (int i = 0; i < N; i++) {
344          var fg = funcGrad(p, xs[i]);
345          // sum up err
346          var err = fg.Item2 - ys[i];
347          logProbSum += scalingFactor * err * err;
348
349//          var llik = (err * err / (2 * sigma * sigma))
350
351          // sum up grad
352          for (int j = 0; j < gSum.Length; j++) gSum[j] += scalingFactor * 2 * err * fg.Item1[j];
353        }
354
355
356        double f = logProbSum /* + N / 2.0 * Math.Log(variance) + N / 2.0 * Math.Log(2 * Math.PI) (constant factors) */;
357        double[] g = gSum /* + N / 2.0 * Math.Log(variance) + N / 2.0 * Math.Log(2 * Math.PI) (constant factors) */;
358        return Tuple.Create(f, g);
359      };
360    }
361
362
363    #endregion
364  }
365}
Note: See TracBrowser for help on using the repository browser.