Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/17/16 15:45:04 (9 years ago)
Author:
mkommend
Message:

#1795: Added linear scaling of solutions while producing a model ensemble for GBM.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GBM/GradientBoostingRegressionAlgorithm.cs

    r13898 r13917  
    332332          // just produce an ensemble solution for now (TODO: correct scaling or linear regression for ensemble model weights)
    333333
    334           var ensembleModel = new RegressionEnsembleModel(models) { AverageModelEstimates = false };
    335           var ensembleSolution = ensembleModel.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
     334          var ensembleSolution = CreateEnsembleSolution(models, (IRegressionProblemData)problemData.Clone());
    336335          Results.Add(new Result("EnsembleSolution", ensembleSolution));
    337336        }
     
    341340        alg.Prepare(true);
    342341      }
     342    }
     343
     344    private static IRegressionEnsembleSolution CreateEnsembleSolution(List<IRegressionModel> models,
     345      IRegressionProblemData problemData) {
     346      var rows = problemData.TrainingPartition.Size;
     347      var features = models.Count;
     348      double[,] inputMatrix = new double[rows, features + 1];
     349      //add model estimates
     350      for (int m = 0; m < models.Count; m++) {
     351        var model = models[m];
     352        var estimates = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
     353        int estimatesCounter = 0;
     354        foreach (var estimate in estimates) {
     355          inputMatrix[estimatesCounter, m] = estimate;
     356          estimatesCounter++;
     357        }
     358      }
     359
     360      //add target
     361      var targets = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     362      int targetCounter = 0;
     363      foreach (var target in targets) {
     364        inputMatrix[targetCounter, models.Count] = target;
     365        targetCounter++;
     366      }
     367
     368      alglib.linearmodel lm = new alglib.linearmodel();
     369      alglib.lrreport ar = new alglib.lrreport();
     370      double[] coefficients;
     371      int retVal = 1;
     372      alglib.lrbuildz(inputMatrix, rows, features, out retVal, out lm, out ar);
     373      if (retVal != 1) throw new ArgumentException("Error in calculation of linear regression solution");
     374
     375      alglib.lrunpack(lm, out coefficients, out features);
     376
     377      var ensembleModel = new RegressionEnsembleModel(models, coefficients.Take(models.Count)) { AverageModelEstimates = false };
     378      var ensembleSolution = ensembleModel.CreateRegressionSolution(problemData);
     379      return ensembleSolution;
    343380    }
    344381
Note: See TracChangeset for help on using the changeset viewer.