Changeset 12710 for trunk/sources/HeuristicLab.Algorithms.DataAnalysis
- Timestamp:
- 07/10/15 13:42:37 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs
r12700 r12710 48 48 internal double r { get; private set; } 49 49 internal double m { get; private set; } 50 internal int[] trainingRows { get; private set; } 51 internal int[] testRows { get; private set; } 50 52 internal RegressionTreeBuilder treeBuilder { get; private set; } 51 53 … … 71 73 random = new MersenneTwister(randSeed); 72 74 this.problemData = problemData; 75 this.trainingRows = problemData.TrainingIndices.ToArray(); 76 this.testRows = problemData.TestIndices.ToArray(); 73 77 this.lossFunction = lossFunction; 74 78 75 int nRows = problemData.TrainingIndices.Count();79 int nRows = trainingRows.Length; 76 80 77 y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();81 y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, trainingRows).ToArray(); 78 82 79 83 treeBuilder = new RegressionTreeBuilder(problemData, random); … … 84 88 double f0 = lossFunction.LineSearch(y, zeros, activeIdx, 0, nRows - 1); // initial constant value (mean for squared errors) 85 89 pred = Enumerable.Repeat(f0, nRows).ToArray(); 86 predTest = Enumerable.Repeat(f0, problemData.TestIndices.Count()).ToArray();90 predTest = Enumerable.Repeat(f0, testRows.Length).ToArray(); 87 91 pseudoRes = new double[nRows]; 88 92 … … 106 110 } 107 111 public double GetTestLoss() { 108 var yTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices);109 var nRows = problemData.TestIndices.Count();112 var yTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, testRows); 113 var nRows = testRows.Length; 110 114 return lossFunction.GetLoss(yTest, predTest) / nRows; 111 115 } … … 160 164 var activeIdx = gbmState.activeIdx; 161 165 var pseudoRes = gbmState.pseudoRes; 166 var trainingRows = gbmState.trainingRows; 167 var testRows = gbmState.testRows; 162 168 163 169 // copy output of gradient function to pre-allocated rim array (pseudo-residual per row and model) … … 170 176 171 177 int i = 0; 172 foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices)) {178 foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, trainingRows)) { 173 179 yPred[i] = yPred[i] + nu * pred; 174 180 i++; … … 176 182 // update predictions for validation set 177 183 i = 0; 178 foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, problemData.TestIndices)) {184 foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, testRows)) { 179 185 yPredTest[i] = yPredTest[i] + nu * pred; 180 186 i++;
Note: See TracChangeset
for help on using the changeset viewer.