Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/04/15 16:03:36 (9 years ago)
Author:
gkronber
Message:

#2261: preparations for trunk integration (adapt to current trunk version, add license headers, add comments, improve code quality)

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12372 r12590  
    1 using System;
     1#region License Information
     2/* HeuristicLab
     3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     4 * and the BEACON Center for the Study of Evolution in Action.
     5 *
     6 * This file is part of HeuristicLab.
     7 *
     8 * HeuristicLab is free software: you can redistribute it and/or modify
     9 * it under the terms of the GNU General Public License as published by
     10 * the Free Software Foundation, either version 3 of the License, or
     11 * (at your option) any later version.
     12 *
     13 * HeuristicLab is distributed in the hope that it will be useful,
     14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
     15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     16 * GNU General Public License for more details.
     17 *
     18 * You should have received a copy of the GNU General Public License
     19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
     20 */
     21#endregion
     22
     23using System;
    224using System.Collections.Generic;
    325using System.Diagnostics;
    426using System.Diagnostics.Contracts;
    527using System.Linq;
    6 using HeuristicLab.Common;
    728using HeuristicLab.Core;
    829using HeuristicLab.Problems.DataAnalysis;
    930
    10 namespace GradientBoostedTrees {
     31namespace HeuristicLab.Algorithms.DataAnalysis {
     32  // This class implements a greedy decision tree learner which selects splits with the maximum reduction in sum of squared errors.
     33  // The tree builder also tracks variable relevance metrics based on the splits and improvement after the split.
     34  // The implementation is tuned for gradient boosting where multiple trees have to be calculated for the same training data
     35  // each time with a different target vector. Vectors of idx to allow iteration of intput variables in sorted order are
     36  // pre-calculated so that optimal thresholds for splits can be calculated in O(n) for each input variable.
     37  // After each split the row idx are partitioned in a left an right part.
    1138  public class RegressionTreeBuilder {
    1239    private readonly IRandom random;
     
    1441
    1542    private readonly int nCols;
    16     private readonly double[][] x; // all training data (original order from problemData)
    17     private double[] y; // training labels (original order from problemData)
     43    private readonly double[][] x; // all training data (original order from problemData), x is constant
     44    private double[] y; // training labels (original order from problemData), y can be changed
    1845
    1946    private Dictionary<string, double> sumImprovements; // for variable relevance calculation
     
    76103    }
    77104
     105    // simple API produces a single regression tree optimizing sum of squared errors
     106    // this can be used if only a simple regression tree should be produced
     107    // for a set of trees use the method CreateRegressionTreeForGradientBoosting below
     108    //
    78109    // r and m work in the same way as for alglib random forest
    79110    // r is fraction of rows to use for training
     
    92123    }
    93124
    94     // specific interface that allows to specify the target labels and the training rows which is necessary when this functionality is called by the gradient boosting routine
     125    // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees
    95126    public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int maxDepth, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) {
    96127      Contract.Assert(maxDepth > 0);
     
    111142      HeuristicLab.Random.ListExtensions.ShuffleInPlace(allowedVariables, random);
    112143
     144      // only select a part of the rows and columns randomly
    113145      effectiveRows = (int)Math.Ceiling(nRows * r);
    114146      effectiveVars = (int)Math.Ceiling(nCols * m);
    115147
     148      // the which array is used for partining row idxs
    116149      Array.Clear(which, 0, which.Length);
    117150
    118151      // mark selected rows
    119152      for (int row = 0; row < effectiveRows; row++) {
    120         which[idx[row]] = 1;
     153        which[idx[row]] = 1; // we use the which vector as a temporary variable here
    121154        internalIdx[row] = idx[row];
    122155      }
     
    126159        for (int row = 0; row < nRows; row++) {
    127160          if (which[sortedIdxAll[col][row]] > 0) {
    128             Trace.Assert(i < effectiveRows);
     161            Debug.Assert(i < effectiveRows);
    129162            sortedIdx[col][i] = sortedIdxAll[col][row];
    130163            i++;
     
    135168      // prepare array for the tree nodes (a tree of maxDepth=1 has 1 node, a tree of maxDepth=d has 2^d - 1 nodes)     
    136169      int numNodes = (int)Math.Pow(2, maxDepth) - 1;
    137       //this.tree = new RegressionTreeModel.TreeNode[numNodes];
    138       this.tree = Enumerable.Range(0, numNodes).Select(_=>new RegressionTreeModel.TreeNode()).ToArray();
     170      this.tree = new RegressionTreeModel.TreeNode[numNodes];
    139171      this.curTreeNodeIdx = 0;
    140172
     
    144176    }
    145177
     178    // recursive routine for building the tree for the row idx stored in internalIdx between startIdx and endIdx
     179    // the lineSearch function calculates the optimal prediction value for tree leaf nodes
     180    // (in the case of squared errors it is the average of target values for the rows represented by the node)
    146181    // startIdx and endIdx are inclusive
    147182    private void CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) {
     
    214249            else if (which[internalIdx[j]] > 0) j--;
    215250            else {
    216               Trace.Assert(which[internalIdx[i]] > 0);
    217               Trace.Assert(which[internalIdx[j]] < 0);
     251              Debug.Assert(which[internalIdx[i]] > 0);
     252              Debug.Assert(which[internalIdx[j]] < 0);
    218253              // swap
    219254              int tmp = internalIdx[i];
     
    283318
    284319      threshold = bestThreshold;
    285 
    286       // Contract.Assert(bestImprovement > 0);
    287       // Contract.Assert(bestImprovement < double.PositiveInfinity);
    288       // Contract.Assert(bestVar != string.Empty);
    289       // Contract.Assert(allowedVariables.Contains(bestVar));
    290320    }
    291321
     
    351381
    352382    public IEnumerable<KeyValuePair<string, double>> GetVariableRelevance() {
     383      // values are scaled: the most important variable has relevance = 100
    353384      double scaling = 100 / sumImprovements.Max(t => t.Value);
    354385      return
Note: See TracChangeset for help on using the changeset viewer.