source: branches/2994-AutoDiffForIntervals/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModelFull.cs @ 17209

Last change on this file since 17209 was 17209, checked in by gkronber, 3 months ago

#2994: merged r17132:17198 from trunk to branch

File size: 10.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HEAL.Attic;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  [StorableType("9C797DF0-1169-4381-A732-6DAB90802839")]
34  [Item("RandomForestModelFull", "Represents a random forest for regression and classification.")]
35  public sealed class RandomForestModelFull : ClassificationModel, IRandomForestModel {
36
37    public override IEnumerable<string> VariablesUsedForPrediction {
38      get { return inputVariables; }
39    }
40
41    [Storable]
42    private double[] classValues;
43
44    [Storable]
45    private string[] inputVariables;
46
47    public int NumberOfTrees {
48      get { return RandomForestNTrees; }
49    }
50
51    // not persisted
52    private alglib.decisionforest randomForest;
53
54    [Storable]
55    private int RandomForestBufSize {
56      get { return randomForest.innerobj.bufsize; }
57      set { randomForest.innerobj.bufsize = value; }
58    }
59    [Storable]
60    private int RandomForestNClasses {
61      get { return randomForest.innerobj.nclasses; }
62      set { randomForest.innerobj.nclasses = value; }
63    }
64    [Storable]
65    private int RandomForestNTrees {
66      get { return randomForest.innerobj.ntrees; }
67      set { randomForest.innerobj.ntrees = value; }
68    }
69    [Storable]
70    private int RandomForestNVars {
71      get { return randomForest.innerobj.nvars; }
72      set { randomForest.innerobj.nvars = value; }
73    }
74    [Storable]
75    private double[] RandomForestTrees {
76      get { return randomForest.innerobj.trees; }
77      set { randomForest.innerobj.trees = value; }
78    }
79
80    [StorableConstructor]
81    private RandomForestModelFull(StorableConstructorFlag _) : base(_) {
82      randomForest = new alglib.decisionforest();
83    }
84
85    private RandomForestModelFull(RandomForestModelFull original, Cloner cloner) : base(original, cloner) {
86      randomForest = new alglib.decisionforest();
87      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
88      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
89      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
90      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
91      randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone();
92
93      // following fields are immutable so we don't need to clone them
94      inputVariables = original.inputVariables;
95      classValues = original.classValues;
96    }
97    public override IDeepCloneable Clone(Cloner cloner) {
98      return new RandomForestModelFull(this, cloner);
99    }
100
101    public RandomForestModelFull(alglib.decisionforest decisionForest, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<double> classValues = null) : base(targetVariable) {
102      this.name = ItemName;
103      this.description = ItemDescription;
104
105      randomForest = decisionForest;
106
107      this.inputVariables = inputVariables.ToArray();
108
109      //classValues are only use for classification models
110      if (classValues == null) this.classValues = new double[0];
111      else this.classValues = classValues.ToArray();
112    }
113
114
115    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
116      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
117    }
118    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
119      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
120    }
121
122    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
123      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
124    }
125
126    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
127      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
128
129      var regressionProblemData = problemData as IRegressionProblemData;
130      if (regressionProblemData != null)
131        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
132
133      var classificationProblemData = problemData as IClassificationProblemData;
134      if (classificationProblemData != null)
135        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
136
137      throw new ArgumentException("The problem data is not compatible with this random forest. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
138    }
139
140    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
141      double[,] inputData = dataset.ToArray(inputVariables, rows);
142      RandomForestUtil.AssertInputMatrix(inputData);
143
144      int n = inputData.GetLength(0);
145      int columns = inputData.GetLength(1);
146      double[] x = new double[columns];
147      double[] y = new double[1];
148
149      for (int row = 0; row < n; row++) {
150        for (int column = 0; column < columns; column++) {
151          x[column] = inputData[row, column];
152        }
153        alglib.dfprocess(randomForest, x, ref y);
154        yield return y[0];
155      }
156    }
157    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
158      double[,] inputData = dataset.ToArray(inputVariables, rows);
159      RandomForestUtil.AssertInputMatrix(inputData);
160
161      int n = inputData.GetLength(0);
162      int columns = inputData.GetLength(1);
163      double[] x = new double[columns];
164      double[] ys = new double[this.randomForest.innerobj.ntrees];
165
166      for (int row = 0; row < n; row++) {
167        for (int column = 0; column < columns; column++) {
168          x[column] = inputData[row, column];
169        }
170        alglib.dforest.dfprocessraw(randomForest.innerobj, x, ref ys);
171        yield return ys.VariancePop();
172      }
173    }
174
175    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
176      double[,] inputData = dataset.ToArray(inputVariables, rows);
177      RandomForestUtil.AssertInputMatrix(inputData);
178
179      int n = inputData.GetLength(0);
180      int columns = inputData.GetLength(1);
181      double[] x = new double[columns];
182      double[] y = new double[randomForest.innerobj.nclasses];
183
184      for (int row = 0; row < n; row++) {
185        for (int column = 0; column < columns; column++) {
186          x[column] = inputData[row, column];
187        }
188        alglib.dfprocess(randomForest, x, ref y);
189        // find class for with the largest probability value
190        int maxProbClassIndex = 0;
191        double maxProb = y[0];
192        for (int i = 1; i < y.Length; i++) {
193          if (maxProb < y[i]) {
194            maxProb = y[i];
195            maxProbClassIndex = i;
196          }
197        }
198        yield return classValues[maxProbClassIndex];
199      }
200    }
201
202    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
203      var rf = randomForest;
204      // hoping that the internal representation of alglib is stable
205
206      // TREE FORMAT
207      // W[Offs]      -   size of sub-array (for the tree)
208      //     node info:
209      // W[K+0]       -   variable number        (-1 for leaf mode)
210      // W[K+1]       -   threshold              (class/value for leaf node)
211      // W[K+2]       -   ">=" branch index      (absent for leaf node)
212
213      // skip irrelevant trees
214      int offset = 0;
215      for (int i = 0; i < treeIdx - 1; i++) {
216        offset = offset + (int)Math.Round(rf.innerobj.trees[offset]);
217      }
218
219      var constSy = new Constant();
220      var varCondSy = new VariableCondition() { IgnoreSlope = true };
221
222      var node = CreateRegressionTreeRec(rf.innerobj.trees, offset, offset + 1, constSy, varCondSy);
223
224      var startNode = new StartSymbol().CreateTreeNode();
225      startNode.AddSubtree(node);
226      var root = new ProgramRootSymbol().CreateTreeNode();
227      root.AddSubtree(startNode);
228      return new SymbolicExpressionTree(root);
229    }
230
231    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) {
232
233      // alglib source for evaluation of one tree (dfprocessinternal)
234      // offs = 0
235      //
236      // Set pointer to the root
237      //
238      // k = offs + 1;
239      //
240      // //
241      // // Navigate through the tree
242      // //
243      // while (true) {
244      //   if ((double)(df.trees[k]) == (double)(-1)) {
245      //     if (df.nclasses == 1) {
246      //       y[0] = y[0] + df.trees[k + 1];
247      //     } else {
248      //       idx = (int)Math.Round(df.trees[k + 1]);
249      //       y[idx] = y[idx] + 1;
250      //     }
251      //     break;
252      //   }
253      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
254      //     k = k + innernodewidth;
255      //   } else {
256      //     k = offs + (int)Math.Round(df.trees[k + 2]);
257      //   }
258      // }
259
260      if ((double)(trees[k]) == (double)(-1)) {
261        var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
262        constNode.Value = trees[k + 1];
263        return constNode;
264      } else {
265        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
266        condNode.VariableName = inputVariables[(int)Math.Round(trees[k])];
267        condNode.Threshold = trees[k + 1];
268        condNode.Slope = double.PositiveInfinity;
269
270        var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy);
271        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy);
272
273        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
274        condNode.AddSubtree(right);
275        return condNode;
276      }
277    }
278  }
279}
Note: See TracBrowser for help on using the repository browser.