Changeset 15433 for branches/MathNetNumerics-Exploration-2789
- Timestamp:
- 10/27/17 11:10:25 (7 years ago)
- Location:
- branches/MathNetNumerics-Exploration-2789/HeuristicLab.Algorithms.DataAnalysis.Experimental
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/MathNetNumerics-Exploration-2789/HeuristicLab.Algorithms.DataAnalysis.Experimental/GAM.cs
r15369 r15433 47 47 private GAM(bool deserializing) : base(deserializing) { } 48 48 [StorableHook(HookType.AfterDeserialization)] 49 private void AfterDeserialization() { 49 private void AfterDeserialization() { 50 50 } 51 51 … … 62 62 Parameters.Add(new ValueParameter<DoubleValue>("Lambda", "Regularization for smoothing splines", new DoubleValue(1.0))); 63 63 Parameters.Add(new ValueParameter<IntValue>("Max iterations", "", new IntValue(100))); 64 }65 64 Parameters.Add(new ValueParameter<IntValue>("Max interactions", "", new IntValue(1))); 65 } 66 66 67 67 protected override void Run(CancellationToken cancellationToken) { … … 69 69 double lambda = ((IValueParameter<DoubleValue>)Parameters["Lambda"]).Value.Value; 70 70 int maxIters = ((IValueParameter<IntValue>)Parameters["Max iterations"]).Value.Value; 71 int maxInteractions = ((IValueParameter<IntValue>)Parameters["Max interactions"]).Value.Value; 72 if (maxInteractions < 1 || maxInteractions > 5) throw new ArgumentException("Max interactions is outside the valid range [1 .. 5]"); 71 73 72 74 // calculates a GAM model using a linear representation + independent non-linear functions of each variable … … 77 79 var avgY = y.Average(); 78 80 var inputVars = Problem.ProblemData.AllowedInputVariables.ToArray(); 79 IRegressionModel[] f = new IRegressionModel[inputVars.Length * 2]; // linear(x) + nonlinear(x) 81 var nTerms = inputVars.Length; // LR 82 for(int i=1;i<=maxInteractions;i++) { 83 nTerms += inputVars.Combinations(i).Count(); 84 } 85 IRegressionModel[] f = new IRegressionModel[nTerms]; 80 86 for(int i=0;i<f.Length;i++) { 81 87 f[i] = new ConstantModel(0.0, problemData.TargetVariable); … … 90 96 Results.Add(new Result("RMSE", rmseTable)); 91 97 rmseRow.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TrainingIndices).StandardDeviation()); // -1 index to use all predictors 92 rmseRowTest.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TestIndices).StandardDeviation()); 98 rmseRowTest.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TestIndices).StandardDeviation()); 99 100 // for analytics 101 double[] rss = new double[f.Length]; 102 string[] terms = new string[f.Length]; 103 Results.Add(new Result("RSS Values", typeof(DoubleMatrix))); 93 104 94 105 // until convergence … … 99 110 foreach (var inputVar in inputVars) { 100 111 var res = CalculateResiduals(problemData, f, j, avgY, problemData.TrainingIndices); 112 rss[j] = res.Variance(); 113 terms[j] = inputVar; 101 114 f[j] = RegressLR(problemData, inputVar, res); 102 115 j++; 103 116 } 104 foreach (var inputVar in inputVars) { 105 var res = CalculateResiduals(problemData, f, j, avgY, problemData.TrainingIndices); 106 f[j] = RegressSpline(problemData, inputVar, res, lambda); 107 j++; 117 118 for(int interaction = 1; interaction <= maxInteractions;interaction++) { 119 var selectedVars = HeuristicLab.Common.EnumerableExtensions.Combinations(inputVars, interaction); 120 121 foreach (var element in selectedVars) { 122 var res = CalculateResiduals(problemData, f, j, avgY, problemData.TrainingIndices); 123 rss[j] = res.Variance(); 124 terms[j] = string.Format("f({0})", string.Join(",", element)); 125 f[j] = RegressSpline(problemData, element.ToArray(), res, lambda); 126 j++; 127 } 108 128 } 109 129 110 130 rmseRow.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TrainingIndices).StandardDeviation()); // -1 index to use all predictors 111 131 rmseRowTest.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TestIndices).StandardDeviation()); 132 133 // calculate table with residual contributions of each term 134 var rssTable = new DoubleMatrix(rss.Length, 1, new string[] { "RSS" }, terms); 135 for (int i = 0; i < rss.Length; i++) rssTable[i, 0] = rss[i]; 136 Results["RSS Values"].Value = rssTable; 112 137 113 138 if (cancellationToken.IsCancellationRequested) break; … … 117 142 model.AverageModelEstimates = false; 118 143 Results.Add(new Result("Ensemble solution", model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone()))); 144 119 145 } 120 146 … … 158 184 } else return new ConstantModel(target.Average(), problemData.TargetVariable); 159 185 } 186 private IRegressionModel RegressSpline(IRegressionProblemData problemData, string[] inputVars, double[] target, double lambda) { 187 if (inputVars.All(problemData.Dataset.VariableHasType<double>)) { 188 var product = problemData.Dataset.GetDoubleValues(inputVars.First(), problemData.TrainingIndices).ToArray(); 189 for(int i = 1;i<inputVars.Length;i++) { 190 product = product.Zip(problemData.Dataset.GetDoubleValues(inputVars[i], problemData.TrainingIndices), (pi, vi) => pi * vi).ToArray(); 191 } 192 // Umständlich! 193 return Splines.CalculatePenalizedRegressionSpline( 194 product, 195 (double[])target.Clone(), lambda, 196 problemData.TargetVariable, inputVars 197 ); 198 } else return new ConstantModel(target.Average(), problemData.TargetVariable); 199 } 200 160 201 private IRegressionModel RegressRF(IRegressionProblemData problemData, string inputVar, double[] target, double lambda) { 161 202 if (problemData.Dataset.VariableHasType<double>(inputVar)) { -
branches/MathNetNumerics-Exploration-2789/HeuristicLab.Algorithms.DataAnalysis.Experimental/Splines.cs
r15369 r15433 763 763 764 764 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 765 foreach (var x in dataset.GetDoubleValues(VariablesUsedForPrediction.First(), rows)) { 765 var product = dataset.GetDoubleValues(VariablesUsedForPrediction.First(), rows).ToArray(); 766 foreach(var factor in VariablesUsedForPrediction.Skip(1)) { 767 product = product.Zip(dataset.GetDoubleValues(factor, rows), (pi, fi) => pi * fi).ToArray(); 768 } 769 770 foreach (var x in product) { 766 771 yield return alglib.spline1dcalc(interpolant, x); 767 772 }
Note: See TracChangeset
for help on using the changeset viewer.