Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModelFull.cs @ 17045

Last change on this file since 17045 was 17045, checked in by mkommend, 5 years ago

#2952: Intermediate commit of refactoring RF models that is not yet finished.

File size: 10.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HEAL.Attic;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  [StorableType("9C797DF0-1169-4381-A732-6DAB90802839")]
34  [Item("RandomForestModelFull", "Represents a random forest for regression and classification.")]
35  public sealed class RandomForestModelFull : ClassificationModel, IRandomForestModel {
36
37    public override IEnumerable<string> VariablesUsedForPrediction {
38      get { return inputVariables; }
39    }
40
41    [Storable]
42    private double[] classValues;
43
44    [Storable]
45    private string[] inputVariables;
46
47    public int NumberOfTrees {
48      get { return RandomForestNTrees; }
49    }
50
51    // not persisted
52    private alglib.decisionforest randomForest;
53
54    [Storable]
55    private int RandomForestBufSize {
56      get { return randomForest.innerobj.bufsize; }
57      set { randomForest.innerobj.bufsize = value; }
58    }
59    [Storable]
60    private int RandomForestNClasses {
61      get { return randomForest.innerobj.nclasses; }
62      set { randomForest.innerobj.nclasses = value; }
63    }
64    [Storable]
65    private int RandomForestNTrees {
66      get { return randomForest.innerobj.ntrees; }
67      set { randomForest.innerobj.ntrees = value; }
68    }
69    [Storable]
70    private int RandomForestNVars {
71      get { return randomForest.innerobj.nvars; }
72      set { randomForest.innerobj.nvars = value; }
73    }
74    [Storable]
75    private double[] RandomForestTrees {
76      get { return randomForest.innerobj.trees; }
77      set { randomForest.innerobj.trees = value; }
78    }
79
80    [StorableConstructor]
81    private RandomForestModelFull(StorableConstructorFlag _) : base(_) {
82      randomForest = new alglib.decisionforest();
83    }
84
85    private RandomForestModelFull(RandomForestModelFull original, Cloner cloner) : base(original, cloner) {
86      randomForest = new alglib.decisionforest();
87      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
88      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
89      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
90      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
91      randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone();
92
93      // following fields are immutable so we don't need to clone them
94      inputVariables = original.inputVariables;
95      classValues = original.classValues;
96    }
97    public override IDeepCloneable Clone(Cloner cloner) {
98      return new RandomForestModelFull(this, cloner);
99    }
100
101    public RandomForestModelFull(alglib.decisionforest decisionForest, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<double> classValues = null) : base(targetVariable) {
102      randomForest = decisionForest;
103
104      this.inputVariables = inputVariables.ToArray();
105
106      //classValues are only use for classification models
107      if (classValues == null) this.classValues = new double[0];
108      this.classValues = classValues.ToArray();
109    }
110
111
112    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
113      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
114    }
115    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
116      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
117    }
118
119    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
120      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
121    }
122
123    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
124      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
125
126      var regressionProblemData = problemData as IRegressionProblemData;
127      if (regressionProblemData != null)
128        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
129
130      var classificationProblemData = problemData as IClassificationProblemData;
131      if (classificationProblemData != null)
132        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
133
134      throw new ArgumentException("The problem data is not compatible with this random forest. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
135    }
136
137    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
138      double[,] inputData = dataset.ToArray(inputVariables, rows);
139      RandomForestUtil.AssertInputMatrix(inputData);
140
141      int n = inputData.GetLength(0);
142      int columns = inputData.GetLength(1);
143      double[] x = new double[columns];
144      double[] y = new double[1];
145
146      for (int row = 0; row < n; row++) {
147        for (int column = 0; column < columns; column++) {
148          x[column] = inputData[row, column];
149        }
150        alglib.dfprocess(randomForest, x, ref y);
151        yield return y[0];
152      }
153    }
154    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
155      double[,] inputData = dataset.ToArray(inputVariables, rows);
156      RandomForestUtil.AssertInputMatrix(inputData);
157
158      int n = inputData.GetLength(0);
159      int columns = inputData.GetLength(1);
160      double[] x = new double[columns];
161      double[] ys = new double[this.randomForest.innerobj.ntrees];
162
163      for (int row = 0; row < n; row++) {
164        for (int column = 0; column < columns; column++) {
165          x[column] = inputData[row, column];
166        }
167        alglib.dforest.dfprocessraw(randomForest.innerobj, x, ref ys);
168        yield return ys.VariancePop();
169      }
170    }
171
172    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
173      double[,] inputData = dataset.ToArray(inputVariables, rows);
174      RandomForestUtil.AssertInputMatrix(inputData);
175
176      int n = inputData.GetLength(0);
177      int columns = inputData.GetLength(1);
178      double[] x = new double[columns];
179      double[] y = new double[randomForest.innerobj.nclasses];
180
181      for (int row = 0; row < n; row++) {
182        for (int column = 0; column < columns; column++) {
183          x[column] = inputData[row, column];
184        }
185        alglib.dfprocess(randomForest, x, ref y);
186        // find class for with the largest probability value
187        int maxProbClassIndex = 0;
188        double maxProb = y[0];
189        for (int i = 1; i < y.Length; i++) {
190          if (maxProb < y[i]) {
191            maxProb = y[i];
192            maxProbClassIndex = i;
193          }
194        }
195        yield return classValues[maxProbClassIndex];
196      }
197    }
198
199    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
200      var rf = randomForest;
201      // hoping that the internal representation of alglib is stable
202
203      // TREE FORMAT
204      // W[Offs]      -   size of sub-array (for the tree)
205      //     node info:
206      // W[K+0]       -   variable number        (-1 for leaf mode)
207      // W[K+1]       -   threshold              (class/value for leaf node)
208      // W[K+2]       -   ">=" branch index      (absent for leaf node)
209
210      // skip irrelevant trees
211      int offset = 0;
212      for (int i = 0; i < treeIdx - 1; i++) {
213        offset = offset + (int)Math.Round(rf.innerobj.trees[offset]);
214      }
215
216      var constSy = new Constant();
217      var varCondSy = new VariableCondition() { IgnoreSlope = true };
218
219      var node = CreateRegressionTreeRec(rf.innerobj.trees, offset, offset + 1, constSy, varCondSy);
220
221      var startNode = new StartSymbol().CreateTreeNode();
222      startNode.AddSubtree(node);
223      var root = new ProgramRootSymbol().CreateTreeNode();
224      root.AddSubtree(startNode);
225      return new SymbolicExpressionTree(root);
226    }
227
228    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) {
229
230      // alglib source for evaluation of one tree (dfprocessinternal)
231      // offs = 0
232      //
233      // Set pointer to the root
234      //
235      // k = offs + 1;
236      //
237      // //
238      // // Navigate through the tree
239      // //
240      // while (true) {
241      //   if ((double)(df.trees[k]) == (double)(-1)) {
242      //     if (df.nclasses == 1) {
243      //       y[0] = y[0] + df.trees[k + 1];
244      //     } else {
245      //       idx = (int)Math.Round(df.trees[k + 1]);
246      //       y[idx] = y[idx] + 1;
247      //     }
248      //     break;
249      //   }
250      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
251      //     k = k + innernodewidth;
252      //   } else {
253      //     k = offs + (int)Math.Round(df.trees[k + 2]);
254      //   }
255      // }
256
257      if ((double)(trees[k]) == (double)(-1)) {
258        var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
259        constNode.Value = trees[k + 1];
260        return constNode;
261      } else {
262        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
263        condNode.VariableName = inputVariables[(int)Math.Round(trees[k])];
264        condNode.Threshold = trees[k + 1];
265        condNode.Slope = double.PositiveInfinity;
266
267        var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy);
268        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy);
269
270        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
271        condNode.AddSubtree(right);
272        return condNode;
273      }
274    }
275  }
276}
Note: See TracBrowser for help on using the repository browser.