Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs @ 6406

Last change on this file since 6406 was 6241, checked in by gkronber, 14 years ago

#1473: implemented random forest wrapper for classification.

File size: 7.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Optimization;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32using HeuristicLab.Problems.DataAnalysis.Symbolic;
33using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
34using HeuristicLab.Parameters;
35
36namespace HeuristicLab.Algorithms.DataAnalysis {
37  /// <summary>
38  /// Random forest classification data analysis algorithm.
39  /// </summary>
40  [Item("Random Forest Classification", "Random forest classification data analysis algorithm (wrapper for ALGLIB).")]
41  [Creatable("Data Analysis")]
42  [StorableClass]
43  public sealed class RandomForestClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
44    private const string RandomForestClassificationModelResultName = "Random forest classification solution";
45    private const string NumberOfTreesParameterName = "Number of trees";
46    private const string RParameterName = "R";
47    #region parameter properties
48    public IValueParameter<IntValue> NumberOfTreesParameter {
49      get { return (IValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
50    }
51    public IValueParameter<DoubleValue> RParameter {
52      get { return (IValueParameter<DoubleValue>)Parameters[RParameterName]; }
53    }
54    #endregion
55    #region properties
56    public int NumberOfTrees {
57      get { return NumberOfTreesParameter.Value.Value; }
58      set { NumberOfTreesParameter.Value.Value = value; }
59    }
60    public double R {
61      get { return RParameter.Value.Value; }
62      set { RParameter.Value.Value = value; }
63    }
64    #endregion
65    [StorableConstructor]
66    private RandomForestClassification(bool deserializing) : base(deserializing) { }
67    private RandomForestClassification(RandomForestClassification original, Cloner cloner)
68      : base(original, cloner) {
69    }
70    public RandomForestClassification()
71      : base() {
72      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
73      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
74      Problem = new ClassificationProblem();
75    }
76    [StorableHook(HookType.AfterDeserialization)]
77    private void AfterDeserialization() { }
78
79    public override IDeepCloneable Clone(Cloner cloner) {
80      return new RandomForestClassification(this, cloner);
81    }
82
83    #region random forest
84    protected override void Run() {
85      double rmsError, relClassificationError, outOfBagRmsError, outOfBagRelClassificationError;
86      var solution = CreateRandomForestClassificationSolution(Problem.ProblemData, NumberOfTrees, R, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
87      Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
88      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
89      Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError)));
90      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
91      Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error  of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError)));
92    }
93
94    public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r,
95      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
96      Dataset dataset = problemData.Dataset;
97      string targetVariable = problemData.TargetVariable;
98      IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
99      IEnumerable<int> rows = problemData.TrainingIndizes;
100      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
101      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
102        throw new NotSupportedException("Random forest classification does not support NaN or infinity values in the input dataset.");
103
104
105      alglib.decisionforest dforest;
106      alglib.dfreport rep;
107      int nRows = inputMatrix.GetLength(0);
108      int nCols = inputMatrix.GetLength(1);
109      int info;
110      double[] classValues = dataset.GetVariableValues(targetVariable).Distinct().OrderBy(x => x).ToArray();
111      int nClasses = classValues.Count();
112      // map original class values to values [0..nClasses-1]
113      Dictionary<double, double> classIndizes = new Dictionary<double, double>();
114      for (int i = 0; i < nClasses; i++) {
115        classIndizes[classValues[i]] = i;
116      }
117      for (int row = 0; row < nRows; row++) {
118        inputMatrix[row, nCols - 1] = classIndizes[inputMatrix[row, nCols - 1]];
119      }
120      // execute random forest algorithm
121      alglib.dfbuildrandomdecisionforest(inputMatrix, nRows, nCols - 1, nClasses, nTrees, r, out info, out dforest, out rep);
122      if (info != 1) throw new ArgumentException("Error in calculation of random forest classification solution");
123
124      rmsError = rep.rmserror;
125      outOfBagRmsError = rep.oobrmserror;
126      relClassificationError = rep.relclserror;
127      outOfBagRelClassificationError = rep.oobrelclserror;
128      return new RandomForestClassificationSolution(problemData, new RandomForestModel(dforest, targetVariable, allowedInputVariables, classValues));
129    }
130    #endregion
131  }
132}
Note: See TracBrowser for help on using the repository browser.