Changeset 12597 for branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions
- Timestamp:
- 07/06/15 13:02:19 (9 years ago)
- Location:
- branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs
r12590 r12597 52 52 53 53 while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) { 54 // weight * sign(res) 54 55 var res = targetEnum.Current - predEnum.Current; 55 56 if (res > 0) yield return weightEnum.Current; … … 61 62 } 62 63 64 // return median of residuals 63 65 public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) { 64 66 var targetArr = target.ToArray(); -
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/RelativeErrorLoss.cs
r12590 r12597 29 29 namespace HeuristicLab.Algorithms.DataAnalysis { 30 30 // relative error loss is a special case of weighted absolute error loss 31 // absolute loss is weighted by (1/target) 31 32 public class RelativeErrorLoss : ILossFunction { 32 33 public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) { … … 52 53 53 54 while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) { 55 // weight * sign(res) * abs(1 / target) 54 56 var res = targetEnum.Current - predEnum.Current; 55 57 if (res > 0) yield return weightEnum.Current * 1.0 / Math.Abs(targetEnum.Current); … … 71 73 72 74 // line search for relative error 73 // TODO: check and improve?75 // weighted median (weight = 1/target) 74 76 LineSearchFunc lineSearch = (idx, startIdx, endIdx) => { 75 77 // weighted median calculation 76 78 int nRows = endIdx - startIdx + 1; // startIdx and endIdx are inclusive 77 if (nRows == 1) return targetArr[idx[startIdx]] - predArr[idx[startIdx]]; 79 if (nRows == 1) return targetArr[idx[startIdx]] - predArr[idx[startIdx]]; // res 78 80 else if (nRows == 2) { 81 // weighted average of two residuals 79 82 var w0 = weightArr[idx[startIdx]] * Math.Abs(1.0 / targetArr[idx[startIdx]]); 80 83 var w1 = weightArr[idx[endIdx]] * Math.Abs(1.0 / targetArr[idx[endIdx]]); 81 if (w0 > w1) { 82 return targetArr[idx[startIdx]] - predArr[idx[startIdx]]; 83 } else if (w0 < w1) { 84 return targetArr[idx[endIdx]] - predArr[idx[endIdx]]; 85 } else { 86 // same weight 87 return ((targetArr[idx[startIdx]] - predArr[idx[startIdx]]) + (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / 2; 88 } 84 return (w0 * (targetArr[idx[startIdx]] - predArr[idx[startIdx]]) + w1 * (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / (w0 + w1); 89 85 } else { 90 86 var ts = from offset in Enumerable.Range(0, nRows) 91 87 let i = startIdx + offset 92 select new { res = targetArr[idx[i]] - predArr[idx[i]], weight = weightArr[idx[i]] * Math.Abs(1.0 / targetArr[idx[i]]) }; 88 let row = idx[i] 89 select new { res = targetArr[row] - predArr[row], weight = weightArr[row] * Math.Abs(1.0 / targetArr[row]) }; 93 90 ts = ts.OrderBy(t => t.res); 94 91 var totalWeight = ts.Sum(t => t.weight);
Note: See TracChangeset
for help on using the changeset viewer.