Changeset 12607 for branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/LogisticRegressionLoss.cs
- Timestamp:
- 07/06/15 15:20:28 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/LogisticRegressionLoss.cs
r12590 r12607 28 28 29 29 namespace HeuristicLab.Algorithms.DataAnalysis { 30 // Greedy Function Approximation: A Gradient Boosting Machine (page 9) 30 31 public class LogisticRegressionLoss : ILossFunction { 31 32 public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) { … … 36 37 double s = 0; 37 38 while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) { 38 // assert target == 0 or target == 1 39 if (!targetEnum.Current.IsAlmost(0.0) && !targetEnum.Current.IsAlmost(1.0)) 40 throw new NotSupportedException("labels must be 0 or 1 for logistic regression loss"); 41 double f = Math.Max(-7, Math.Min(7, predEnum.Current)); // threshold for exponent 42 var probPos = Math.Exp(2 * f) / (1 + Math.Exp(2 * f)); 43 s += weightEnum.Current * (-targetEnum.Current * Math.Log(probPos) - (1 - targetEnum.Current) * Math.Log(1 - probPos)); 39 Debug.Assert(targetEnum.Current.IsAlmost(0.0) || targetEnum.Current.IsAlmost(1.0), "labels must be 0 or 1 for logistic regression loss"); 40 41 var y = targetEnum.Current * 2 - 1; // y in {-1,1} 42 s += weightEnum.Current * Math.Log(1 + Math.Exp(-2 * y * predEnum.Current)); 44 43 } 45 44 if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext()) 46 throw new ArgumentException("target, pred and weight have differ inglengths");45 throw new ArgumentException("target, pred and weight have different lengths"); 47 46 48 47 return s; … … 55 54 56 55 while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) { 57 // assert target == 0 or target == 1 58 if (!targetEnum.Current.IsAlmost(0.0) && !targetEnum.Current.IsAlmost(1.0)) 59 throw new NotSupportedException("labels must be 0 or 1 for logistic regression loss"); 60 double f = Math.Max(-7, Math.Min(7, predEnum.Current)); // threshold for exponent 61 var probPos = Math.Exp(2 * f) / (1 + Math.Exp(2 * f)); 62 yield return weightEnum.Current * (targetEnum.Current - probPos) / (probPos * probPos - probPos); 56 Debug.Assert(targetEnum.Current.IsAlmost(0.0) || targetEnum.Current.IsAlmost(1.0), "labels must be 0 or 1 for logistic regression loss"); 57 var y = targetEnum.Current * 2 - 1; // y in {-1,1} 58 59 yield return weightEnum.Current * 2 * y / (1 + Math.Exp(2 * y * predEnum.Current)); 60 63 61 } 64 62 if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext()) 65 throw new ArgumentException("target, pred and weight have differ inglengths");63 throw new ArgumentException("target, pred and weight have different lengths"); 66 64 } 67 65 … … 75 73 76 74 if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length) 77 throw new ArgumentException("target, pred and weight have differ inglengths");75 throw new ArgumentException("target, pred and weight have different lengths"); 78 76 79 // line search for abs error77 // "Simple Newton-Raphson step" of eqn. 23 80 78 LineSearchFunc lineSearch = (idx, startIdx, endIdx) => { 81 79 double sumY = 0.0; 82 80 double sumDiff = 0.0; 83 81 for (int i = startIdx; i <= endIdx; i++) { 84 var yi = (targetArr[idx[i]] - predArr[idx[i]]); 85 var wi = weightArr[idx[i]]; 82 var row = idx[i]; 83 var y = targetArr[row] * 2 - 1; // y in {-1,1} 84 var pseudoResponse = weightArr[row] * 2 * y / (1 + Math.Exp(2 * y * predArr[row])); 86 85 87 sumY += wi * yi; 88 sumDiff += wi * Math.Abs(yi) * (1 - Math.Abs(yi)); 89 86 sumY += pseudoResponse; 87 sumDiff += Math.Abs(pseudoResponse) * (2 - Math.Abs(pseudoResponse)); 90 88 } 91 89 // prevent divByZero 92 90 sumDiff = Math.Max(1E-12, sumDiff); 93 return 0.5 *sumY / sumDiff;91 return sumY / sumDiff; 94 92 }; 95 93 return lineSearch;
Note: See TracChangeset
for help on using the changeset viewer.