Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/21/16 17:59:36 (8 years ago)
Author:
gkronber
Message:

#2690: implemented methods to generate symbolic expression tree solutions for decision tree models (random forest and gradient boosted) as well as views which make it possible to inspect each of the individual trees in a GBT and RF solution

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassificationSolution.cs

    r14185 r14345  
    3838    }
    3939
     40    public int NumberOfTrees {
     41      get { return Model.NumberOfTrees; }
     42    }
     43
    4044    [StorableConstructor]
    4145    private RandomForestClassificationSolution(bool deserializing) : base(deserializing) { }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r14230 r14345  
    2525using HeuristicLab.Common;
    2626using HeuristicLab.Core;
     27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2728using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2829using HeuristicLab.Problems.DataAnalysis;
     30using HeuristicLab.Problems.DataAnalysis.Symbolic;
    2931
    3032namespace HeuristicLab.Algorithms.DataAnalysis {
     
    4951    }
    5052
     53    public int NumberOfTrees {
     54      get { return nTrees; }
     55    }
    5156
    5257    // instead of storing the data of the model itself
     
    6469    [Storable]
    6570    private double m;
    66 
    6771
    6872    [StorableConstructor]
     
    197201    }
    198202
     203    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
     204      // hoping that the internal representation of alglib is stable
     205
     206      // TREE FORMAT
     207      // W[Offs]      -   size of sub-array (for the tree)
     208      //     node info:
     209      // W[K+0]       -   variable number        (-1 for leaf mode)
     210      // W[K+1]       -   threshold              (class/value for leaf node)
     211      // W[K+2]       -   ">=" branch index      (absent for leaf node)
     212
     213      // skip irrelevant trees
     214      int offset = 0;
     215      for (int i = 0; i < treeIdx - 1; i++) {
     216        offset = offset + (int)Math.Round(randomForest.innerobj.trees[offset]);
     217      }
     218
     219      var constSy = new Constant();
     220      var varCondSy = new VariableCondition() { IgnoreSlope = true };
     221
     222      var node = CreateRegressionTreeRec(randomForest.innerobj.trees, offset, offset + 1, constSy, varCondSy);
     223
     224      var startNode = new StartSymbol().CreateTreeNode();
     225      startNode.AddSubtree(node);
     226      var root = new ProgramRootSymbol().CreateTreeNode();
     227      root.AddSubtree(startNode);
     228      return new SymbolicExpressionTree(root);
     229    }
     230
     231    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) {
     232
     233      // alglib source for evaluation of one tree (dfprocessinternal)
     234      // offs = 0
     235      //
     236      // Set pointer to the root
     237      //
     238      // k = offs + 1;
     239      //
     240      // //
     241      // // Navigate through the tree
     242      // //
     243      // while (true) {
     244      //   if ((double)(df.trees[k]) == (double)(-1)) {
     245      //     if (df.nclasses == 1) {
     246      //       y[0] = y[0] + df.trees[k + 1];
     247      //     } else {
     248      //       idx = (int)Math.Round(df.trees[k + 1]);
     249      //       y[idx] = y[idx] + 1;
     250      //     }
     251      //     break;
     252      //   }
     253      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
     254      //     k = k + innernodewidth;
     255      //   } else {
     256      //     k = offs + (int)Math.Round(df.trees[k + 2]);
     257      //   }
     258      // }
     259
     260      if ((double)(trees[k]) == (double)(-1)) {
     261        var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
     262        constNode.Value = trees[k + 1];
     263        return constNode;
     264      } else {
     265        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
     266        condNode.VariableName = AllowedInputVariables[(int)Math.Round(trees[k])];
     267        condNode.Threshold = trees[k + 1];
     268        condNode.Slope = double.PositiveInfinity;
     269
     270        var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy);
     271        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy);
     272
     273        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
     274        condNode.AddSubtree(right);
     275        return condNode;
     276      }
     277    }
     278
    199279
    200280    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegressionSolution.cs

    r14185 r14345  
    2020#endregion
    2121
     22using System;
    2223using HeuristicLab.Common;
    2324using HeuristicLab.Core;
    2425using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2526using HeuristicLab.Problems.DataAnalysis;
     27using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    2628
    2729namespace HeuristicLab.Algorithms.DataAnalysis {
     
    3638      get { return (IRandomForestModel)base.Model; }
    3739      set { base.Model = value; }
     40    }
     41
     42    public int NumberOfTrees {
     43      get { return Model.NumberOfTrees; }
    3844    }
    3945
Note: See TracChangeset for help on using the changeset viewer.