Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/09/15 16:11:52 (9 years ago)
Author:
gkronber
Message:

#2261: killed all weights

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs

    r12661 r12696  
    5555      internal double[] pred;
    5656      internal double[] predTest;
    57       internal double[] w;
    5857      internal double[] y;
    5958      internal int[] activeIdx;
     
    7776
    7877        y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
    79         // weights are all 1 for now (HL doesn't support weights yet)
    80         w = Enumerable.Repeat(1.0, nRows).ToArray();
    8178
    8279        treeBuilder = new RegressionTreeBuilder(problemData, random);
     
    8582
    8683        var zeros = Enumerable.Repeat(0.0, nRows);
    87         var ones = Enumerable.Repeat(1.0, nRows);
    88         double f0 = lossFunction.GetLineSearchFunc(y, zeros, ones)(activeIdx, 0, nRows - 1); // initial constant value (mean for squared errors)
     84        double f0 = lossFunction.GetLineSearchFunc(y, zeros)(activeIdx, 0, nRows - 1); // initial constant value (mean for squared errors)
    8985        pred = Enumerable.Repeat(f0, nRows).ToArray();
    9086        predTest = Enumerable.Repeat(f0, problemData.TestIndices.Count()).ToArray();
     
    107103      public double GetTrainLoss() {
    108104        int nRows = y.Length;
    109         return lossFunction.GetLoss(y, pred, w) / nRows;
     105        return lossFunction.GetLoss(y, pred) / nRows;
    110106      }
    111107      public double GetTestLoss() {
    112108        var yTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices);
    113         var wTest = problemData.TestIndices.Select(_ => 1.0); // ones
    114109        var nRows = problemData.TestIndices.Count();
    115         return lossFunction.GetLoss(yTest, predTest, wTest) / nRows;
     110        return lossFunction.GetLoss(yTest, predTest) / nRows;
    116111      }
    117112    }
     
    160155      var yPred = gbmState.pred;
    161156      var yPredTest = gbmState.predTest;
    162       var w = gbmState.w;
    163157      var treeBuilder = gbmState.treeBuilder;
    164158      var y = gbmState.y;
     
    168162      // copy output of gradient function to pre-allocated rim array (pseudo-residuals)
    169163      int rimIdx = 0;
    170       foreach (var g in lossFunction.GetLossGradient(y, yPred, w)) {
     164      foreach (var g in lossFunction.GetLossGradient(y, yPred)) {
    171165        pseudoRes[rimIdx++] = g;
    172166      }
    173167
    174       var tree = treeBuilder.CreateRegressionTreeForGradientBoosting(pseudoRes, maxSize, activeIdx, lossFunction.GetLineSearchFunc(y, yPred, w), r, m);
     168      var tree = treeBuilder.CreateRegressionTreeForGradientBoosting(pseudoRes, maxSize, activeIdx, lossFunction.GetLineSearchFunc(y, yPred), r, m);
    175169
    176170      int i = 0;
Note: See TracChangeset for help on using the changeset viewer.