Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2893_BNLR/HeuristicLab.Algorithms.DataAnalysis/3.4/NonlinearRegression/BayesianNonlinearRegressionModel.cs

Last change on this file was 15748, checked in by gkronber, 7 years ago

#2893: implemented a first version of Bayesian non-linear regression using HMC sampling

File size: 5.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  [StorableClass]
34  [Item("BayesianNonlinearRegressionModel", "")]
35  public sealed class BayesianNonlinearRegressionModel : RegressionModel, IConfidenceRegressionModel {
36    private const int SAMPLE_SIZE = 100;
37
38
39    [Storable]
40    public ISymbolicExpressionTree Tree {
41      get; private set;
42    }
43
44    private double[][] parameterEmpiricalDistribution;
45    public IEnumerable<double[]> ParameterEmpiricalDistribution {
46      get { return parameterEmpiricalDistribution; }
47    }
48
49    public ISymbolicDataAnalysisExpressionTreeInterpreter Interpreter {
50      get; private set;
51    }
52
53    public override IEnumerable<string> VariablesUsedForPrediction {
54      get { return allowedInputVariables; }
55    }
56
57    [Storable]
58    private string[] allowedInputVariables;
59
60    [StorableConstructor]
61    private BayesianNonlinearRegressionModel(bool deserializing)
62      : base(deserializing) {
63    }
64    private BayesianNonlinearRegressionModel(BayesianNonlinearRegressionModel original, Cloner cloner)
65      : base(original, cloner) {
66      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
67      this.Tree = cloner.Clone(original.Tree);
68      this.parameterEmpiricalDistribution = original.parameterEmpiricalDistribution;
69      this.Interpreter = cloner.Clone(original.Interpreter);
70    }
71    public BayesianNonlinearRegressionModel(
72      ISymbolicExpressionTree tree,
73      double[][] parameterEmpiricalDistribution,
74      ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
75      string targetVariable,
76      IEnumerable<string> allowedInputVariables
77      )
78      : base(targetVariable) {
79      this.name = ItemName;
80      this.description = ItemDescription;
81      this.Tree = tree;
82      this.parameterEmpiricalDistribution = (double[][])parameterEmpiricalDistribution.Clone();
83      this.Interpreter = interpreter;
84      this.allowedInputVariables = allowedInputVariables.ToArray();
85    }
86
87    [StorableHook(HookType.AfterDeserialization)]
88    private void AfterDeserialization() {
89    }
90
91    public override IDeepCloneable Clone(Cloner cloner) {
92      return new BayesianNonlinearRegressionModel(this, cloner);
93    }
94
95
96    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
97      var y = Sample(dataset, rows);
98      return y.Select(yi => yi.Average());
99    }
100
101    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
102      var y = Sample(dataset, rows);
103      return y.Select(yi => yi.VariancePop());
104    }
105
106    private IList<double>[] Sample(IDataset dataset, IEnumerable<int> rows) {
107      List<double>[] y = rows.Select(_ => new List<double>(SAMPLE_SIZE)).ToArray();
108      var rand = new System.Random(1234); // TODO;
109      for (int s = 0; s < SAMPLE_SIZE; s++) {
110        var paramIdx = rand.Next(parameterEmpiricalDistribution.Length);
111        UpdateConstants(Tree, parameterEmpiricalDistribution[paramIdx]);
112        int predRow = 0;
113        foreach (var pred in Interpreter.GetSymbolicExpressionTreeValues(Tree, dataset, rows)) {
114          y[predRow].Add(pred);
115          predRow++;
116        }
117      }
118      return y;
119    }
120
121    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
122      return new ConfidenceRegressionSolution(this, new RegressionProblemData(problemData));
123    }
124
125
126    #region taken from ConstantOptEval TODO
127    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) {
128      int i = 0;
129      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
130        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
131        VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase;
132        FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
133        if (constantTreeNode != null)
134          constantTreeNode.Value = constants[i++];
135        // else if (updateVariableWeights && variableTreeNodeBase != null)
136        //   variableTreeNodeBase.Weight = constants[i++];
137        else if (factorVarTreeNode != null) {
138          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
139            factorVarTreeNode.Weights[j] = constants[i++];
140        }
141      }
142    }
143    #endregion
144
145  }
146}
Note: See TracBrowser for help on using the repository browser.