Changeset 17991 for branches/3128_Prediction_Intervals/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SymbolicRegressionModel.cs
- Timestamp:
- 06/16/21 21:35:37 (3 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/3128_Prediction_Intervals/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SymbolicRegressionModel.cs
r17180 r17991 26 26 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 27 27 using HEAL.Attic; 28 using HeuristicLab.Data; 29 using System.Linq; 28 30 29 31 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression { … … 45 47 } 46 48 49 [Storable] 50 private readonly double[,] parameterCovariance; 51 [Storable] 52 private readonly double sigma; 53 47 54 [StorableConstructor] 48 55 protected SymbolicRegressionModel(StorableConstructorFlag _) : base(_) { … … 53 60 : base(original, cloner) { 54 61 this.targetVariable = original.targetVariable; 62 this.parameterCovariance = original.parameterCovariance; // immutable 63 this.sigma = original.sigma; 55 64 } 56 65 57 66 public SymbolicRegressionModel(string targetVariable, ISymbolicExpressionTree tree, 58 67 ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, 59 double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue )68 double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, double[,] parameterCovariance = null, double sigma = 0.0) 60 69 : base(tree, interpreter, lowerEstimationLimit, upperEstimationLimit) { 61 70 this.targetVariable = targetVariable; 71 if (parameterCovariance != null) 72 this.parameterCovariance = (double[,])parameterCovariance.Clone(); 73 this.sigma = sigma; 62 74 } 63 75 … … 69 81 return Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows) 70 82 .LimitToRange(LowerEstimationLimit, UpperEstimationLimit); 83 } 84 85 public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) { 86 // must work with a copy because we change tree nodes 87 var treeCopy = (ISymbolicExpressionTree)SymbolicExpressionTree.Clone(); 88 // uses sampling to produce prediction intervals 89 alglib.hqrndseed(31415, 926535, out var state); 90 var cov = parameterCovariance; 91 if (cov == null || cov.Length == 0) return rows.Select(_ => 0.0); 92 var n = 30; 93 var M = rows.Select(_ => new double[n]).ToArray(); 94 var paramNodes = new List<ISymbolicExpressionTreeNode>(); 95 var coeffList = new List<double>(); 96 // HACK: skip linear scaling parameters because the analyzer doesn't use them (and they are likely correlated with the remaining parameters) 97 // only works with linear scaling 98 if (!(treeCopy.Root.GetSubtree(0).GetSubtree(0).Symbol is Addition) || 99 !(treeCopy.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0).Symbol is Multiplication)) 100 throw new NotImplementedException("prediction intervals are implemented only for linear scaling"); 101 102 foreach (var node in treeCopy.Root.GetSubtree(0).GetSubtree(0).IterateNodesPostfix()) { 103 if (node is ConstantTreeNode constNode) { 104 paramNodes.Add(constNode); 105 coeffList.Add(constNode.Value); 106 } else if (node is VariableTreeNode varNode) { 107 paramNodes.Add(varNode); 108 coeffList.Add(varNode.Weight); 109 } 110 } 111 var coeff = coeffList.ToArray(); 112 var numParams = coeff.Length; 113 if (cov.GetLength(0) != numParams) throw new InvalidProgramException(); 114 115 // TODO: probably we do not need to sample but can instead use a first-order or second-order approximation of f 116 // see http://sia.webpopix.org/nonlinearRegression.html 117 // also see https://rmazing.wordpress.com/2013/08/26/predictnls-part-2-taylor-approximation-confidence-intervals-for-nls-models/ 118 // https://www.rdocumentation.org/packages/propagate/versions/1.0-4/topics/predictNLS 119 double[] p = new double[numParams]; 120 for (int i = 0; i < 30; i++) { 121 // sample and update parameter vector delta is 122 alglib.hqrndnormalv(state, numParams, out var delta); 123 alglib.rmatrixmv(numParams, numParams, cov, 0, 0, 0, delta, 0, ref p, 0); 124 for (int j = 0; j < numParams; j++) { 125 if (paramNodes[j] is ConstantTreeNode constNode) constNode.Value = coeff[j] + p[j]; 126 else if (paramNodes[j] is VariableTreeNode varNode) varNode.Weight = coeff[j] + p[j]; 127 } 128 var r = 0; 129 var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(treeCopy, dataset, rows).LimitToRange(LowerEstimationLimit, UpperEstimationLimit); 130 131 foreach (var pred in estimatedValues) { 132 M[r++][i] = pred; 133 } 134 } 135 136 // reset parameters 137 for (int j = 0; j < numParams; j++) { 138 if (paramNodes[j] is ConstantTreeNode constNode) constNode.Value = coeff[j]; 139 else if (paramNodes[j] is VariableTreeNode varNode) varNode.Weight = coeff[j]; 140 } 141 var sigma2 = sigma * sigma; 142 return M.Select(M_i => M_i.Variance() + sigma2).ToArray(); 71 143 } 72 144
Note: See TracChangeset
for help on using the changeset viewer.