Free cookie consent management tool by TermsFeed Policy Generator

source: branches/MPI/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 6345

Last change on this file since 6345 was 6241, checked in by gkronber, 14 years ago

#1473: implemented random forest wrapper for classification.

File size: 6.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.IO;
25using System.Linq;
26using System.Text;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis;
31using SVM;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  /// <summary>
35  /// Represents a random forest model for regression and classification
36  /// </summary>
37  [StorableClass]
38  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
39  public sealed class RandomForestModel : NamedItem, IRandomForestModel {
40
41    private alglib.decisionforest randomForest;
42    public alglib.decisionforest RandomForest {
43      get { return randomForest; }
44      set {
45        if (value != randomForest) {
46          if (value == null) throw new ArgumentNullException();
47          randomForest = value;
48          OnChanged(EventArgs.Empty);
49        }
50      }
51    }
52
53    [Storable]
54    private string targetVariable;
55    [Storable]
56    private string[] allowedInputVariables;
57    [Storable]
58    private double[] classValues;
59    [StorableConstructor]
60    private RandomForestModel(bool deserializing)
61      : base(deserializing) {
62      if (deserializing)
63        randomForest = new alglib.decisionforest();
64    }
65    private RandomForestModel(RandomForestModel original, Cloner cloner)
66      : base(original, cloner) {
67      randomForest = new alglib.decisionforest();
68      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
69      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
70      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
71      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
72      randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone();
73      targetVariable = original.targetVariable;
74      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
75      if (original.classValues != null)
76        this.classValues = (double[])original.classValues.Clone();
77    }
78    public RandomForestModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
79      : base() {
80      this.name = ItemName;
81      this.description = ItemDescription;
82      this.randomForest = randomForest;
83      this.targetVariable = targetVariable;
84      this.allowedInputVariables = allowedInputVariables.ToArray();
85      if (classValues != null)
86        this.classValues = (double[])classValues.Clone();
87    }
88
89    public override IDeepCloneable Clone(Cloner cloner) {
90      return new RandomForestModel(this, cloner);
91    }
92
93    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
94      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
95
96      int n = inputData.GetLength(0);
97      int columns = inputData.GetLength(1);
98      double[] x = new double[columns];
99      double[] y = new double[1];
100
101      for (int row = 0; row < n; row++) {
102        for (int column = 0; column < columns; column++) {
103          x[column] = inputData[row, column];
104        }
105        alglib.dfprocess(randomForest, x, ref y);
106        yield return y[0];
107      }
108    }
109
110    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
111      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
112
113      int n = inputData.GetLength(0);
114      int columns = inputData.GetLength(1);
115      double[] x = new double[columns];
116      double[] y = new double[randomForest.innerobj.nclasses];
117
118      for (int row = 0; row < n; row++) {
119        for (int column = 0; column < columns; column++) {
120          x[column] = inputData[row, column];
121        }
122        alglib.dfprocess(randomForest, x, ref y);
123        // find class for with the largest probability value
124        int maxProbClassIndex = 0;
125        double maxProb = y[0];
126        for (int i = 1; i < y.Length; i++) {
127          if (maxProb < y[i]) {
128            maxProb = y[i];
129            maxProbClassIndex = i;
130          }
131        }
132        yield return classValues[maxProbClassIndex];
133      }
134    }
135
136    #region events
137    public event EventHandler Changed;
138    private void OnChanged(EventArgs e) {
139      var handlers = Changed;
140      if (handlers != null)
141        handlers(this, e);
142    }
143    #endregion
144
145    #region persistence
146    [Storable]
147    private int RandomForestBufSize {
148      get {
149        return randomForest.innerobj.bufsize;
150      }
151      set {
152        randomForest.innerobj.bufsize = value;
153      }
154    }
155    [Storable]
156    private int RandomForestNClasses {
157      get {
158        return randomForest.innerobj.nclasses;
159      }
160      set {
161        randomForest.innerobj.nclasses = value;
162      }
163    }
164    [Storable]
165    private int RandomForestNTrees {
166      get {
167        return randomForest.innerobj.ntrees;
168      }
169      set {
170        randomForest.innerobj.ntrees = value;
171      }
172    }
173    [Storable]
174    private int RandomForestNVars {
175      get {
176        return randomForest.innerobj.nvars;
177      }
178      set {
179        randomForest.innerobj.nvars = value;
180      }
181    }
182    [Storable]
183    private double[] RandomForestTrees {
184      get {
185        return randomForest.innerobj.trees;
186      }
187      set {
188        randomForest.innerobj.trees = value;
189      }
190    }
191    #endregion
192  }
193}
Note: See TracBrowser for help on using the repository browser.