Changeset 12607 for branches/GBT-trunkintegration
- Timestamp:
- 07/06/15 15:20:28 (9 years ago)
- Location:
- branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
- Files:
-
- 6 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs
r12597 r12607 84 84 activeIdx = Enumerable.Range(0, nRows).ToArray(); 85 85 86 // prepare arrays (allocate only once) 87 double f0 = y.Average(); // default prediction (constant) 86 var zeros = Enumerable.Repeat(0.0, nRows); 87 var ones = Enumerable.Repeat(1.0, nRows); 88 double f0 = lossFunction.GetLineSearchFunc(y, zeros, ones)(activeIdx, 0, nRows - 1); // initial constant value (mean for squared errors) 88 89 pred = Enumerable.Repeat(f0, nRows).ToArray(); 89 90 predTest = Enumerable.Repeat(f0, problemData.TestIndices.Count()).ToArray(); -
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs
r12597 r12607 38 38 while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) { 39 39 double res = targetEnum.Current - predEnum.Current; 40 s += weightEnum.Current * Math.Abs(res); 40 s += weightEnum.Current * Math.Abs(res); // w * |res| 41 41 } 42 42 if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext()) … … 52 52 53 53 while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) { 54 // weight * sign(res)54 // dL(y, f(x)) / df(x) = weight * sign(res) 55 55 var res = targetEnum.Current - predEnum.Current; 56 56 if (res > 0) yield return weightEnum.Current; … … 67 67 var predArr = pred.ToArray(); 68 68 var weightArr = weight.ToArray(); 69 // weights are not supported yet70 // w hen weights are supported we need to calculate a weighted median69 // the optimal constant value that should be added to the predictions is the median of the residuals 70 // weights are not supported yet (need to calculate a weighted median) 71 71 Debug.Assert(weightArr.All(w => w.IsAlmost(1.0))); 72 72 … … 80 80 int nRows = endIdx - startIdx + 1; 81 81 var res = new double[nRows]; 82 for (int offset = 0; offset < nRows; offset++) { 83 var i = startIdx + offset; 82 for (int i = startIdx; i <= endIdx; i++) { 84 83 var row = idx[i]; 85 res[ offset] = targetArr[row] - predArr[row];84 res[i - startIdx] = targetArr[row] - predArr[row]; 86 85 } 87 86 return res.Median(); -
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/ILossFunction.cs
r12590 r12607 24 24 25 25 namespace HeuristicLab.Algorithms.DataAnalysis { 26 // returns the optimal value for the partition of rows stored in idx[startIdx] .. idx[endIdx] inclusive 26 27 public delegate double LineSearchFunc(int[] idx, int startIdx, int endIdx); 27 28 29 // represents an interface for loss functions used by gradient boosting 30 // target represents the target vector (original targets from the problem data, never changed) 31 // pred represents the current vector of predictions (a weighted combination of models learned so far, this vector is updated after each step) 32 // weight represents a weight vector for rows (this is not supported yet -> all weights are 1) 28 33 public interface ILossFunction { 29 // returns the weighted loss 34 // returns the weighted loss of the current prediction vector 30 35 double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight); 31 36 … … 37 42 } 38 43 } 44 45 -
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/LogisticRegressionLoss.cs
r12590 r12607 28 28 29 29 namespace HeuristicLab.Algorithms.DataAnalysis { 30 // Greedy Function Approximation: A Gradient Boosting Machine (page 9) 30 31 public class LogisticRegressionLoss : ILossFunction { 31 32 public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) { … … 36 37 double s = 0; 37 38 while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) { 38 // assert target == 0 or target == 1 39 if (!targetEnum.Current.IsAlmost(0.0) && !targetEnum.Current.IsAlmost(1.0)) 40 throw new NotSupportedException("labels must be 0 or 1 for logistic regression loss"); 41 double f = Math.Max(-7, Math.Min(7, predEnum.Current)); // threshold for exponent 42 var probPos = Math.Exp(2 * f) / (1 + Math.Exp(2 * f)); 43 s += weightEnum.Current * (-targetEnum.Current * Math.Log(probPos) - (1 - targetEnum.Current) * Math.Log(1 - probPos)); 39 Debug.Assert(targetEnum.Current.IsAlmost(0.0) || targetEnum.Current.IsAlmost(1.0), "labels must be 0 or 1 for logistic regression loss"); 40 41 var y = targetEnum.Current * 2 - 1; // y in {-1,1} 42 s += weightEnum.Current * Math.Log(1 + Math.Exp(-2 * y * predEnum.Current)); 44 43 } 45 44 if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext()) 46 throw new ArgumentException("target, pred and weight have differ inglengths");45 throw new ArgumentException("target, pred and weight have different lengths"); 47 46 48 47 return s; … … 55 54 56 55 while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) { 57 // assert target == 0 or target == 1 58 if (!targetEnum.Current.IsAlmost(0.0) && !targetEnum.Current.IsAlmost(1.0)) 59 throw new NotSupportedException("labels must be 0 or 1 for logistic regression loss"); 60 double f = Math.Max(-7, Math.Min(7, predEnum.Current)); // threshold for exponent 61 var probPos = Math.Exp(2 * f) / (1 + Math.Exp(2 * f)); 62 yield return weightEnum.Current * (targetEnum.Current - probPos) / (probPos * probPos - probPos); 56 Debug.Assert(targetEnum.Current.IsAlmost(0.0) || targetEnum.Current.IsAlmost(1.0), "labels must be 0 or 1 for logistic regression loss"); 57 var y = targetEnum.Current * 2 - 1; // y in {-1,1} 58 59 yield return weightEnum.Current * 2 * y / (1 + Math.Exp(2 * y * predEnum.Current)); 60 63 61 } 64 62 if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext()) 65 throw new ArgumentException("target, pred and weight have differ inglengths");63 throw new ArgumentException("target, pred and weight have different lengths"); 66 64 } 67 65 … … 75 73 76 74 if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length) 77 throw new ArgumentException("target, pred and weight have differ inglengths");75 throw new ArgumentException("target, pred and weight have different lengths"); 78 76 79 // line search for abs error77 // "Simple Newton-Raphson step" of eqn. 23 80 78 LineSearchFunc lineSearch = (idx, startIdx, endIdx) => { 81 79 double sumY = 0.0; 82 80 double sumDiff = 0.0; 83 81 for (int i = startIdx; i <= endIdx; i++) { 84 var yi = (targetArr[idx[i]] - predArr[idx[i]]); 85 var wi = weightArr[idx[i]]; 82 var row = idx[i]; 83 var y = targetArr[row] * 2 - 1; // y in {-1,1} 84 var pseudoResponse = weightArr[row] * 2 * y / (1 + Math.Exp(2 * y * predArr[row])); 86 85 87 sumY += wi * yi; 88 sumDiff += wi * Math.Abs(yi) * (1 - Math.Abs(yi)); 89 86 sumY += pseudoResponse; 87 sumDiff += Math.Abs(pseudoResponse) * (2 - Math.Abs(pseudoResponse)); 90 88 } 91 89 // prevent divByZero 92 90 sumDiff = Math.Max(1E-12, sumDiff); 93 return 0.5 *sumY / sumDiff;91 return sumY / sumDiff; 94 92 }; 95 93 return lineSearch; -
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/RelativeErrorLoss.cs
r12597 r12607 28 28 29 29 namespace HeuristicLab.Algorithms.DataAnalysis { 30 // relative error loss is a special case of weighted absolute error loss 31 // absolute loss is weighted by (1/target) 30 // relative error loss is a special case of weighted absolute error loss with weights = (1/target) 32 31 public class RelativeErrorLoss : ILossFunction { 33 32 public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) { … … 42 41 } 43 42 if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext()) 44 throw new ArgumentException("target, pred and weight have differ inglengths");43 throw new ArgumentException("target, pred and weight have different lengths"); 45 44 46 45 return s; … … 60 59 } 61 60 if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext()) 62 throw new ArgumentException("target, pred and weight have differ inglengths");61 throw new ArgumentException("target, pred and weight have different lengths"); 63 62 } 64 63 … … 70 69 71 70 if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length) 72 throw new ArgumentException("target, pred and weight have differ inglengths");71 throw new ArgumentException("target, pred and weight have different lengths"); 73 72 74 73 // line search for relative error … … 82 81 var w0 = weightArr[idx[startIdx]] * Math.Abs(1.0 / targetArr[idx[startIdx]]); 83 82 var w1 = weightArr[idx[endIdx]] * Math.Abs(1.0 / targetArr[idx[endIdx]]); 84 return (w0 * (targetArr[idx[startIdx]] - predArr[idx[startIdx]]) + w1 * (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / (w0 + w1); 83 if (w0 > w1) { 84 return targetArr[idx[startIdx]] - predArr[idx[startIdx]]; 85 } else if (w0 < w1) { 86 return targetArr[idx[endIdx]] - predArr[idx[endIdx]]; 87 } else { 88 // same weight -> return average of both residuals 89 return ((targetArr[idx[startIdx]] - predArr[idx[startIdx]]) + (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / 2; 90 } 85 91 } else { 86 var ts = from offset in Enumerable.Range(0, nRows) 87 let i = startIdx + offset 88 let row = idx[i] 89 select new { res = targetArr[row] - predArr[row], weight = weightArr[row] * Math.Abs(1.0 / targetArr[row]) }; 90 ts = ts.OrderBy(t => t.res); 91 var totalWeight = ts.Sum(t => t.weight); 92 var tsEnumerator = ts.GetEnumerator(); 93 tsEnumerator.MoveNext(); 92 // create an array of key-value pairs to be sorted (instead of using Array.Sort(res, weights)) 93 var res_w = new KeyValuePair<double, double>[nRows]; 94 var totalWeight = 0.0; 95 for (int i = startIdx; i <= endIdx; i++) { 96 int row = idx[i]; 97 var res = targetArr[row] - predArr[row]; 98 var w = weightArr[row] * Math.Abs(1.0 / targetArr[row]); 99 res_w[i - startIdx] = new KeyValuePair<double, double>(res, w); 100 totalWeight += w; 101 } 102 res_w.StableSort((a, b) => Math.Sign(a.Key - b.Key)); 94 103 95 double aggWeight = tsEnumerator.Current.weight; // weight of first96 97 while ( aggWeight <totalWeight / 2) {98 tsEnumerator.MoveNext();99 aggWeight += tsEnumerator.Current.weight;104 int k = 0; 105 double sum = totalWeight - res_w[k].Value; // total - first weight 106 while (sum > totalWeight / 2) { 107 k++; 108 sum -= res_w[k].Value; 100 109 } 101 return tsEnumerator.Current.res;110 return res_w[k].Key; 102 111 } 103 112 }; -
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/SquaredErrorLoss.cs
r12590 r12607 35 35 while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) { 36 36 double res = targetEnum.Current - predEnum.Current; 37 s += weightEnum.Current * res * res; 37 s += weightEnum.Current * res * res; // w * (res)^2 38 38 } 39 39 if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext()) 40 throw new ArgumentException("target, pred and weight have differ inglengths");40 throw new ArgumentException("target, pred and weight have different lengths"); 41 41 42 42 return s; … … 49 49 50 50 while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) { 51 yield return weightEnum.Current * 2.0 * (targetEnum.Current - predEnum.Current); 51 yield return weightEnum.Current * 2.0 * (targetEnum.Current - predEnum.Current); // dL(y, f(x)) / df(x) = w * 2 * res 52 52 } 53 53 if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext()) 54 throw new ArgumentException("target, pred and weight have differ inglengths");54 throw new ArgumentException("target, pred and weight have different lengths"); 55 55 } 56 56 … … 60 60 var weightArr = weight.ToArray(); 61 61 if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length) 62 throw new ArgumentException("target, pred and weight have differ inglengths");62 throw new ArgumentException("target, pred and weight have different lengths"); 63 63 64 // line search for squared error loss => return the average value 64 // line search for squared error loss 65 // for a given partition of rows the optimal constant that should be added to the current prediction values is the average of the residuals 65 66 LineSearchFunc lineSearch = (idx, startIdx, endIdx) => { 66 67 double s = 0.0;
Note: See TracChangeset
for help on using the changeset viewer.