Changeset 12620


Ignore:
Timestamp:
07/06/15 20:38:56 (4 years ago)
Author:
gkronber
Message:

#2261: corrected check if a split is useful, added a unit test class and added an elaborate comment on split quality calculation

Location:
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
1 added
3 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r12619 r12620  
    186186
    187187      // init
    188       var problemData = Problem.ProblemData;
     188      var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
    189189      var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>()
    190190        .Single(l => l.ToString() == LossFunctionParameter.Value.Value);
     
    245245        } else {
    246246          // otherwise we produce a regression solution
    247           Results.Add(new Result("Solution", new RegressionSolution(state.GetModel(), (IRegressionProblemData)problemData.Clone())));
     247          Results.Add(new Result("Solution", new RegressionSolution(state.GetModel(), problemData)));
    248248        }
    249249      }
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12619 r12620  
    2222
    2323using System;
     24using System.Collections;
    2425using System.Collections.Generic;
    2526using System.Diagnostics;
     
    5354    private readonly int[][] sortedIdx; // random selection from sortedIdxAll (for r < 1.0)
    5455
    55 
     56    private int calls = 0;
    5657
    5758    // helper arrays which are allocated to maximal necessary size only once in the ctor
     
    6465
    6566    private class Partition {
    66       public int TreeNodeIdx { get; set; }
     67      public int ParentNodeIdx { get; set; }
    6768      public int Depth { get; set; }
    6869      public int StartIdx { get; set; }
     
    7071      public bool Left { get; set; }
    7172    }
    72     private readonly IList<Partition> queue;
     73    private readonly SortedList<double, Partition> queue;
    7374
    7475    // prepare and allocate buffer variables in ctor
     
    9495      outx = new double[rows];
    9596      outSortedIdx = new int[rows];
    96       queue = new List<Partition>();
     97      queue = new SortedList<double, Partition>();
    9798
    9899      x = new double[nCols][];
     
    179180
    180181      // start and end idx are inclusive
    181       queue.Add(new Partition() { TreeNodeIdx = -1, Depth = maxDepth, StartIdx = 0, EndIndex = effectiveRows - 1 });
     182      queue.Add(calls++, new Partition() { ParentNodeIdx = -1, Depth = maxDepth, StartIdx = 0, EndIndex = effectiveRows - 1 });
    182183      CreateRegressionTreeForIdx(lineSearch);
    183184
     
    187188    private void CreateRegressionTreeForIdx(LineSearchFunc lineSearch) {
    188189      while (queue.Any()) {
    189         var f = queue[0]; // actually a stack
     190        var f = queue.First().Value; // actually a stack
    190191        queue.RemoveAt(0);
    191192
     
    204205        if (depth <= 1 || endIdx - startIdx == 0 || !FindBestVariableAndThreshold(startIdx, endIdx, out threshold, out bestVariableName)) {
    205206          CreateLeafNode(startIdx, endIdx, lineSearch);
    206           if (f.TreeNodeIdx >= 0) if (f.Left) {
    207               tree[f.TreeNodeIdx].leftIdx = curTreeNodeIdx;
     207          if (f.ParentNodeIdx >= 0) if (f.Left) {
     208              tree[f.ParentNodeIdx].leftIdx = curTreeNodeIdx;
    208209            } else {
    209               tree[f.TreeNodeIdx].rightIdx = curTreeNodeIdx;
     210              tree[f.ParentNodeIdx].rightIdx = curTreeNodeIdx;
    210211            }
    211212          curTreeNodeIdx++;
     
    215216
    216217          // connect to parent tree
    217           if (f.TreeNodeIdx >= 0) if (f.Left) {
    218               tree[f.TreeNodeIdx].leftIdx = curTreeNodeIdx;
     218          if (f.ParentNodeIdx >= 0) if (f.Left) {
     219              tree[f.ParentNodeIdx].leftIdx = curTreeNodeIdx;
    219220            } else {
    220               tree[f.TreeNodeIdx].rightIdx = curTreeNodeIdx;
     221              tree[f.ParentNodeIdx].rightIdx = curTreeNodeIdx;
    221222            }
    222223
     
    224225          Debug.Assert(startIdx <= splitIdx);
    225226
    226           queue.Insert(0, new Partition() { TreeNodeIdx = curTreeNodeIdx, Left = false, Depth = depth - 1, StartIdx = splitIdx + 1, EndIndex = endIdx });
    227           queue.Insert(0, new Partition() { TreeNodeIdx = curTreeNodeIdx, Left = true, Depth = depth - 1, StartIdx = startIdx, EndIndex = splitIdx }); // left part before right part (stack organization)
     227          queue.Add(calls++, new Partition() { ParentNodeIdx = curTreeNodeIdx, Left = true, Depth = depth - 1, StartIdx = startIdx, EndIndex = splitIdx }); // left part before right part (stack organization)
     228          queue.Add(calls++, new Partition() { ParentNodeIdx = curTreeNodeIdx, Left = false, Depth = depth - 1, StartIdx = splitIdx + 1, EndIndex = endIdx });
    228229          curTreeNodeIdx++;
    229230
     
    318319      }
    319320
    320       double bestImprovement = 0.0;
     321      double bestImprovement = 1.0 / rows * sumY * sumY;
    321322      double bestThreshold = double.PositiveInfinity;
    322323      bestVar = RegressionTreeModel.TreeNode.NO_VARIABLE;
     
    384385      double nr = rows;
    385386
    386       bestImprovement = 0.0;
     387      bestImprovement = 1.0 / rows * sumY * sumY;
    387388      bestThreshold = double.NegativeInfinity;
    388389      // for all thresholds
     
    398399
    399400        if (x[i] < x[i + 1]) { // don't try to split when two elements are equal
     401
     402          // goal is to find the split with leading to minimal total variance of left and right parts
     403          // without partitioning the variance is var(y) = E(y²) - E(y)² 
     404          //    = 1/n * sum(y²) - (1/n * sum(y))²
     405          //      -------------
     406          // if we split into right and left part the overall variance is the weigthed combination nl/n * var(y_l) + nr/n * var(y_r) 
     407          //    = nl/n * (1/nl * sum(y_l²) - (1/nl * sum(y_l))²) + nr/n * (1/nr * sum(y_r²) - (1/nr * sum(y_r))²)
     408          //    = 1/n * sum(y_l²) - 1/nl * 1/n * sum(y_l)² + 1/n * sum(y_r²) - 1/nr * 1/n * sum(y_r)²
     409          //    = 1/n * (sum(y_l²) + sum(y_r²)) - 1/n * (sum(y_l)² / nl + sum(y_r)² / nr)
     410          //    = 1/n * sum(y²) - 1/n * (sum(y_l)² / nl + sum(y_r)² / nr)
     411          //      -------------
     412          //       not changed by split (and the same for total variance without partitioning)
     413          //
     414          //   therefore we need to find the maximum value (sum(y_l)² / nl + sum(y_r)² / nr) (ignoring the factor 1/n)
     415          //   and this value must be larger than 1/n * sum(y)² to be an improvement over no split
     416
    400417          double curQuality = sl * sl / nl + sr * sr / nr;
    401           // curQuality = nl*nr / (nl+nr) * Sqr(sl / nl - sr / nr) // greedy function approximation page 12 eqn (35)
    402418
    403419          if (curQuality > bestImprovement) {
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r12588 r12620  
    199199      <Private>False</Private>
    200200    </Reference>
     201    <Reference Include="Microsoft.VisualStudio.QualityTools.UnitTestFramework, Version=10.1.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL" />
    201202    <Reference Include="System" />
    202203    <Reference Include="System.Core">
     
    289290    <Compile Include="GradientBoostedTrees\RegressionTreeBuilder.cs" />
    290291    <Compile Include="GradientBoostedTrees\RegressionTreeModel.cs" />
     292    <Compile Include="GradientBoostedTrees\Test.cs" />
    291293    <Compile Include="Interfaces\IGaussianProcessClassificationModelCreator.cs" />
    292294    <Compile Include="Interfaces\IGaussianProcessRegressionModelCreator.cs" />
Note: See TracChangeset for help on using the changeset viewer.