Changeset 14872 for branches/RBFRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/RadialBasisFunctions/RadialBasisFunctionModel.cs
- Timestamp:
- 04/14/17 17:53:30 (8 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/RBFRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/RadialBasisFunctions/RadialBasisFunctionModel.cs
r14386 r14872 1 #region License Information1 #region License Information 2 2 /* HeuristicLab 3 3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL) … … 22 22 using System; 23 23 using System.Collections.Generic; 24 using System.Diagnostics; 24 25 using System.Linq; 25 26 using HeuristicLab.Common; 26 27 using HeuristicLab.Core; 27 using HeuristicLab.Data;28 28 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 29 29 using HeuristicLab.Problems.DataAnalysis; … … 31 31 namespace HeuristicLab.Algorithms.DataAnalysis { 32 32 /// <summary> 33 /// Represents a Radial Basis Functionregression model.33 /// Represents an RBF regression model. 34 34 /// </summary> 35 35 [StorableClass] 36 [Item("RBFModel", " Represents a Gaussian process posterior.")]36 [Item("RBFModel", "An RBF regression model")] 37 37 public sealed class RadialBasisFunctionModel : RegressionModel, IConfidenceRegressionModel { 38 public override IEnumerable<string> VariablesUsedForPrediction 39 { 38 public override IEnumerable<string> VariablesUsedForPrediction { 40 39 get { return allowedInputVariables; } 41 40 } 42 41 43 42 [Storable] 44 private string[] allowedInputVariables; 45 public string[] AllowedInputVariables 46 { 43 private readonly string[] allowedInputVariables; 44 public string[] AllowedInputVariables { 47 45 get { return allowedInputVariables; } 48 46 } 49 47 50 48 [Storable] 51 private double[] alpha; 52 [Storable] 53 private IDataset trainingDataset; // it is better to store the original training dataset completely because this is more efficient in persistence 54 [Storable] 55 private int[] trainingRows; 56 [Storable] 57 private IKernelFunction<double[]> kernel; 58 [Storable] 59 private DoubleMatrix gramInvert; 60 49 private readonly double[] alpha; 50 51 [Storable] 52 private readonly double[,] trainX; // it is better to store the original training dataset completely because this is more efficient in persistence 53 54 [Storable] 55 private readonly ITransformation<double>[] scaling; 56 57 [Storable] 58 private readonly ICovarianceFunction kernel; 59 60 private double[,] gramInvert; // not storable as it can be large (recreate after deserialization as required) 61 62 [Storable] 63 private readonly double meanOffset; // implementation works for zero-mean target variables 61 64 62 65 [StorableConstructor] … … 65 68 : base(original, cloner) { 66 69 // shallow copies of arrays because they cannot be modified 67 trainingRows = original.trainingRows;68 70 allowedInputVariables = original.allowedInputVariables; 69 71 alpha = original.alpha; 70 trainingDataset = original.trainingDataset; 71 kernel = original.kernel; 72 } 73 public RadialBasisFunctionModel(IDataset dataset, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<int> rows, IKernelFunction<double[]> kernel) 72 trainX = original.trainX; 73 gramInvert = original.gramInvert; 74 scaling = original.scaling; 75 76 meanOffset = original.meanOffset; 77 if (original.kernel != null) 78 kernel = cloner.Clone(original.kernel); 79 } 80 public RadialBasisFunctionModel(IDataset dataset, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<int> rows, 81 bool scaleInputs, ICovarianceFunction kernel) 74 82 : base(targetVariable) { 83 if (kernel.GetNumberOfParameters(allowedInputVariables.Count()) > 0) throw new ArgumentException("All parameters in the kernel function must be specified."); 75 84 name = ItemName; 76 85 description = ItemDescription; 77 86 this.allowedInputVariables = allowedInputVariables.ToArray(); 78 trainingRows = rows.ToArray(); 79 trainingDataset = dataset; 80 this.kernel = (IKernelFunction<double[]>)kernel.Clone(); 87 var trainingRows = rows.ToArray(); 88 this.kernel = (ICovarianceFunction)kernel.Clone(); 81 89 try { 82 var data = ExtractData(dataset, trainingRows); 83 var qualities = dataset.GetDoubleValues(targetVariable, trainingRows).ToArray(); 90 if (scaleInputs) 91 scaling = CreateScaling(dataset, trainingRows); 92 trainX = ExtractData(dataset, trainingRows, scaling); 93 var y = dataset.GetDoubleValues(targetVariable, trainingRows).ToArray(); 94 meanOffset = y.Average(); 95 for (int i = 0; i < y.Length; i++) y[i] -= meanOffset; 84 96 int info; 97 // TODO: improve efficiency by decomposing matrix once instead of solving the system and then inverting the matrix 85 98 alglib.densesolverlsreport denseSolveRep; 86 var gr = BuildGramMatrix(data); 87 alglib.rmatrixsolvels(gr, data.Length + 1, data.Length + 1, qualities.Concat(new[] { 0.0 }).ToArray(), 0.0, out info, out denseSolveRep, out alpha); 88 if (info != 1) throw new ArgumentException("Could not create Model."); 89 gramInvert = new DoubleMatrix(gr).Invert(); 90 } 91 catch (alglib.alglibexception ae) { 99 gramInvert = BuildGramMatrix(trainX); 100 int n = trainX.GetLength(0); 101 alglib.rmatrixsolvels(gramInvert, n, n, y, 0.0, out info, out denseSolveRep, out alpha); 102 if (info != 1) throw new ArgumentException("Could not create model."); 103 104 alglib.matinvreport report; 105 alglib.rmatrixinverse(ref gramInvert, out info, out report); 106 if (info != 1) throw new ArgumentException("Could not invert matrix. Is it quadratic symmetric positive definite?"); 107 108 } catch (alglib.alglibexception ae) { 92 109 // wrap exception so that calling code doesn't have to know about alglib implementation 93 throw new ArgumentException("There was a problem in the calculation of the RBF process model", ae); 94 } 95 } 96 private double[][] ExtractData(IDataset dataset, IEnumerable<int> rows) { 97 return rows.Select(r => allowedInputVariables.Select(v => dataset.GetDoubleValue(v, r)).ToArray()).ToArray(); 110 throw new ArgumentException("There was a problem in the calculation of the RBF model", ae); 111 } 112 } 113 114 private ITransformation<double>[] CreateScaling(IDataset dataset, int[] rows) { 115 var trans = new ITransformation<double>[allowedInputVariables.Length]; 116 int i = 0; 117 foreach (var variable in allowedInputVariables) { 118 var lin = new LinearTransformation(allowedInputVariables); 119 var max = dataset.GetDoubleValues(variable, rows).Max(); 120 var min = dataset.GetDoubleValues(variable, rows).Min(); 121 lin.Multiplier = 1.0 / (max - min); 122 lin.Addend = -min / (max - min); 123 trans[i] = lin; 124 i++; 125 } 126 return trans; 127 } 128 129 private double[,] ExtractData(IDataset dataset, IEnumerable<int> rows, ITransformation<double>[] scaling = null) { 130 double[][] variables; 131 if (scaling != null) { 132 variables = 133 allowedInputVariables.Select((var, i) => scaling[i].Apply(dataset.GetDoubleValues(var, rows)).ToArray()) 134 .ToArray(); 135 } else { 136 variables = 137 allowedInputVariables.Select(var => dataset.GetDoubleValues(var, rows).ToArray()).ToArray(); 138 } 139 int n = variables.First().Length; 140 var res = new double[n, variables.Length]; 141 for (int r = 0; r < n; r++) 142 for (int c = 0; c < variables.Length; c++) { 143 res[r, c] = variables[c][r]; 144 } 145 return res; 98 146 } 99 147 … … 104 152 #region IRegressionModel Members 105 153 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 106 var solutions = ExtractData(dataset, rows); 107 var data = ExtractData(trainingDataset, trainingRows); 108 return solutions.Select(solution => alpha.Zip(data, (a, d) => a * kernel.Get(solution, d)).Sum() + 1 * alpha[alpha.Length - 1]).ToArray(); 154 var newX = ExtractData(dataset, rows, scaling); 155 var dim = newX.GetLength(1); 156 var cov = kernel.GetParameterizedCovarianceFunction(new double[0], Enumerable.Range(0, dim).ToArray()); 157 158 var pred = new double[newX.GetLength(0)]; 159 for (int i = 0; i < pred.Length; i++) { 160 double sum = meanOffset; 161 for (int j = 0; j < alpha.Length; j++) { 162 sum += alpha[j] * cov.CrossCovariance(trainX, newX, j, i); 163 } 164 pred[i] = sum; 165 } 166 return pred; 109 167 } 110 168 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 111 return new RadialBasisFunctionRegressionSolution(this, new RegressionProblemData(problemData));169 return new ConfidenceRegressionSolution(this, new RegressionProblemData(problemData)); 112 170 } 113 171 #endregion 114 172 115 173 public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) { 116 var data = ExtractData(trainingDataset, trainingRows); 117 return ExtractData(dataset, rows).Select(x => GetVariance(x, data)); 174 if (gramInvert == null) CreateAndInvertGramMatrix(); 175 int n = gramInvert.GetLength(0); 176 var newData = ExtractData(dataset, rows, scaling); 177 var dim = newData.GetLength(1); 178 var cov = kernel.GetParameterizedCovarianceFunction(new double[0], Enumerable.Range(0, dim).ToArray()); 179 180 // TODO perf (matrix matrix multiplication) 181 for (int i = 0; i < newData.GetLength(0); i++) { 182 double[] p = new double[n]; 183 184 for (int j = 0; j < trainX.GetLength(0); j++) { 185 p[j] = cov.CrossCovariance(trainX, newData, j, i); 186 } 187 188 var Ap = new double[n]; 189 alglib.ablas.rmatrixmv(n, n, gramInvert, 0, 0, 0, p, 0, ref Ap, 0); 190 var res = 0.0; 191 // dot product 192 for (int j = 0; j < p.Length; j++) res += p[j] * Ap[j]; 193 yield return res > 0 ? res : 0; 194 } 118 195 } 119 196 public double LeaveOneOutCrossValidationRootMeanSquaredError() { 120 return Math.Sqrt(alpha.Select((t, i) => t / gramInvert[i, i]).Sum(d => d * d) / gramInvert.Rows); 121 } 122 197 if (gramInvert == null) CreateAndInvertGramMatrix(); 198 var n = gramInvert.GetLength(0); 199 var s = 1.0 / n; 200 201 var sum = 0.0; 202 for (int i = 0; i < alpha.Length; i++) { 203 var x = alpha[i] / gramInvert[i, i]; 204 sum += x * x; 205 } 206 sum *= s; 207 return Math.Sqrt(sum); 208 } 209 210 private void CreateAndInvertGramMatrix() { 211 try { 212 gramInvert = BuildGramMatrix(trainX); 213 int info = 0; 214 alglib.matinvreport report; 215 alglib.rmatrixinverse(ref gramInvert, out info, out report); 216 if (info != 1) 217 throw new ArgumentException("Could not invert matrix. Is it quadratic symmetric positive definite?"); 218 } catch (alglib.alglibexception) { 219 // wrap exception so that calling code doesn't have to know about alglib implementation 220 throw new ArgumentException("Could not invert matrix. Is it quadratic symmetric positive definite?"); 221 } 222 } 123 223 #region helpers 124 private double[,] BuildGramMatrix(double[ ][] data) {125 var size = data.Length + 1;126 var gram = new double[size, size];127 for (var i = 0; i < size; i++)128 for (var j = i; j < size; j++) {129 if (j == size - 1 && i == size - 1) gram[i, j] = 0;130 else if (j == size - 1 || i == size - 1) gram[j, i] = gram[i, j] = 1;131 else gram[j, i] = gram[i, j] = kernel.Get(data[i], data[j]); //symmteric Matrix --> half of the work224 private double[,] BuildGramMatrix(double[,] data) { 225 var n = data.GetLength(0); 226 var dim = data.GetLength(1); 227 var cov = kernel.GetParameterizedCovarianceFunction(new double[0], Enumerable.Range(0, dim).ToArray()); 228 var gram = new double[n, n]; 229 for (var i = 0; i < n; i++) 230 for (var j = i; j < n; j++) { 231 gram[j, i] = gram[i, j] = cov.Covariance(data, i, j); // symmetric matrix --> half of the work 132 232 } 133 233 return gram; 134 234 } 135 private double GetVariance(double[] solution, IEnumerable<double[]> data) { 136 var phiT = data.Select(x => kernel.Get(x, solution)).Concat(new[] { 1.0 }).ToColumnVector(); 137 var res = phiT.Transpose().Mul(gramInvert.Mul(phiT))[0, 0]; 138 return res > 0 ? res : 0; 139 } 235 140 236 #endregion 141 237 }
Note: See TracChangeset
for help on using the changeset viewer.