- Timestamp:
- 04/09/21 19:51:38 (4 years ago)
- Location:
- branches/3116_GAM_Interactions/HeuristicLab.Algorithms.DataAnalysis/3.4
- Files:
-
- 2 added
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/3116_GAM_Interactions/HeuristicLab.Algorithms.DataAnalysis/3.4/GAM/GeneralizedAdditiveModelAlgorithm.cs
r17888 r17933 47 47 private const string SetSeedRandomlyParameterName = "SetSeedRandomly"; 48 48 private const string CreateSolutionParameterName = "CreateSolution"; 49 private const string BaseLearnerParameterName = "BaseLearner"; 50 private const string Spline1dBaseLearner = "1d cubic regression spline (penalized)"; 51 private const string Spline2dBaseLearner = "2d cubic spline"; 52 private const string Spline3dBaseLearner = "3d cubic spline"; 49 53 50 54 #endregion … … 72 76 } 73 77 78 public IConstrainedValueParameter<StringValue> BaseLearnerParameter { 79 get { return (IConstrainedValueParameter<StringValue>)Parameters[BaseLearnerParameterName]; } 80 } 74 81 #endregion 75 82 … … 101 108 } 102 109 110 public string BaseLearner { 111 get { return BaseLearnerParameter.Value.Value; } 112 set { BaseLearnerParameter.Value = BaseLearnerParameter.ValidValues.Single(vv => vv.Value == value); } 113 } 103 114 #endregion 115 116 public override bool SupportsStop => true; 104 117 105 118 [StorableConstructor] … … 122 135 "Number of iterations. Try a large value and check convergence of the error over iterations. Usually, only a few iterations (e.g. 10) are needed for convergence.", new IntValue(10))); 123 136 Parameters.Add(new FixedValueParameter<DoubleValue>(LambdaParameterName, 124 "The penalty parameter for the penalized regression splines. Set to a value between -8 (weak smoothing) and 8 (strong smoot ing). Usually, a value between -4 and 4 should be fine", new DoubleValue(3)));137 "The penalty parameter for the penalized regression splines. Set to a value between -8 (weak smoothing) and 8 (strong smoothing). Usually, a value between -4 and 4 should be fine", new DoubleValue(3))); 125 138 Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, 126 139 "The random seed used to initialize the new pseudo random number generator.", new IntValue(0))); … … 130 143 "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 131 144 Parameters[CreateSolutionParameterName].Hidden = true; 145 var validBaseLearners = new ItemSet<StringValue>(new StringValue[] { new StringValue(Spline1dBaseLearner).AsReadOnly(), new StringValue(Spline2dBaseLearner).AsReadOnly(), new StringValue(Spline3dBaseLearner).AsReadOnly()}); 146 Parameters.Add(new ConstrainedValueParameter<StringValue>("BaseLearner", "The model to use for the additive functions", validBaseLearners, validBaseLearners.First())); 132 147 } 133 148 … … 148 163 var inputVars = problemData.AllowedInputVariables.ToArray(); 149 164 150 int nTerms = inputVars.Length; 165 var maxInteractions = 0; 166 if (BaseLearner == Spline1dBaseLearner) maxInteractions = 1; 167 else if (BaseLearner == Spline2dBaseLearner) maxInteractions = 2; 168 else if (BaseLearner == Spline3dBaseLearner) maxInteractions = 3; 169 else throw new ArgumentException("Unknown base learner."); 170 var varTuples = EnumerableExtensions.Combinations(inputVars, maxInteractions).ToArray(); 171 var varTupleNames = varTuples.Select(comb => string.Join(" ", comb)).ToArray(); 172 int nTerms = varTuples.Length; 151 173 152 174 #region init results … … 167 189 168 190 // calculate table with residual contributions of each term 169 var rssTable = new DoubleMatrix(nTerms, 1, new string[] { "RSS" }, inputVars);191 var rssTable = new DoubleMatrix(nTerms, 1, new string[] { "RSS" }, varTupleNames); 170 192 Results.Add(new Result("RSS Values", rssTable)); 171 193 #endregion … … 194 216 idx.ShuffleInPlace(rand); 195 217 foreach (var inputIdx in idx) { 196 var inputVar = inputVars[inputIdx]; 218 if (cancellationToken.IsCancellationRequested) break; 219 var fInputVars = varTuples[inputIdx]; 197 220 // first remove the effect of the previous model for the inputIdx (by adding the output of the current model to the residual) 198 221 AddInPlace(res, f[inputIdx].GetEstimatedValues(ds, trainRows)); … … 200 223 201 224 rssTable[inputIdx, 0] = MSE(res); 202 f[inputIdx] = RegressSpline(problemData, inputVar, res, lambda); 225 if (fInputVars.Count() == 1) { 226 f[inputIdx] = Regress1dSpline(problemData, fInputVars.Single(), res, lambda); 227 } else if (fInputVars.Count() == 2) { 228 f[inputIdx] = Regress2dSpline(problemData, fInputVars, res, lambda); 229 } else if (fInputVars.Count() == 3) { 230 f[inputIdx] = Regress3dSpline(problemData, fInputVars, res, lambda); 231 } else { 232 throw new ArgumentException(); 233 } 203 234 204 235 SubtractInPlace(res, f[inputIdx].GetEstimatedValues(ds, trainRows)); … … 217 248 var model = new RegressionEnsembleModel(f.Concat(new[] { new ConstantModel(avgY, problemData.TargetVariable) })); 218 249 model.AverageModelEstimates = false; 219 var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone()); 250 var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone()); 220 251 Results.Add(new Result("Ensemble solution", solution)); 221 252 } … … 223 254 224 255 public static double MSE(IEnumerable<double> residuals) { 225 var mse 256 var mse = residuals.Select(r => r * r).Average(); 226 257 return mse; 227 258 } … … 233 264 } 234 265 235 private IRegressionModel Regress Spline(IRegressionProblemData problemData, string inputVar, double[] target, double lambda) {266 private IRegressionModel Regress1dSpline(IRegressionProblemData problemData, string inputVar, double[] target, double lambda) { 236 267 var x = problemData.Dataset.GetDoubleValues(inputVar, problemData.TrainingIndices).ToArray(); 237 268 var y = (double[])target.Clone(); 238 int info;239 269 alglib.spline1dinterpolant s; 240 alglib.spline1dfitreport rep;241 270 int numKnots = (int)Math.Min(50, 3 * Math.Sqrt(x.Length)); // heuristic for number of knots (Elements of Statistical Learning) 242 271 243 alglib.spline1dfitpenalized(x, y, numKnots, lambda, out info, out s, out rep);272 alglib.spline1dfitpenalized(x, y, numKnots, lambda, out _, out s, out _); 244 273 245 274 return new Spline1dModel(s.innerobj, problemData.TargetVariable, inputVar); 275 } 276 277 private IRegressionModel Regress2dSpline(IRegressionProblemData problemData, IEnumerable<string> inputVars, double[] target, double lambda) { 278 var x = problemData.Dataset.ToArray(inputVars, problemData.TrainingIndices); 279 var xy = x.HorzCat(target); 280 281 alglib.spline2dbuilder builder; 282 int d = 1; // scalar output 283 int numKnots = (int)Math.Min(50, 3 * Math.Sqrt(x.Length)); // heuristic for number of knots (Elements of Statistical Learning) 284 285 alglib.spline2dbuildercreate(d, out builder); 286 alglib.spline2dbuildersetpoints(builder, xy, xy.GetLength(0)); 287 alglib.spline2dbuildersetgrid(builder, numKnots, numKnots); 288 alglib.spline2dbuildersetalgoblocklls(builder, Math.Exp(lambda)); 289 290 // 291 // Now we are ready to fit and evaluate our results 292 // 293 alglib.spline2dinterpolant s; 294 alglib.spline2dfit(builder, out s, out _); 295 return new Spline2dModel(s, problemData.TargetVariable, inputVars.ElementAt(0), inputVars.ElementAt(1)); 296 } 297 298 private IRegressionModel Regress3dSpline(IRegressionProblemData problemData, IEnumerable<string> inputVars, double[] target, double lambda) { 299 var x = problemData.Dataset.ToArray(inputVars, problemData.TrainingIndices); 300 var xy = x.HorzCat(target); 301 302 var rbase = 100.0; 303 var nlayers = 3; 304 305 alglib.rbfmodel model; 306 alglib.rbfcreate(nx: 3, ny: 1, out model); 307 alglib.rbfsetpoints(model, xy); 308 alglib.rbfsetalgohierarchical(model, rbase, nlayers, Math.Exp(lambda)); 309 alglib.rbfbuildmodel(model, out _); 310 311 return new AlglibRbfModel(model, problemData.TargetVariable, inputVars.ElementAt(0), inputVars.ElementAt(1), inputVars.ElementAt(2)); 246 312 } 247 313 -
branches/3116_GAM_Interactions/HeuristicLab.Algorithms.DataAnalysis/3.4/GAM/Spline1dModel.cs
r17889 r17933 29 29 30 30 namespace HeuristicLab.Algorithms.DataAnalysis { 31 31 32 [Item("Spline model (1d)", 32 33 "Univariate spline model (wrapper for alglib.spline1dmodel)")] … … 40 41 set { 41 42 if (value.Length > 1) throw new ArgumentException("A one-dimensional spline model supports only one input variable."); 42 inputVariable= value[0];43 x1 = value[0]; 43 44 } 44 45 } 45 46 46 47 [Storable] 47 private string inputVariable;48 public override IEnumerable<string> VariablesUsedForPrediction => new[] { inputVariable};48 private string x1; 49 public override IEnumerable<string> VariablesUsedForPrediction => new[] { x1 }; 49 50 50 51 [StorableConstructor] … … 54 55 55 56 private Spline1dModel(Spline1dModel orig, Cloner cloner) : base(orig, cloner) { 56 this. inputVariable = orig.inputVariable;57 this.x1 = orig.x1; 57 58 this.interpolant = (alglib.spline1d.spline1dinterpolant)orig.interpolant.make_copy(); 58 59 } … … 60 61 : base(targetVar, $"Spline model ({inputVar})") { 61 62 this.interpolant = (alglib.spline1d.spline1dinterpolant)interpolant.make_copy(); 62 this. inputVariable = inputVar;63 this.x1 = inputVar; 63 64 } 64 65 … … 67 68 68 69 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 69 var solution = 70 solution.Name = $"Regression Spline ({ inputVariable})";70 var solution = new RegressionSolution(this, (IRegressionProblemData)problemData.Clone()); 71 solution.Name = $"Regression Spline ({x1})"; 71 72 72 73 return solution; 73 74 } 74 75 75 public double GetEstimatedValue(double x) => alglib.spline1d.spline1dcalc(interpolant, x );76 public double GetEstimatedValue(double x) => alglib.spline1d.spline1dcalc(interpolant, x, null); 76 77 77 78 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 78 return dataset.GetDoubleValues( inputVariable, rows).Select(GetEstimatedValue);79 return dataset.GetDoubleValues(x1, rows).Select(GetEstimatedValue); 79 80 } 80 81 -
branches/3116_GAM_Interactions/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj
r17931 r17933 142 142 <Compile Include="FixedDataAnalysisAlgorithm.cs" /> 143 143 <Compile Include="GAM\GeneralizedAdditiveModelAlgorithm.cs" /> 144 <Compile Include="GAM\AlglibRbfModel.cs" /> 145 <Compile Include="GAM\Spline2dModel.cs" /> 144 146 <Compile Include="GAM\Spline1dModel.cs" /> 145 147 <Compile Include="GaussianProcess\CovarianceFunctions\CovarianceSpectralMixture.cs" />
Note: See TracChangeset
for help on using the changeset viewer.