Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/21/16 17:59:36 (8 years ago)
Author:
gkronber
Message:

#2690: implemented methods to generate symbolic expression tree solutions for decision tree models (random forest and gradient boosted) as well as views which make it possible to inspect each of the individual trees in a GBT and RF solution

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
9 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r14185 r14345  
    269269        } else {
    270270          // otherwise we produce a regression solution
    271           Results.Add(new Result("Solution", new RegressionSolution(model, problemData)));
     271          Results.Add(new Result("Solution", new GradientBoostedTreesSolution(model, problemData)));
    272272        }
    273273      }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesSolution.cs

    r14185 r14345  
    2020#endregion
    2121
    22 using System.Collections.Generic;
    23 using System.Linq;
    2422using HeuristicLab.Common;
    2523using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r14185 r14345  
    2828using HeuristicLab.Common;
    2929using HeuristicLab.Core;
     30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    3031using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3132using HeuristicLab.Problems.DataAnalysis;
     33using HeuristicLab.Problems.DataAnalysis.Symbolic;
     34using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    3235
    3336namespace HeuristicLab.Algorithms.DataAnalysis {
     
    210213    }
    211214
     215    /// <summary>
     216    /// Transforms the tree model to a symbolic regression solution
     217    /// </summary>
     218    /// <param name="problemData"></param>
     219    /// <returns>A new symbolic regression solution which matches the tree model</returns>
     220    public ISymbolicRegressionSolution CreateSymbolicRegressionSolution(IRegressionProblemData problemData) {
     221      var rootSy = new ProgramRootSymbol();
     222      var startSy = new StartSymbol();
     223      var varCondSy = new VariableCondition() { IgnoreSlope = true };
     224      var constSy = new Constant();
     225
     226      var startNode = startSy.CreateTreeNode();
     227      startNode.AddSubtree(CreateSymbolicRegressionTreeRecursive(tree, 0, varCondSy, constSy));
     228      var rootNode = rootSy.CreateTreeNode();
     229      rootNode.AddSubtree(startNode);
     230      var model = new SymbolicRegressionModel(TargetVariable, new SymbolicExpressionTree(rootNode), new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
     231      return model.CreateRegressionSolution(problemData);
     232    }
     233
     234    private ISymbolicExpressionTreeNode CreateSymbolicRegressionTreeRecursive(TreeNode[] treeNodes, int nodeIdx, VariableCondition varCondSy, Constant constSy) {
     235      var curNode = treeNodes[nodeIdx];
     236      if (curNode.VarName == TreeNode.NO_VARIABLE) {
     237        var node = (ConstantTreeNode)constSy.CreateTreeNode();
     238        node.Value = curNode.Val;
     239        return node;
     240      } else {
     241        var node = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
     242        node.VariableName = curNode.VarName;
     243        node.Threshold = curNode.Val;
     244
     245        var left = CreateSymbolicRegressionTreeRecursive(treeNodes, curNode.LeftIdx, varCondSy, constSy);
     246        var right = CreateSymbolicRegressionTreeRecursive(treeNodes, curNode.RightIdx, varCondSy, constSy);
     247        node.AddSubtree(left);
     248        node.AddSubtree(right);
     249        return node;
     250      }
     251    }
     252
     253
    212254    private string TreeToString(int idx, string part) {
    213255      var n = tree[idx];
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Interfaces/IRandomForestClassificationSolution.cs

    r14185 r14345  
    3030  public interface IRandomForestClassificationSolution : IClassificationSolution {
    3131    new IRandomForestModel Model { get; }
     32    int NumberOfTrees { get; }
    3233  }
    3334}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Interfaces/IRandomForestModel.cs

    r14185 r14345  
    2020#endregion
    2121
    22 using HeuristicLab.Optimization;
     22using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2323using HeuristicLab.Problems.DataAnalysis;
    24 using HeuristicLab.Core;
    25 using System.Collections.Generic;
     24
    2625
    2726namespace HeuristicLab.Algorithms.DataAnalysis {
     
    3029  /// </summary>
    3130  public interface IRandomForestModel : IConfidenceRegressionModel, IClassificationModel {
     31    int NumberOfTrees { get; }
     32    ISymbolicExpressionTree ExtractTree(int treeIdx); // returns a specific tree from the random forest as a ISymbolicRegressionModel
    3233  }
    3334}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Interfaces/IRandomForestRegressionSolution.cs

    r14185 r14345  
    2020#endregion
    2121
    22 using HeuristicLab.Optimization;
    2322using HeuristicLab.Problems.DataAnalysis;
    24 using HeuristicLab.Core;
     23using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    2524
    2625namespace HeuristicLab.Algorithms.DataAnalysis {
     
    3029  public interface IRandomForestRegressionSolution : IConfidenceRegressionSolution {
    3130    new IRandomForestModel Model { get; }
     31    int NumberOfTrees { get; }
    3232  }
    3333}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassificationSolution.cs

    r14185 r14345  
    3838    }
    3939
     40    public int NumberOfTrees {
     41      get { return Model.NumberOfTrees; }
     42    }
     43
    4044    [StorableConstructor]
    4145    private RandomForestClassificationSolution(bool deserializing) : base(deserializing) { }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r14230 r14345  
    2525using HeuristicLab.Common;
    2626using HeuristicLab.Core;
     27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2728using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2829using HeuristicLab.Problems.DataAnalysis;
     30using HeuristicLab.Problems.DataAnalysis.Symbolic;
    2931
    3032namespace HeuristicLab.Algorithms.DataAnalysis {
     
    4951    }
    5052
     53    public int NumberOfTrees {
     54      get { return nTrees; }
     55    }
    5156
    5257    // instead of storing the data of the model itself
     
    6469    [Storable]
    6570    private double m;
    66 
    6771
    6872    [StorableConstructor]
     
    197201    }
    198202
     203    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
     204      // hoping that the internal representation of alglib is stable
     205
     206      // TREE FORMAT
     207      // W[Offs]      -   size of sub-array (for the tree)
     208      //     node info:
     209      // W[K+0]       -   variable number        (-1 for leaf mode)
     210      // W[K+1]       -   threshold              (class/value for leaf node)
     211      // W[K+2]       -   ">=" branch index      (absent for leaf node)
     212
     213      // skip irrelevant trees
     214      int offset = 0;
     215      for (int i = 0; i < treeIdx - 1; i++) {
     216        offset = offset + (int)Math.Round(randomForest.innerobj.trees[offset]);
     217      }
     218
     219      var constSy = new Constant();
     220      var varCondSy = new VariableCondition() { IgnoreSlope = true };
     221
     222      var node = CreateRegressionTreeRec(randomForest.innerobj.trees, offset, offset + 1, constSy, varCondSy);
     223
     224      var startNode = new StartSymbol().CreateTreeNode();
     225      startNode.AddSubtree(node);
     226      var root = new ProgramRootSymbol().CreateTreeNode();
     227      root.AddSubtree(startNode);
     228      return new SymbolicExpressionTree(root);
     229    }
     230
     231    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) {
     232
     233      // alglib source for evaluation of one tree (dfprocessinternal)
     234      // offs = 0
     235      //
     236      // Set pointer to the root
     237      //
     238      // k = offs + 1;
     239      //
     240      // //
     241      // // Navigate through the tree
     242      // //
     243      // while (true) {
     244      //   if ((double)(df.trees[k]) == (double)(-1)) {
     245      //     if (df.nclasses == 1) {
     246      //       y[0] = y[0] + df.trees[k + 1];
     247      //     } else {
     248      //       idx = (int)Math.Round(df.trees[k + 1]);
     249      //       y[idx] = y[idx] + 1;
     250      //     }
     251      //     break;
     252      //   }
     253      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
     254      //     k = k + innernodewidth;
     255      //   } else {
     256      //     k = offs + (int)Math.Round(df.trees[k + 2]);
     257      //   }
     258      // }
     259
     260      if ((double)(trees[k]) == (double)(-1)) {
     261        var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
     262        constNode.Value = trees[k + 1];
     263        return constNode;
     264      } else {
     265        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
     266        condNode.VariableName = AllowedInputVariables[(int)Math.Round(trees[k])];
     267        condNode.Threshold = trees[k + 1];
     268        condNode.Slope = double.PositiveInfinity;
     269
     270        var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy);
     271        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy);
     272
     273        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
     274        condNode.AddSubtree(right);
     275        return condNode;
     276      }
     277    }
     278
    199279
    200280    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegressionSolution.cs

    r14185 r14345  
    2020#endregion
    2121
     22using System;
    2223using HeuristicLab.Common;
    2324using HeuristicLab.Core;
    2425using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2526using HeuristicLab.Problems.DataAnalysis;
     27using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    2628
    2729namespace HeuristicLab.Algorithms.DataAnalysis {
     
    3638      get { return (IRandomForestModel)base.Model; }
    3739      set { base.Model = value; }
     40    }
     41
     42    public int NumberOfTrees {
     43      get { return Model.NumberOfTrees; }
    3844    }
    3945
Note: See TracChangeset for help on using the changeset viewer.