Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessRegression.cs @ 8366

Last change on this file since 8366 was 8325, checked in by gkronber, 12 years ago

#1902 changed return value type for parameter properties

File size: 10.1 KB
Line 
1
2#region License Information
3/* HeuristicLab
4 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Analysis;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.PluginInfrastructure;
33using HeuristicLab.Problems.DataAnalysis;
34using HeuristicLab.Random;
35
36namespace HeuristicLab.Algorithms.DataAnalysis.GaussianProcess {
37  /// <summary>
38  ///Gaussian process regression data analysis algorithm.
39  /// </summary>
40  [Item("Gaussian Process Regression", "Gaussian process regression data analysis algorithm.")]
41  [Creatable("Data Analysis")]
42  [StorableClass]
43  public sealed class GaussianProcessRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
44    private const string MeanFunctionParameterName = "MeanFunction";
45    private const string CovarianceFunctionParameterName = "CovarianceFunction";
46    private const string MinimizationIterationsParameterName = "MinimizationIterations";
47    private const string NegativeLogLikelihoodTableParameterName = "NegativeLogLikelihoodTable";
48    private const string HyperParametersTableParameterName = "HyperParametersTable";
49
50    #region parameter properties
51    public IConstrainedValueParameter<IMeanFunction> MeanFunctionParameter {
52      get { return (IConstrainedValueParameter<IMeanFunction>)Parameters[MeanFunctionParameterName]; }
53    }
54    public IConstrainedValueParameter<ICovarianceFunction> CovarianceFunctionParameter {
55      get { return (IConstrainedValueParameter<ICovarianceFunction>)Parameters[CovarianceFunctionParameterName]; }
56    }
57    public IValueParameter<IntValue> MinimizationIterationsParameter {
58      get { return (IValueParameter<IntValue>)Parameters[MinimizationIterationsParameterName]; }
59    }
60    //public ILookupParameter<DataTable> NegativeLogLikelihoodTableParameter {
61    //  get { return (ILookupParameter<DataTable>)Parameters[NegativeLogLikelihoodTableParameterName]; }
62    //}
63    //public ILookupParameter<DataTable> HyperParametersTableParameter {
64    //  get { return (ILookupParameter<DataTable>)Parameters[HyperParametersTableParameterName]; }
65    //}
66    #endregion
67    #region properties
68    public IMeanFunction MeanFunction {
69      set { MeanFunctionParameter.Value = value; }
70      get { return MeanFunctionParameter.Value; }
71    }
72    public ICovarianceFunction CovarianceFunction {
73      set { CovarianceFunctionParameter.Value = value; }
74      get { return CovarianceFunctionParameter.Value; }
75    }
76    public int MinimizationIterations {
77      set { MinimizationIterationsParameter.Value.Value = value; }
78      get { return MinimizationIterationsParameter.Value.Value; }
79    }
80    #endregion
81    [StorableConstructor]
82    private GaussianProcessRegression(bool deserializing) : base(deserializing) { }
83    private GaussianProcessRegression(GaussianProcessRegression original, Cloner cloner)
84      : base(original, cloner) {
85    }
86    public GaussianProcessRegression()
87      : base() {
88      Problem = new RegressionProblem();
89
90      List<IMeanFunction> meanFunctions = ApplicationManager.Manager.GetInstances<IMeanFunction>().ToList();
91      List<ICovarianceFunction> covFunctions = ApplicationManager.Manager.GetInstances<ICovarianceFunction>().ToList();
92
93      Parameters.Add(new ConstrainedValueParameter<IMeanFunction>(MeanFunctionParameterName, "The mean function to use.",
94        new ItemSet<IMeanFunction>(meanFunctions), meanFunctions.First()));
95      Parameters.Add(new ConstrainedValueParameter<ICovarianceFunction>(CovarianceFunctionParameterName, "The covariance function to use.",
96        new ItemSet<ICovarianceFunction>(covFunctions), covFunctions.First()));
97      Parameters.Add(new ValueParameter<IntValue>(MinimizationIterationsParameterName, "The number of iterations for likelihood optimization.", new IntValue(20)));
98      //Parameters.Add(new LookupParameter<DataTable>(NegativeLogLikelihoodTableParameterName, "The negative log likelihood values over the whole run."));
99      //Parameters.Add(new LookupParameter<DataTable>(HyperParametersTableParameterName, "The values of the hyper-parameters over the whole run."));
100    }
101    [StorableHook(HookType.AfterDeserialization)]
102    private void AfterDeserialization() { }
103
104    public override IDeepCloneable Clone(Cloner cloner) {
105      return new GaussianProcessRegression(this, cloner);
106    }
107
108    #region Gaussian process regression
109    protected override void Run() {
110      IRegressionProblemData problemData = Problem.ProblemData;
111
112      int nAllowedVariables = problemData.AllowedInputVariables.Count();
113      var mt = new MersenneTwister();
114
115      var hyp0 =
116        Enumerable.Range(0,
117                         1 + MeanFunction.GetNumberOfParameters(nAllowedVariables) +
118                         CovarianceFunction.GetNumberOfParameters(nAllowedVariables))
119          .Select(i => mt.NextDouble())
120          .ToArray();
121
122      double[] hyp;
123
124      // find hyperparameters
125
126      double epsg = 0;
127      double epsf = 0.00001;
128      double epsx = 0;
129
130      alglib.minlbfgsstate state;
131      alglib.minlbfgsreport rep;
132
133      alglib.minlbfgscreate(1, hyp0, out state);
134      alglib.minlbfgssetcond(state, epsg, epsf, epsx, MinimizationIterations);
135      alglib.minlbfgssetxrep(state, true);
136      alglib.minlbfgsoptimize(state, OptimizeGaussianProcessParameters, Report, new object[] { MeanFunction, CovarianceFunction, problemData });
137      alglib.minlbfgsresults(state, out hyp, out rep);
138
139
140      double trainR2, testR2, negativeLogLikelihood;
141      var solution = CreateGaussianProcessSolution(problemData, hyp, MeanFunction, CovarianceFunction,
142        out negativeLogLikelihood, out trainR2, out testR2);
143
144      Results.Add(new Result("Gaussian process regression solution", "The Gaussian process regression solution.", solution));
145      Results.Add(new Result("Training R²", "The Pearson's R² of the Gaussian process solution on the training partition.", new DoubleValue(trainR2)));
146      Results.Add(new Result("Test R²", "The Pearson's R² of the Gaussian process solution on the test partition.", new DoubleValue(testR2)));
147      Results.Add(new Result("Negative log likelihood", "The negative log likelihood of the Gaussian process.", new DoubleValue(negativeLogLikelihood)));
148    }
149
150    public static GaussianProcessRegressionSolution CreateGaussianProcessSolution(IRegressionProblemData problemData,
151      IEnumerable<double> hyp, IMeanFunction mean, ICovarianceFunction cov,
152      out double negativeLogLikelihood, out double trainingR2, out double testR2) {
153
154      Dataset dataset = problemData.Dataset;
155      var allowedInputVariables = problemData.AllowedInputVariables;
156      string targetVariable = problemData.TargetVariable;
157      IEnumerable<int> rows = problemData.TrainingIndices;
158
159      var model = new GaussianProcessModel(dataset, targetVariable, allowedInputVariables, rows, hyp, mean, cov);
160      var solution = new GaussianProcessRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
161      negativeLogLikelihood = model.NegativeLogLikelihood;
162      trainingR2 = solution.TrainingRSquared;
163      testR2 = solution.TestRSquared;
164      return solution;
165    }
166
167    private static void OptimizeGaussianProcessParameters(double[] hyp, ref double func, double[] grad, object obj) {
168      var objArr = (object[])obj;
169      var meanFunction = (IMeanFunction)objArr[0];
170      var covarianceFunction = (ICovarianceFunction)objArr[1];
171      var problemData = (RegressionProblemData)objArr[2];
172      IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
173
174      Dataset ds = problemData.Dataset;
175      string targetVariable = problemData.TargetVariable;
176      IEnumerable<int> rows = problemData.TrainingIndices;
177
178
179      IEnumerable<double> dHyp;
180      var model = new GaussianProcessModel(ds, targetVariable, allowedInputVariables, rows, hyp, meanFunction,
181                                           covarianceFunction);
182      dHyp = model.GetHyperparameterGradients();
183
184      int i = 0;
185      foreach (var e in dHyp) {
186        grad[i++] = e;
187      }
188      func = model.NegativeLogLikelihood;
189    }
190
191    public void Report(double[] arg, double func, object obj) {
192      if (!Results.ContainsKey(NegativeLogLikelihoodTableParameterName)) {
193        Results.Add(new Result(NegativeLogLikelihoodTableParameterName, new DataTable()));
194      }
195      if (!Results.ContainsKey(HyperParametersTableParameterName)) {
196        Results.Add(new Result(HyperParametersTableParameterName, new DataTable()));
197      }
198
199      var nllTable = (DataTable)Results[NegativeLogLikelihoodTableParameterName].Value;
200      if (!nllTable.Rows.ContainsKey("Negative log likelihood"))
201        nllTable.Rows.Add(new DataRow("Negative log likelihood"));
202      var nllRow = nllTable.Rows["Negative log likelihood"];
203
204      nllRow.Values.Add(func);
205
206      var hypTable = (DataTable)Results[HyperParametersTableParameterName].Value;
207      if (hypTable.Rows.Count == 0) {
208        for (int i = 0; i < arg.Length; i++)
209          hypTable.Rows.Add(new DataRow(i.ToString()));
210      }
211      for (int i = 0; i < arg.Length; i++) {
212        hypTable.Rows[i.ToString()].Values.Add(arg[i]);
213      }
214    }
215
216    #endregion
217  }
218}
Note: See TracBrowser for help on using the repository browser.