Free cookie consent management tool by TermsFeed Policy Generator

Changeset 17933


Ignore:
Timestamp:
04/09/21 19:51:38 (3 years ago)
Author:
gkronber
Message:

#3116: added GAM base-learners for 2d cubic splines and 3d functions (using alglib RBF model)

Location:
branches/3116_GAM_Interactions/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
2 added
3 edited

Legend:

Unmodified
Added
Removed
  • branches/3116_GAM_Interactions/HeuristicLab.Algorithms.DataAnalysis/3.4/GAM/GeneralizedAdditiveModelAlgorithm.cs

    r17888 r17933  
    4747    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    4848    private const string CreateSolutionParameterName = "CreateSolution";
     49    private const string BaseLearnerParameterName = "BaseLearner";
     50    private const string Spline1dBaseLearner = "1d cubic regression spline (penalized)";
     51    private const string Spline2dBaseLearner = "2d cubic spline";
     52    private const string Spline3dBaseLearner = "3d cubic spline";
    4953
    5054    #endregion
     
    7276    }
    7377
     78    public IConstrainedValueParameter<StringValue> BaseLearnerParameter {
     79      get { return (IConstrainedValueParameter<StringValue>)Parameters[BaseLearnerParameterName]; }
     80    }
    7481    #endregion
    7582
     
    101108    }
    102109
     110    public string BaseLearner {
     111      get { return BaseLearnerParameter.Value.Value; }
     112      set { BaseLearnerParameter.Value = BaseLearnerParameter.ValidValues.Single(vv => vv.Value == value); }
     113    }
    103114    #endregion
     115
     116    public override bool SupportsStop => true;
    104117
    105118    [StorableConstructor]
     
    122135        "Number of iterations. Try a large value and check convergence of the error over iterations. Usually, only a few iterations (e.g. 10) are needed for convergence.", new IntValue(10)));
    123136      Parameters.Add(new FixedValueParameter<DoubleValue>(LambdaParameterName,
    124         "The penalty parameter for the penalized regression splines. Set to a value between -8 (weak smoothing) and 8 (strong smooting). Usually, a value between -4 and 4 should be fine", new DoubleValue(3)));
     137        "The penalty parameter for the penalized regression splines. Set to a value between -8 (weak smoothing) and 8 (strong smoothing). Usually, a value between -4 and 4 should be fine", new DoubleValue(3)));
    125138      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName,
    126139        "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
     
    130143        "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    131144      Parameters[CreateSolutionParameterName].Hidden = true;
     145      var validBaseLearners = new ItemSet<StringValue>(new StringValue[] { new StringValue(Spline1dBaseLearner).AsReadOnly(), new StringValue(Spline2dBaseLearner).AsReadOnly(), new StringValue(Spline3dBaseLearner).AsReadOnly()});
     146      Parameters.Add(new ConstrainedValueParameter<StringValue>("BaseLearner", "The model to use for the additive functions", validBaseLearners, validBaseLearners.First()));
    132147    }
    133148
     
    148163      var inputVars = problemData.AllowedInputVariables.ToArray();
    149164
    150       int nTerms = inputVars.Length;
     165      var maxInteractions = 0;
     166      if (BaseLearner == Spline1dBaseLearner) maxInteractions = 1;
     167      else if (BaseLearner == Spline2dBaseLearner) maxInteractions = 2;
     168      else if (BaseLearner == Spline3dBaseLearner) maxInteractions = 3;
     169      else throw new ArgumentException("Unknown base learner.");
     170      var varTuples = EnumerableExtensions.Combinations(inputVars, maxInteractions).ToArray();
     171      var varTupleNames = varTuples.Select(comb => string.Join(" ", comb)).ToArray();
     172      int nTerms = varTuples.Length;
    151173
    152174      #region init results
     
    167189
    168190      // calculate table with residual contributions of each term
    169       var rssTable = new DoubleMatrix(nTerms, 1, new string[] { "RSS" }, inputVars);
     191      var rssTable = new DoubleMatrix(nTerms, 1, new string[] { "RSS" }, varTupleNames);
    170192      Results.Add(new Result("RSS Values", rssTable));
    171193      #endregion
     
    194216        idx.ShuffleInPlace(rand);
    195217        foreach (var inputIdx in idx) {
    196           var inputVar = inputVars[inputIdx];
     218          if (cancellationToken.IsCancellationRequested) break;
     219          var fInputVars = varTuples[inputIdx];
    197220          // first remove the effect of the previous model for the inputIdx (by adding the output of the current model to the residual)
    198221          AddInPlace(res, f[inputIdx].GetEstimatedValues(ds, trainRows));
     
    200223
    201224          rssTable[inputIdx, 0] = MSE(res);
    202           f[inputIdx] = RegressSpline(problemData, inputVar, res, lambda);
     225          if (fInputVars.Count() == 1) {
     226            f[inputIdx] = Regress1dSpline(problemData, fInputVars.Single(), res, lambda);
     227          } else if (fInputVars.Count() == 2) {
     228            f[inputIdx] = Regress2dSpline(problemData, fInputVars, res, lambda);
     229          } else if (fInputVars.Count() == 3) {
     230            f[inputIdx] = Regress3dSpline(problemData, fInputVars, res, lambda);
     231          } else {
     232            throw new ArgumentException();
     233          }
    203234
    204235          SubtractInPlace(res, f[inputIdx].GetEstimatedValues(ds, trainRows));
     
    217248        var model = new RegressionEnsembleModel(f.Concat(new[] { new ConstantModel(avgY, problemData.TargetVariable) }));
    218249        model.AverageModelEstimates = false;
    219         var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());       
     250        var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
    220251        Results.Add(new Result("Ensemble solution", solution));
    221252      }
     
    223254
    224255    public static double MSE(IEnumerable<double> residuals) {
    225       var mse  = residuals.Select(r => r * r).Average();
     256      var mse = residuals.Select(r => r * r).Average();
    226257      return mse;
    227258    }
     
    233264    }
    234265
    235     private IRegressionModel RegressSpline(IRegressionProblemData problemData, string inputVar, double[] target, double lambda) {
     266    private IRegressionModel Regress1dSpline(IRegressionProblemData problemData, string inputVar, double[] target, double lambda) {
    236267      var x = problemData.Dataset.GetDoubleValues(inputVar, problemData.TrainingIndices).ToArray();
    237268      var y = (double[])target.Clone();
    238       int info;
    239269      alglib.spline1dinterpolant s;
    240       alglib.spline1dfitreport rep;
    241270      int numKnots = (int)Math.Min(50, 3 * Math.Sqrt(x.Length)); // heuristic for number of knots  (Elements of Statistical Learning)
    242271
    243       alglib.spline1dfitpenalized(x, y, numKnots, lambda, out info, out s, out rep);
     272      alglib.spline1dfitpenalized(x, y, numKnots, lambda, out _, out s, out _);
    244273
    245274      return new Spline1dModel(s.innerobj, problemData.TargetVariable, inputVar);
     275    }
     276
     277    private IRegressionModel Regress2dSpline(IRegressionProblemData problemData, IEnumerable<string> inputVars, double[] target, double lambda) {
     278      var x = problemData.Dataset.ToArray(inputVars, problemData.TrainingIndices);
     279      var xy = x.HorzCat(target);
     280
     281      alglib.spline2dbuilder builder;
     282      int d = 1; // scalar output
     283      int numKnots = (int)Math.Min(50, 3 * Math.Sqrt(x.Length)); // heuristic for number of knots  (Elements of Statistical Learning)
     284
     285      alglib.spline2dbuildercreate(d, out builder);
     286      alglib.spline2dbuildersetpoints(builder, xy, xy.GetLength(0));
     287      alglib.spline2dbuildersetgrid(builder, numKnots, numKnots);
     288      alglib.spline2dbuildersetalgoblocklls(builder, Math.Exp(lambda));
     289
     290      //
     291      // Now we are ready to fit and evaluate our results
     292      //
     293      alglib.spline2dinterpolant s;
     294      alglib.spline2dfit(builder, out s, out _);
     295      return new Spline2dModel(s, problemData.TargetVariable, inputVars.ElementAt(0), inputVars.ElementAt(1));
     296    }
     297
     298    private IRegressionModel Regress3dSpline(IRegressionProblemData problemData, IEnumerable<string> inputVars, double[] target, double lambda) {
     299      var x = problemData.Dataset.ToArray(inputVars, problemData.TrainingIndices);
     300      var xy = x.HorzCat(target);
     301
     302      var rbase = 100.0;
     303      var nlayers = 3;
     304
     305      alglib.rbfmodel model;
     306      alglib.rbfcreate(nx: 3, ny: 1, out model);
     307      alglib.rbfsetpoints(model, xy);
     308      alglib.rbfsetalgohierarchical(model, rbase, nlayers, Math.Exp(lambda));
     309      alglib.rbfbuildmodel(model, out _);
     310
     311      return new AlglibRbfModel(model, problemData.TargetVariable, inputVars.ElementAt(0), inputVars.ElementAt(1), inputVars.ElementAt(2));
    246312    }
    247313
  • branches/3116_GAM_Interactions/HeuristicLab.Algorithms.DataAnalysis/3.4/GAM/Spline1dModel.cs

    r17889 r17933  
    2929
    3030namespace HeuristicLab.Algorithms.DataAnalysis {
     31
    3132  [Item("Spline model (1d)",
    3233    "Univariate spline model (wrapper for alglib.spline1dmodel)")]
     
    4041      set {
    4142        if (value.Length > 1) throw new ArgumentException("A one-dimensional spline model supports only one input variable.");
    42         inputVariable = value[0];
     43        x1 = value[0];
    4344      }
    4445    }
    4546
    4647    [Storable]
    47     private string inputVariable;
    48     public override IEnumerable<string> VariablesUsedForPrediction => new[] { inputVariable };
     48    private string x1;
     49    public override IEnumerable<string> VariablesUsedForPrediction => new[] { x1 };
    4950
    5051    [StorableConstructor]
     
    5455
    5556    private Spline1dModel(Spline1dModel orig, Cloner cloner) : base(orig, cloner) {
    56       this.inputVariable = orig.inputVariable;
     57      this.x1 = orig.x1;
    5758      this.interpolant = (alglib.spline1d.spline1dinterpolant)orig.interpolant.make_copy();
    5859    }
     
    6061      : base(targetVar, $"Spline model ({inputVar})") {
    6162      this.interpolant = (alglib.spline1d.spline1dinterpolant)interpolant.make_copy();
    62       this.inputVariable = inputVar;     
     63      this.x1 = inputVar;
    6364    }
    6465
     
    6768
    6869    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    69       var solution =  new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
    70       solution.Name = $"Regression Spline ({inputVariable})";
     70      var solution = new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
     71      solution.Name = $"Regression Spline ({x1})";
    7172
    7273      return solution;
    7374    }
    7475
    75     public double GetEstimatedValue(double x) => alglib.spline1d.spline1dcalc(interpolant, x);
     76    public double GetEstimatedValue(double x) => alglib.spline1d.spline1dcalc(interpolant, x, null);
    7677
    7778    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    78       return dataset.GetDoubleValues(inputVariable, rows).Select(GetEstimatedValue);
     79      return dataset.GetDoubleValues(x1, rows).Select(GetEstimatedValue);
    7980    }
    8081
  • branches/3116_GAM_Interactions/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r17931 r17933  
    142142    <Compile Include="FixedDataAnalysisAlgorithm.cs" />
    143143    <Compile Include="GAM\GeneralizedAdditiveModelAlgorithm.cs" />
     144    <Compile Include="GAM\AlglibRbfModel.cs" />
     145    <Compile Include="GAM\Spline2dModel.cs" />
    144146    <Compile Include="GAM\Spline1dModel.cs" />
    145147    <Compile Include="GaussianProcess\CovarianceFunctions\CovarianceSpectralMixture.cs" />
Note: See TracChangeset for help on using the changeset viewer.