Free cookie consent management tool by TermsFeed Policy Generator

Changeset 12710


Ignore:
Timestamp:
07/10/15 13:42:37 (9 years ago)
Author:
gkronber
Message:

#2261: cached training and test rows in GBT for another speedup of ~1.5 (+renamed test class)

Location:
trunk/sources
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs

    r12700 r12710  
    4848      internal double r { get; private set; }
    4949      internal double m { get; private set; }
     50      internal int[] trainingRows { get; private set; }
     51      internal int[] testRows { get; private set; }
    5052      internal RegressionTreeBuilder treeBuilder { get; private set; }
    5153
     
    7173        random = new MersenneTwister(randSeed);
    7274        this.problemData = problemData;
     75        this.trainingRows = problemData.TrainingIndices.ToArray();
     76        this.testRows = problemData.TestIndices.ToArray();
    7377        this.lossFunction = lossFunction;
    7478
    75         int nRows = problemData.TrainingIndices.Count();
     79        int nRows = trainingRows.Length;
    7680
    77         y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
     81        y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, trainingRows).ToArray();
    7882
    7983        treeBuilder = new RegressionTreeBuilder(problemData, random);
     
    8488        double f0 = lossFunction.LineSearch(y, zeros, activeIdx, 0, nRows - 1); // initial constant value (mean for squared errors)
    8589        pred = Enumerable.Repeat(f0, nRows).ToArray();
    86         predTest = Enumerable.Repeat(f0, problemData.TestIndices.Count()).ToArray();
     90        predTest = Enumerable.Repeat(f0, testRows.Length).ToArray();
    8791        pseudoRes = new double[nRows];
    8892
     
    106110      }
    107111      public double GetTestLoss() {
    108         var yTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices);
    109         var nRows = problemData.TestIndices.Count();
     112        var yTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, testRows);
     113        var nRows = testRows.Length;
    110114        return lossFunction.GetLoss(yTest, predTest) / nRows;
    111115      }
     
    160164      var activeIdx = gbmState.activeIdx;
    161165      var pseudoRes = gbmState.pseudoRes;
     166      var trainingRows = gbmState.trainingRows;
     167      var testRows = gbmState.testRows;
    162168
    163169      // copy output of gradient function to pre-allocated rim array (pseudo-residual per row and model)
     
    170176
    171177      int i = 0;
    172       foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices)) {
     178      foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, trainingRows)) {
    173179        yPred[i] = yPred[i] + nu * pred;
    174180        i++;
     
    176182      // update predictions for validation set
    177183      i = 0;
    178       foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, problemData.TestIndices)) {
     184      foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, testRows)) {
    179185        yPredTest[i] = yPredTest[i] + nu * pred;
    180186        i++;
  • trunk/sources/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs

    r12700 r12710  
    99namespace HeuristicLab.Algorithms.DataAnalysis {
    1010  [TestClass()]
    11   public class Test {
     11  public class GradientBoostingTest {
    1212    [TestMethod]
    1313    [TestCategory("Algorithms.DataAnalysis")]
Note: See TracChangeset for help on using the changeset viewer.