Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModelFull.cs @ 18190

Last change on this file since 18190 was 18132, checked in by gkronber, 3 years ago

#3140: merged r18091:18131 from branch to trunk

File size: 10.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HEAL.Attic;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  [StorableType("55412E08-DAD4-4C2E-9181-C142E7EA9474")]
34  [Item("RandomForestModelFull", "Represents a random forest for regression and classification.")]
35  public sealed class RandomForestModelFull : ClassificationModel, IRandomForestModel {
36
37    public override IEnumerable<string> VariablesUsedForPrediction {
38      get { return inputVariables; }
39    }
40
41    [Storable]
42    private double[] classValues;
43
44    public int NumClasses => classValues == null ? 0 : classValues.Length;
45
46    [Storable]
47    private string[] inputVariables;
48
49    [Storable]
50    public int NumberOfTrees {
51      get; private set;
52    }
53
54    // not persisted
55    private alglib.decisionforest randomForest;
56
57    [Storable]
58    private string RandomForestSerialized {
59      get { alglib.dfserialize(randomForest, out var serialized); return serialized; }
60      set { if (value != null) alglib.dfunserialize(value, out randomForest); }
61    }
62
63    [StorableConstructor]
64    private RandomForestModelFull(StorableConstructorFlag _) : base(_) { }
65
66    private RandomForestModelFull(RandomForestModelFull original, Cloner cloner) : base(original, cloner) {
67      if (original.randomForest != null)
68        randomForest = (alglib.decisionforest)original.randomForest.make_copy();
69      NumberOfTrees = original.NumberOfTrees;
70
71      // following fields are immutable so we don't need to clone them
72      inputVariables = original.inputVariables;
73      classValues = original.classValues;
74    }
75    public override IDeepCloneable Clone(Cloner cloner) {
76      return new RandomForestModelFull(this, cloner);
77    }
78
79    public RandomForestModelFull(alglib.decisionforest decisionForest, int nTrees, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<double> classValues = null) : base(targetVariable) {
80      this.name = ItemName;
81      this.description = ItemDescription;
82
83      randomForest = (alglib.decisionforest)decisionForest.make_copy();
84      NumberOfTrees = nTrees;
85
86      this.inputVariables = inputVariables.ToArray();
87
88      //classValues are only use for classification models
89      if (classValues == null) this.classValues = new double[0];
90      else this.classValues = classValues.ToArray();
91    }
92
93
94    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
95      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
96    }
97    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
98      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
99    }
100
101    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
102      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
103    }
104
105    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
106      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
107
108      var regressionProblemData = problemData as IRegressionProblemData;
109      if (regressionProblemData != null)
110        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
111
112      var classificationProblemData = problemData as IClassificationProblemData;
113      if (classificationProblemData != null)
114        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
115
116      throw new ArgumentException("The problem data is not compatible with this random forest. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
117    }
118
119    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
120      double[,] inputData = dataset.ToArray(inputVariables, rows);
121      RandomForestUtil.AssertInputMatrix(inputData);
122
123      int n = inputData.GetLength(0);
124      int columns = inputData.GetLength(1);
125      double[] x = new double[columns];
126      double[] y = new double[1];
127
128      alglib.dfcreatebuffer(randomForest, out var buf);
129      for (int row = 0; row < n; row++) {
130        for (int column = 0; column < columns; column++) {
131          x[column] = inputData[row, column];
132        }
133        alglib.dftsprocess(randomForest, buf, x, ref y); // thread-safe process (as long as separate buffers are used)
134        yield return y[0];
135      }
136    }
137    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
138      double[,] inputData = dataset.ToArray(inputVariables, rows);
139      RandomForestUtil.AssertInputMatrix(inputData);
140
141      int n = inputData.GetLength(0);
142      int columns = inputData.GetLength(1);
143      double[] x = new double[columns];
144      double[] ys = new double[this.randomForest.innerobj.ntrees];
145
146      for (int row = 0; row < n; row++) {
147        for (int column = 0; column < columns; column++) {
148          x[column] = inputData[row, column];
149        }
150        lock (randomForest)
151          alglib.dforest.dfprocessraw(randomForest.innerobj, x, ref ys, null);
152        yield return ys.VariancePop();
153      }
154    }
155
156    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
157      double[,] inputData = dataset.ToArray(inputVariables, rows);
158      RandomForestUtil.AssertInputMatrix(inputData);
159
160      int n = inputData.GetLength(0);
161      int columns = inputData.GetLength(1);
162      double[] x = new double[columns];
163      double[] y = new double[NumClasses];
164
165      alglib.dfcreatebuffer(randomForest, out var buf);
166      for (int row = 0; row < n; row++) {
167        for (int column = 0; column < columns; column++) {
168          x[column] = inputData[row, column];
169        }
170        alglib.dftsprocess(randomForest, buf, x, ref y);
171        // find class for with the largest probability value
172        int maxProbClassIndex = 0;
173        double maxProb = y[0];
174        for (int i = 1; i < y.Length; i++) {
175          if (maxProb < y[i]) {
176            maxProb = y[i];
177            maxProbClassIndex = i;
178          }
179        }
180        yield return classValues[maxProbClassIndex];
181      }
182    }
183
184    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
185      var rf = randomForest;
186      // hoping that the internal representation of alglib is stable
187
188      // TREE FORMAT
189      // W[Offs]      -   size of sub-array (for the tree)
190      //     node info:
191      // W[K+0]       -   variable number        (-1 for leaf mode)
192      // W[K+1]       -   threshold              (class/value for leaf node)
193      // W[K+2]       -   ">=" branch index      (absent for leaf node)
194
195      // skip irrelevant trees
196      int offset = 0;
197      for (int i = 0; i < treeIdx - 1; i++) {
198        offset = offset + (int)Math.Round(rf.innerobj.trees[offset]);
199      }
200
201      var numSy = new Number();
202      var varCondSy = new VariableCondition() { IgnoreSlope = true };
203
204      var node = CreateRegressionTreeRec(rf.innerobj.trees, offset, offset + 1, numSy, varCondSy);
205
206      var startNode = new StartSymbol().CreateTreeNode();
207      startNode.AddSubtree(node);
208      var root = new ProgramRootSymbol().CreateTreeNode();
209      root.AddSubtree(startNode);
210      return new SymbolicExpressionTree(root);
211    }
212
213    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Number numSy, VariableCondition varCondSy) {
214
215      // alglib source for evaluation of one tree (dfprocessinternal)
216      // offs = 0
217      //
218      // Set pointer to the root
219      //
220      // k = offs + 1;
221      //
222      // //
223      // // Navigate through the tree
224      // //
225      // while (true) {
226      //   if ((double)(df.trees[k]) == (double)(-1)) {
227      //     if (df.nclasses == 1) {
228      //       y[0] = y[0] + df.trees[k + 1];
229      //     } else {
230      //       idx = (int)Math.Round(df.trees[k + 1]);
231      //       y[idx] = y[idx] + 1;
232      //     }
233      //     break;
234      //   }
235      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
236      //     k = k + innernodewidth;
237      //   } else {
238      //     k = offs + (int)Math.Round(df.trees[k + 2]);
239      //   }
240      // }
241
242      if (trees[k] == -1) {
243        var numNode = (NumberTreeNode)numSy.CreateTreeNode();
244        numNode.Value = trees[k + 1];
245        return numNode;
246      } else {
247        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
248        condNode.VariableName = inputVariables[(int)Math.Round(trees[k])];
249        condNode.Threshold = trees[k + 1];
250        condNode.Slope = double.PositiveInfinity;
251
252        var left = CreateRegressionTreeRec(trees, offset, k + 3, numSy, varCondSy);
253        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), numSy, varCondSy);
254
255        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
256        condNode.AddSubtree(right);
257        return condNode;
258      }
259    }
260  }
261}
Note: See TracBrowser for help on using the repository browser.