Free cookie consent management tool by TermsFeed Policy Generator

Changeset 15433


Ignore:
Timestamp:
10/27/17 11:10:25 (7 years ago)
Author:
gkronber
Message:

#2789 added the possibility to include interaction terms in GAM

Location:
branches/MathNetNumerics-Exploration-2789/HeuristicLab.Algorithms.DataAnalysis.Experimental
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/MathNetNumerics-Exploration-2789/HeuristicLab.Algorithms.DataAnalysis.Experimental/GAM.cs

    r15369 r15433  
    4747    private GAM(bool deserializing) : base(deserializing) { }
    4848    [StorableHook(HookType.AfterDeserialization)]
    49     private void AfterDeserialization() {
     49    private void AfterDeserialization() {     
    5050    }
    5151
     
    6262      Parameters.Add(new ValueParameter<DoubleValue>("Lambda", "Regularization for smoothing splines", new DoubleValue(1.0)));
    6363      Parameters.Add(new ValueParameter<IntValue>("Max iterations", "", new IntValue(100)));
    64     }
    65 
     64      Parameters.Add(new ValueParameter<IntValue>("Max interactions", "", new IntValue(1)));
     65    }   
    6666
    6767    protected override void Run(CancellationToken cancellationToken) {
     
    6969      double lambda = ((IValueParameter<DoubleValue>)Parameters["Lambda"]).Value.Value;
    7070      int maxIters = ((IValueParameter<IntValue>)Parameters["Max iterations"]).Value.Value;
     71      int maxInteractions = ((IValueParameter<IntValue>)Parameters["Max interactions"]).Value.Value;
     72      if (maxInteractions < 1 || maxInteractions > 5) throw new ArgumentException("Max interactions is outside the valid range [1 .. 5]");
    7173
    7274      // calculates a GAM model using a linear representation + independent non-linear functions of each variable
     
    7779      var avgY = y.Average();
    7880      var inputVars = Problem.ProblemData.AllowedInputVariables.ToArray();
    79       IRegressionModel[] f = new IRegressionModel[inputVars.Length * 2]; // linear(x) + nonlinear(x)
     81      var nTerms = inputVars.Length; // LR
     82      for(int i=1;i<=maxInteractions;i++) {
     83        nTerms += inputVars.Combinations(i).Count();
     84      }
     85      IRegressionModel[] f = new IRegressionModel[nTerms];
    8086      for(int i=0;i<f.Length;i++) {
    8187        f[i] = new ConstantModel(0.0, problemData.TargetVariable);
     
    9096      Results.Add(new Result("RMSE", rmseTable));
    9197      rmseRow.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TrainingIndices).StandardDeviation()); // -1 index to use all predictors
    92       rmseRowTest.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TestIndices).StandardDeviation());
     98      rmseRowTest.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TestIndices).StandardDeviation());
     99
     100      // for analytics
     101      double[] rss = new double[f.Length];
     102      string[] terms = new string[f.Length];
     103      Results.Add(new Result("RSS Values", typeof(DoubleMatrix)));
    93104
    94105      // until convergence
     
    99110        foreach (var inputVar in inputVars) {
    100111          var res = CalculateResiduals(problemData, f, j, avgY, problemData.TrainingIndices);
     112          rss[j] = res.Variance();
     113          terms[j] = inputVar;
    101114          f[j] = RegressLR(problemData, inputVar, res);
    102115          j++;
    103116        }
    104         foreach (var inputVar in inputVars) {
    105           var res = CalculateResiduals(problemData, f, j, avgY, problemData.TrainingIndices);
    106           f[j] = RegressSpline(problemData, inputVar, res, lambda);
    107           j++;
     117
     118        for(int interaction = 1; interaction <= maxInteractions;interaction++) {
     119          var selectedVars = HeuristicLab.Common.EnumerableExtensions.Combinations(inputVars, interaction);
     120
     121          foreach (var element in selectedVars) {
     122            var res = CalculateResiduals(problemData, f, j, avgY, problemData.TrainingIndices);
     123            rss[j] = res.Variance();
     124            terms[j] = string.Format("f({0})", string.Join(",", element));
     125            f[j] = RegressSpline(problemData, element.ToArray(), res, lambda);
     126            j++;
     127          }
    108128        }
    109129
    110130        rmseRow.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TrainingIndices).StandardDeviation()); // -1 index to use all predictors
    111131        rmseRowTest.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TestIndices).StandardDeviation());
     132
     133        // calculate table with residual contributions of each term
     134        var rssTable = new DoubleMatrix(rss.Length, 1, new string[] { "RSS" }, terms);
     135        for (int i = 0; i < rss.Length; i++) rssTable[i, 0] = rss[i];
     136        Results["RSS Values"].Value = rssTable;
    112137
    113138        if (cancellationToken.IsCancellationRequested) break;
     
    117142      model.AverageModelEstimates = false;
    118143      Results.Add(new Result("Ensemble solution", model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone())));
     144
    119145    }
    120146
     
    158184      } else return new ConstantModel(target.Average(), problemData.TargetVariable);
    159185    }
     186    private IRegressionModel RegressSpline(IRegressionProblemData problemData, string[] inputVars, double[] target, double lambda) {
     187      if (inputVars.All(problemData.Dataset.VariableHasType<double>)) {
     188        var product = problemData.Dataset.GetDoubleValues(inputVars.First(), problemData.TrainingIndices).ToArray();
     189        for(int i = 1;i<inputVars.Length;i++) {
     190          product = product.Zip(problemData.Dataset.GetDoubleValues(inputVars[i], problemData.TrainingIndices), (pi, vi) => pi * vi).ToArray();
     191        }
     192        // Umständlich!
     193        return Splines.CalculatePenalizedRegressionSpline(
     194          product,
     195          (double[])target.Clone(), lambda,
     196          problemData.TargetVariable, inputVars
     197          );
     198      } else return new ConstantModel(target.Average(), problemData.TargetVariable);
     199    }
     200
    160201    private IRegressionModel RegressRF(IRegressionProblemData problemData, string inputVar, double[] target, double lambda) {
    161202      if (problemData.Dataset.VariableHasType<double>(inputVar)) {
  • branches/MathNetNumerics-Exploration-2789/HeuristicLab.Algorithms.DataAnalysis.Experimental/Splines.cs

    r15369 r15433  
    763763
    764764    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    765       foreach (var x in dataset.GetDoubleValues(VariablesUsedForPrediction.First(), rows)) {
     765      var product = dataset.GetDoubleValues(VariablesUsedForPrediction.First(), rows).ToArray();
     766      foreach(var factor in VariablesUsedForPrediction.Skip(1)) {
     767        product = product.Zip(dataset.GetDoubleValues(factor, rows), (pi, fi) => pi * fi).ToArray();
     768      }
     769
     770      foreach (var x in product) {
    766771        yield return alglib.spline1dcalc(interpolant, x);
    767772      }
Note: See TracChangeset for help on using the changeset viewer.