Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModelAlglib_3_7.cs @ 18190

Last change on this file since 18190 was 18132, checked in by gkronber, 3 years ago

#3140: merged r18091:18131 from branch to trunk

File size: 11.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22extern alias alglib_3_7;
23using System;
24using System.Collections.Generic;
25using System.Linq;
26using HEAL.Attic;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Problems.DataAnalysis;
31using HeuristicLab.Problems.DataAnalysis.Symbolic;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34
35  [StorableType("9C797DF0-1169-4381-A732-6DAB90802839")]
36  [Item("RandomForestModel (alglib 3.7)", "Represents a random forest for regression and classification.")]
37  [Obsolete("This version uses alglib version 3.7. Use RandomForestModelFull instead.")]
38  public sealed class RandomForestModelAlglib_3_7 : ClassificationModel, IRandomForestModel {
39
40    public override IEnumerable<string> VariablesUsedForPrediction {
41      get { return inputVariables; }
42    }
43
44    [Storable]
45    private double[] classValues;
46
47    [Storable]
48    private string[] inputVariables;
49
50    public int NumberOfTrees {
51      get { return RandomForestNTrees; }
52    }
53
54    // not persisted
55    private alglib_3_7.alglib.decisionforest randomForest;
56
57    [Storable]
58    private int RandomForestBufSize {
59      get { return randomForest.innerobj.bufsize; }
60      set { randomForest.innerobj.bufsize = value; }
61    }
62    [Storable]
63    private int RandomForestNClasses {
64      get { return randomForest.innerobj.nclasses; }
65      set { randomForest.innerobj.nclasses = value; }
66    }
67    [Storable]
68    private int RandomForestNTrees {
69      get { return randomForest.innerobj.ntrees; }
70      set { randomForest.innerobj.ntrees = value; }
71    }
72    [Storable]
73    private int RandomForestNVars {
74      get { return randomForest.innerobj.nvars; }
75      set { randomForest.innerobj.nvars = value; }
76    }
77    [Storable]
78    private double[] RandomForestTrees {
79      get { return randomForest.innerobj.trees; }
80      set { randomForest.innerobj.trees = value; }
81    }
82
83    [StorableConstructor]
84    private RandomForestModelAlglib_3_7(StorableConstructorFlag _) : base(_) {
85      randomForest = new alglib_3_7.alglib.decisionforest();
86    }
87
88    private RandomForestModelAlglib_3_7(RandomForestModelAlglib_3_7 original, Cloner cloner) : base(original, cloner) {
89      randomForest = new alglib_3_7.alglib.decisionforest();
90      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
91      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
92      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
93      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
94      randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone();
95
96      // following fields are immutable so we don't need to clone them
97      inputVariables = original.inputVariables;
98      classValues = original.classValues;
99    }
100    public override IDeepCloneable Clone(Cloner cloner) {
101      return new RandomForestModelAlglib_3_7(this, cloner);
102    }
103
104    public RandomForestModelAlglib_3_7(alglib_3_7.alglib.decisionforest decisionForest, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<double> classValues = null) : base(targetVariable) {
105      this.name = ItemName;
106      this.description = ItemDescription;
107
108      randomForest = decisionForest;
109
110      this.inputVariables = inputVariables.ToArray();
111
112      //classValues are only use for classification models
113      if (classValues == null) this.classValues = new double[0];
114      else this.classValues = classValues.ToArray();
115    }
116
117
118    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
119      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
120    }
121    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
122      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
123    }
124
125    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
126      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
127    }
128
129    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
130      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
131
132      var regressionProblemData = problemData as IRegressionProblemData;
133      if (regressionProblemData != null)
134        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
135
136      var classificationProblemData = problemData as IClassificationProblemData;
137      if (classificationProblemData != null)
138        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
139
140      throw new ArgumentException("The problem data is not compatible with this random forest. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
141    }
142
143    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
144      double[,] inputData = dataset.ToArray(inputVariables, rows);
145      RandomForestUtil.AssertInputMatrix(inputData);
146
147      int n = inputData.GetLength(0);
148      int columns = inputData.GetLength(1);
149      double[] x = new double[columns];
150      double[] y = new double[1];
151
152      for (int row = 0; row < n; row++) {
153        for (int column = 0; column < columns; column++) {
154          x[column] = inputData[row, column];
155        }
156        alglib_3_7.alglib.dfprocess(randomForest, x, ref y);
157        yield return y[0];
158      }
159    }
160    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
161      double[,] inputData = dataset.ToArray(inputVariables, rows);
162      RandomForestUtil.AssertInputMatrix(inputData);
163
164      int n = inputData.GetLength(0);
165      int columns = inputData.GetLength(1);
166      double[] x = new double[columns];
167      double[] ys = new double[this.randomForest.innerobj.ntrees];
168
169      for (int row = 0; row < n; row++) {
170        for (int column = 0; column < columns; column++) {
171          x[column] = inputData[row, column];
172        }
173        alglib_3_7.alglib.dforest.dfprocessraw(randomForest.innerobj, x, ref ys);
174        yield return ys.VariancePop();
175      }
176    }
177
178    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
179      double[,] inputData = dataset.ToArray(inputVariables, rows);
180      RandomForestUtil.AssertInputMatrix(inputData);
181
182      int n = inputData.GetLength(0);
183      int columns = inputData.GetLength(1);
184      double[] x = new double[columns];
185      double[] y = new double[randomForest.innerobj.nclasses];
186
187      for (int row = 0; row < n; row++) {
188        for (int column = 0; column < columns; column++) {
189          x[column] = inputData[row, column];
190        }
191        alglib_3_7::alglib.dfprocess(randomForest, x, ref y);
192        // find class for with the largest probability value
193        int maxProbClassIndex = 0;
194        double maxProb = y[0];
195        for (int i = 1; i < y.Length; i++) {
196          if (maxProb < y[i]) {
197            maxProb = y[i];
198            maxProbClassIndex = i;
199          }
200        }
201        yield return classValues[maxProbClassIndex];
202      }
203    }
204
205    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
206      var rf = randomForest;
207      // hoping that the internal representation of alglib is stable
208
209      // TREE FORMAT
210      // W[Offs]      -   size of sub-array (for the tree)
211      //     node info:
212      // W[K+0]       -   variable number        (-1 for leaf mode)
213      // W[K+1]       -   threshold              (class/value for leaf node)
214      // W[K+2]       -   ">=" branch index      (absent for leaf node)
215
216      // skip irrelevant trees
217      int offset = 0;
218      for (int i = 0; i < treeIdx - 1; i++) {
219        offset = offset + (int)Math.Round(rf.innerobj.trees[offset]);
220      }
221
222      var numSy = new Number();
223      var varCondSy = new VariableCondition() { IgnoreSlope = true };
224
225      var node = CreateRegressionTreeRec(rf.innerobj.trees, offset, offset + 1, numSy, varCondSy);
226
227      var startNode = new StartSymbol().CreateTreeNode();
228      startNode.AddSubtree(node);
229      var root = new ProgramRootSymbol().CreateTreeNode();
230      root.AddSubtree(startNode);
231      return new SymbolicExpressionTree(root);
232    }
233
234    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Number numSy, VariableCondition varCondSy) {
235
236      // alglib source for evaluation of one tree (dfprocessinternal)
237      // offs = 0
238      //
239      // Set pointer to the root
240      //
241      // k = offs + 1;
242      //
243      // //
244      // // Navigate through the tree
245      // //
246      // while (true) {
247      //   if ((double)(df.trees[k]) == (double)(-1)) {
248      //     if (df.nclasses == 1) {
249      //       y[0] = y[0] + df.trees[k + 1];
250      //     } else {
251      //       idx = (int)Math.Round(df.trees[k + 1]);
252      //       y[idx] = y[idx] + 1;
253      //     }
254      //     break;
255      //   }
256      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
257      //     k = k + innernodewidth;
258      //   } else {
259      //     k = offs + (int)Math.Round(df.trees[k + 2]);
260      //   }
261      // }
262
263      if (trees[k] == -1) {
264        var numNode = (NumberTreeNode)numSy.CreateTreeNode();
265        numNode.Value = trees[k + 1];
266        return numNode;
267      } else {
268        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
269        condNode.VariableName = inputVariables[(int)Math.Round(trees[k])];
270        condNode.Threshold = trees[k + 1];
271        condNode.Slope = double.PositiveInfinity;
272
273        var left = CreateRegressionTreeRec(trees, offset, k + 3, numSy, varCondSy);
274        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), numSy, varCondSy);
275
276        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
277        condNode.AddSubtree(right);
278        return condNode;
279      }
280    }
281  }
282}
Note: See TracBrowser for help on using the repository browser.