Changeset 14345


Ignore:
Timestamp:
10/21/16 17:59:36 (12 months ago)
Author:
gkronber
Message:

#2690: implemented methods to generate symbolic expression tree solutions for decision tree models (random forest and gradient boosted) as well as views which make it possible to inspect each of the individual trees in a GBT and RF solution

Location:
trunk/sources
Files:
10 added
17 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/HeuristicLab.Algorithms.DataAnalysis.Views-3.4.csproj

    r14125 r14345  
    125125  </ItemGroup>
    126126  <ItemGroup>
     127    <Compile Include="RandomForestClassificationSolutionView.cs">
     128      <SubType>UserControl</SubType>
     129    </Compile>
     130    <Compile Include="RandomForestClassificationSolutionView.Designer.cs">
     131      <DependentUpon>RandomForestClassificationSolutionView.cs</DependentUpon>
     132    </Compile>
     133    <Compile Include="RandomForestModelView.cs">
     134      <SubType>UserControl</SubType>
     135    </Compile>
     136    <Compile Include="RandomForestModelView.Designer.cs">
     137      <DependentUpon>RandomForestModelView.cs</DependentUpon>
     138    </Compile>
     139    <Compile Include="RandomForestRegressionSolutionView.cs">
     140      <SubType>UserControl</SubType>
     141    </Compile>
     142    <Compile Include="RandomForestRegressionSolutionView.Designer.cs">
     143      <DependentUpon>RandomForestRegressionSolutionView.cs</DependentUpon>
     144    </Compile>
     145    <Compile Include="GradientBoostedTreesModelView.cs">
     146      <SubType>UserControl</SubType>
     147    </Compile>
     148    <Compile Include="GradientBoostedTreesModelView.Designer.cs">
     149      <DependentUpon>GradientBoostedTreesModelView.cs</DependentUpon>
     150    </Compile>
    127151    <Compile Include="MeanProdView.cs">
    128152      <SubType>UserControl</SubType>
     
    192216    <Compile Include="SupportVectorMachineModelView.Designer.cs">
    193217      <DependentUpon>SupportVectorMachineModelView.cs</DependentUpon>
     218    </Compile>
     219    <Compile Include="GradientBoostedTreesSolutionView.cs">
     220      <SubType>UserControl</SubType>
     221    </Compile>
     222    <Compile Include="GradientBoostedTreesSolutionView.Designer.cs">
     223      <DependentUpon>GradientBoostedTreesSolutionView.cs</DependentUpon>
    194224    </Compile>
    195225  </ItemGroup>
     
    245275      <Private>False</Private>
    246276    </ProjectReference>
     277    <ProjectReference Include="..\..\HeuristicLab.Encodings.SymbolicExpressionTreeEncoding\3.4\HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4.csproj">
     278      <Project>{06D4A186-9319-48A0-BADE-A2058D462EEA}</Project>
     279      <Name>HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4</Name>
     280    </ProjectReference>
    247281    <ProjectReference Include="..\..\HeuristicLab.MainForm.WindowsForms\3.3\HeuristicLab.MainForm.WindowsForms-3.3.csproj">
    248282      <Project>{AB687BBE-1BFE-476B-906D-44237135431D}</Project>
     
    269303      <Name>HeuristicLab.PluginInfrastructure-3.3</Name>
    270304      <Private>False</Private>
     305    </ProjectReference>
     306    <ProjectReference Include="..\..\HeuristicLab.Problems.DataAnalysis.Symbolic.Classification\3.4\HeuristicLab.Problems.DataAnalysis.Symbolic.Classification-3.4.csproj">
     307      <Project>{05BAE4E1-A9FA-4644-AA77-42558720159E}</Project>
     308      <Name>HeuristicLab.Problems.DataAnalysis.Symbolic.Classification-3.4</Name>
     309    </ProjectReference>
     310    <ProjectReference Include="..\..\HeuristicLab.Problems.DataAnalysis.Symbolic.Regression\3.4\HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4.csproj">
     311      <Project>{5AC82412-911B-4FA2-A013-EDC5E3F3FCC2}</Project>
     312      <Name>HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4</Name>
     313    </ProjectReference>
     314    <ProjectReference Include="..\..\HeuristicLab.Problems.DataAnalysis.Symbolic.Views\3.4\HeuristicLab.Problems.DataAnalysis.Symbolic.Views-3.4.csproj">
     315      <Project>{7a2531ce-3f7c-4f13-bcca-ed6dc27a7086}</Project>
     316      <Name>HeuristicLab.Problems.DataAnalysis.Symbolic.Views-3.4</Name>
     317    </ProjectReference>
     318    <ProjectReference Include="..\..\HeuristicLab.Problems.DataAnalysis.Symbolic\3.4\HeuristicLab.Problems.DataAnalysis.Symbolic-3.4.csproj">
     319      <Project>{3d28463f-ec96-4d82-afee-38be91a0ca00}</Project>
     320      <Name>HeuristicLab.Problems.DataAnalysis.Symbolic-3.4</Name>
    271321    </ProjectReference>
    272322    <ProjectReference Include="..\..\HeuristicLab.Problems.DataAnalysis.Views\3.4\HeuristicLab.Problems.DataAnalysis.Views-3.4.csproj">
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/KMeansClusteringModelView.cs

    r14185 r14345  
    1919 */
    2020#endregion
    21 using System;
    2221using System.Linq;
    23 using System.IO;
    24 using System.Windows.Forms;
    2522using HeuristicLab.MainForm;
    2623using HeuristicLab.MainForm.WindowsForms;
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/Plugin.cs.frame

    r14195 r14345  
    3737  [PluginDependency("HeuristicLab.Data", "3.3")]
    3838  [PluginDependency("HeuristicLab.Data.Views", "3.3")]
     39  [PluginDependency("HeuristicLab.Encodings.SymbolicExpressionTreeEncoding", "3.4")]
    3940  [PluginDependency("HeuristicLab.LibSVM", "3.12")]
    4041  [PluginDependency("HeuristicLab.MainForm", "3.3")]
     
    4445  [PluginDependency("HeuristicLab.Problems.DataAnalysis", "3.4")]
    4546  [PluginDependency("HeuristicLab.Problems.DataAnalysis.Views", "3.4")]
     47  [PluginDependency("HeuristicLab.Problems.DataAnalysis.Symbolic", "3.4")]
     48  [PluginDependency("HeuristicLab.Problems.DataAnalysis.Symbolic.Views", "3.4")]
     49  [PluginDependency("HeuristicLab.Problems.DataAnalysis.Symbolic.Regression", "3.4")]
     50  [PluginDependency("HeuristicLab.Problems.DataAnalysis.Symbolic.Classification", "3.4")]
    4651  public class HeuristicLabAlgorithmsDataAnalysisViewsPlugin : PluginBase {
    4752  }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r14185 r14345  
    269269        } else {
    270270          // otherwise we produce a regression solution
    271           Results.Add(new Result("Solution", new RegressionSolution(model, problemData)));
     271          Results.Add(new Result("Solution", new GradientBoostedTreesSolution(model, problemData)));
    272272        }
    273273      }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesSolution.cs

    r14185 r14345  
    2020#endregion
    2121
    22 using System.Collections.Generic;
    23 using System.Linq;
    2422using HeuristicLab.Common;
    2523using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r14185 r14345  
    2828using HeuristicLab.Common;
    2929using HeuristicLab.Core;
     30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    3031using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3132using HeuristicLab.Problems.DataAnalysis;
     33using HeuristicLab.Problems.DataAnalysis.Symbolic;
     34using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    3235
    3336namespace HeuristicLab.Algorithms.DataAnalysis {
     
    210213    }
    211214
     215    /// <summary>
     216    /// Transforms the tree model to a symbolic regression solution
     217    /// </summary>
     218    /// <param name="problemData"></param>
     219    /// <returns>A new symbolic regression solution which matches the tree model</returns>
     220    public ISymbolicRegressionSolution CreateSymbolicRegressionSolution(IRegressionProblemData problemData) {
     221      var rootSy = new ProgramRootSymbol();
     222      var startSy = new StartSymbol();
     223      var varCondSy = new VariableCondition() { IgnoreSlope = true };
     224      var constSy = new Constant();
     225
     226      var startNode = startSy.CreateTreeNode();
     227      startNode.AddSubtree(CreateSymbolicRegressionTreeRecursive(tree, 0, varCondSy, constSy));
     228      var rootNode = rootSy.CreateTreeNode();
     229      rootNode.AddSubtree(startNode);
     230      var model = new SymbolicRegressionModel(TargetVariable, new SymbolicExpressionTree(rootNode), new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
     231      return model.CreateRegressionSolution(problemData);
     232    }
     233
     234    private ISymbolicExpressionTreeNode CreateSymbolicRegressionTreeRecursive(TreeNode[] treeNodes, int nodeIdx, VariableCondition varCondSy, Constant constSy) {
     235      var curNode = treeNodes[nodeIdx];
     236      if (curNode.VarName == TreeNode.NO_VARIABLE) {
     237        var node = (ConstantTreeNode)constSy.CreateTreeNode();
     238        node.Value = curNode.Val;
     239        return node;
     240      } else {
     241        var node = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
     242        node.VariableName = curNode.VarName;
     243        node.Threshold = curNode.Val;
     244
     245        var left = CreateSymbolicRegressionTreeRecursive(treeNodes, curNode.LeftIdx, varCondSy, constSy);
     246        var right = CreateSymbolicRegressionTreeRecursive(treeNodes, curNode.RightIdx, varCondSy, constSy);
     247        node.AddSubtree(left);
     248        node.AddSubtree(right);
     249        return node;
     250      }
     251    }
     252
     253
    212254    private string TreeToString(int idx, string part) {
    213255      var n = tree[idx];
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Interfaces/IRandomForestClassificationSolution.cs

    r14185 r14345  
    3030  public interface IRandomForestClassificationSolution : IClassificationSolution {
    3131    new IRandomForestModel Model { get; }
     32    int NumberOfTrees { get; }
    3233  }
    3334}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Interfaces/IRandomForestModel.cs

    r14185 r14345  
    2020#endregion
    2121
    22 using HeuristicLab.Optimization;
     22using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2323using HeuristicLab.Problems.DataAnalysis;
    24 using HeuristicLab.Core;
    25 using System.Collections.Generic;
     24
    2625
    2726namespace HeuristicLab.Algorithms.DataAnalysis {
     
    3029  /// </summary>
    3130  public interface IRandomForestModel : IConfidenceRegressionModel, IClassificationModel {
     31    int NumberOfTrees { get; }
     32    ISymbolicExpressionTree ExtractTree(int treeIdx); // returns a specific tree from the random forest as a ISymbolicRegressionModel
    3233  }
    3334}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Interfaces/IRandomForestRegressionSolution.cs

    r14185 r14345  
    2020#endregion
    2121
    22 using HeuristicLab.Optimization;
    2322using HeuristicLab.Problems.DataAnalysis;
    24 using HeuristicLab.Core;
     23using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    2524
    2625namespace HeuristicLab.Algorithms.DataAnalysis {
     
    3029  public interface IRandomForestRegressionSolution : IConfidenceRegressionSolution {
    3130    new IRandomForestModel Model { get; }
     31    int NumberOfTrees { get; }
    3232  }
    3333}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassificationSolution.cs

    r14185 r14345  
    3838    }
    3939
     40    public int NumberOfTrees {
     41      get { return Model.NumberOfTrees; }
     42    }
     43
    4044    [StorableConstructor]
    4145    private RandomForestClassificationSolution(bool deserializing) : base(deserializing) { }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r14230 r14345  
    2525using HeuristicLab.Common;
    2626using HeuristicLab.Core;
     27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2728using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2829using HeuristicLab.Problems.DataAnalysis;
     30using HeuristicLab.Problems.DataAnalysis.Symbolic;
    2931
    3032namespace HeuristicLab.Algorithms.DataAnalysis {
     
    4951    }
    5052
     53    public int NumberOfTrees {
     54      get { return nTrees; }
     55    }
    5156
    5257    // instead of storing the data of the model itself
     
    6469    [Storable]
    6570    private double m;
    66 
    6771
    6872    [StorableConstructor]
     
    197201    }
    198202
     203    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
     204      // hoping that the internal representation of alglib is stable
     205
     206      // TREE FORMAT
     207      // W[Offs]      -   size of sub-array (for the tree)
     208      //     node info:
     209      // W[K+0]       -   variable number        (-1 for leaf mode)
     210      // W[K+1]       -   threshold              (class/value for leaf node)
     211      // W[K+2]       -   ">=" branch index      (absent for leaf node)
     212
     213      // skip irrelevant trees
     214      int offset = 0;
     215      for (int i = 0; i < treeIdx - 1; i++) {
     216        offset = offset + (int)Math.Round(randomForest.innerobj.trees[offset]);
     217      }
     218
     219      var constSy = new Constant();
     220      var varCondSy = new VariableCondition() { IgnoreSlope = true };
     221
     222      var node = CreateRegressionTreeRec(randomForest.innerobj.trees, offset, offset + 1, constSy, varCondSy);
     223
     224      var startNode = new StartSymbol().CreateTreeNode();
     225      startNode.AddSubtree(node);
     226      var root = new ProgramRootSymbol().CreateTreeNode();
     227      root.AddSubtree(startNode);
     228      return new SymbolicExpressionTree(root);
     229    }
     230
     231    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) {
     232
     233      // alglib source for evaluation of one tree (dfprocessinternal)
     234      // offs = 0
     235      //
     236      // Set pointer to the root
     237      //
     238      // k = offs + 1;
     239      //
     240      // //
     241      // // Navigate through the tree
     242      // //
     243      // while (true) {
     244      //   if ((double)(df.trees[k]) == (double)(-1)) {
     245      //     if (df.nclasses == 1) {
     246      //       y[0] = y[0] + df.trees[k + 1];
     247      //     } else {
     248      //       idx = (int)Math.Round(df.trees[k + 1]);
     249      //       y[idx] = y[idx] + 1;
     250      //     }
     251      //     break;
     252      //   }
     253      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
     254      //     k = k + innernodewidth;
     255      //   } else {
     256      //     k = offs + (int)Math.Round(df.trees[k + 2]);
     257      //   }
     258      // }
     259
     260      if ((double)(trees[k]) == (double)(-1)) {
     261        var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
     262        constNode.Value = trees[k + 1];
     263        return constNode;
     264      } else {
     265        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
     266        condNode.VariableName = AllowedInputVariables[(int)Math.Round(trees[k])];
     267        condNode.Threshold = trees[k + 1];
     268        condNode.Slope = double.PositiveInfinity;
     269
     270        var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy);
     271        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy);
     272
     273        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
     274        condNode.AddSubtree(right);
     275        return condNode;
     276      }
     277    }
     278
    199279
    200280    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegressionSolution.cs

    r14185 r14345  
    2020#endregion
    2121
     22using System;
    2223using HeuristicLab.Common;
    2324using HeuristicLab.Core;
    2425using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2526using HeuristicLab.Problems.DataAnalysis;
     27using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    2628
    2729namespace HeuristicLab.Algorithms.DataAnalysis {
     
    3638      get { return (IRandomForestModel)base.Model; }
    3739      set { base.Model = value; }
     40    }
     41
     42    public int NumberOfTrees {
     43      get { return Model.NumberOfTrees; }
    3844    }
    3945
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionCompiledTreeInterpreter.cs

    r14185 r14345  
    614614        case OpCodes.VariableCondition: {
    615615            var variableConditionTreeNode = (VariableConditionTreeNode)node;
     616            if (variableConditionTreeNode.Symbol.IgnoreSlope) throw new NotSupportedException("Strict variable conditionals are not supported");
    616617            var variableName = variableConditionTreeNode.VariableName;
    617618            var indexExpr = Expression.Constant(variableIndices[variableName]);
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeInterpreter.cs

    r14185 r14345  
    471471            if (row < 0 || row >= dataset.Rows) return double.NaN;
    472472            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
    473             double variableValue = ((IList<double>)currentInstr.data)[row];
    474             double x = variableValue - variableConditionTreeNode.Threshold;
    475             double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
    476 
    477             double trueBranch = Evaluate(dataset, ref row, state);
    478             double falseBranch = Evaluate(dataset, ref row, state);
    479 
    480             return trueBranch * p + falseBranch * (1 - p);
     473            if (!variableConditionTreeNode.Symbol.IgnoreSlope) {
     474              double variableValue = ((IList<double>)currentInstr.data)[row];
     475              double x = variableValue - variableConditionTreeNode.Threshold;
     476              double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
     477
     478              double trueBranch = Evaluate(dataset, ref row, state);
     479              double falseBranch = Evaluate(dataset, ref row, state);
     480
     481              return trueBranch * p + falseBranch * (1 - p);
     482            } else {
     483              // strict threshold
     484              double variableValue = ((IList<double>)currentInstr.data)[row];
     485              if (variableValue <= variableConditionTreeNode.Threshold) {
     486                var left = Evaluate(dataset, ref row, state);
     487                state.SkipInstructions();
     488                return left;
     489              } else {
     490                state.SkipInstructions();
     491                return Evaluate(dataset, ref row, state);
     492              }
     493            }
    481494          }
    482495        default:
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs

    r14282 r14345  
    126126    private readonly object syncRoot = new object();
    127127    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
     128      if (!rows.Any()) return Enumerable.Empty<double>();
    128129      if (CheckExpressionsWithIntervalArithmetic)
    129130        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
     
    159160          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
    160161          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
    161           double variableValue = ((IList<double>)instr.data)[row];
    162           double x = variableValue - variableConditionTreeNode.Threshold;
    163           double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
    164 
    165           double trueBranch = code[instr.childIndex].value;
    166           double falseBranch = code[instr.childIndex + 1].value;
    167 
    168           instr.value = trueBranch * p + falseBranch * (1 - p);
     162          if (!variableConditionTreeNode.Symbol.IgnoreSlope) {
     163            double variableValue = ((IList<double>)instr.data)[row];
     164            double x = variableValue - variableConditionTreeNode.Threshold;
     165            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
     166
     167            double trueBranch = code[instr.childIndex].value;
     168            double falseBranch = code[instr.childIndex + 1].value;
     169
     170            instr.value = trueBranch * p + falseBranch * (1 - p);
     171          } else {
     172            double variableValue = ((IList<double>)instr.data)[row];
     173            if (variableValue <= variableConditionTreeNode.Threshold) {
     174              instr.value = code[instr.childIndex].value;
     175            } else {
     176              instr.value = code[instr.childIndex + 1].value;
     177            }
     178          }
    169179        } else if (instr.opCode == OpCodes.Add) {
    170180          double s = code[instr.childIndex].value;
     
    411421              for (int j = 1; j != seq.Length; ++j)
    412422                seq[j].skip = true;
    413             }
    414             break;
     423              break;
     424            }
    415425        }
    416426        #endregion
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Symbols/VariableCondition.cs

    r14185 r14345  
    149149        }
    150150      }
     151    }
     152
     153    /// <summary>
     154    /// Flag to indicate if the interpreter should ignore the slope parameter (introduced for representation of expression trees)
     155    /// </summary>
     156    [Storable]
     157    private bool ignoreSlope;
     158    public bool IgnoreSlope {
     159      get { return ignoreSlope; }
     160      set { ignoreSlope = value; }
    151161    }
    152162
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Symbols/VariableConditionTreeNode.cs

    r14185 r14345  
    9696
    9797    public override string ToString() {
    98       if (slope.IsAlmost(0.0))
     98      if (slope.IsAlmost(0.0) || Symbol.IgnoreSlope) {
     99        return variableName + " < " + threshold.ToString("E4");
     100      } else {
    99101        return variableName + " > " + threshold.ToString("E4") + Environment.NewLine +
    100           "slope: " + slope.ToString("E4");
    101       else
    102         return variableName + " > " + threshold.ToString("E4");
     102               "slope: " + slope.ToString("E4");
     103      }
    103104    }
    104105  }
Note: See TracChangeset for help on using the changeset viewer.