Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegressionModel.cs @ 6240

Last change on this file since 6240 was 6240, checked in by gkronber, 12 years ago

#1473: Implemented wrapper for ALGLIB random forest for regression.

File size: 5.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.IO;
25using System.Linq;
26using System.Text;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis;
31using SVM;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  /// <summary>
35  /// Represents a random forest regression model.
36  /// </summary>
37  [StorableClass]
38  [Item("RandomForestRegressionModel", "Represents a random forest regression model.")]
39  public sealed class RandomForestRegressionModel : NamedItem, IRandomForestRegressionModel {
40
41    private alglib.decisionforest randomForest;
42    /// <summary>
43    /// Gets or sets the SVM model.
44    /// </summary>
45    public alglib.decisionforest RandomForest {
46      get { return randomForest; }
47      set {
48        if (value != randomForest) {
49          if (value == null) throw new ArgumentNullException();
50          randomForest = value;
51          OnChanged(EventArgs.Empty);
52        }
53      }
54    }
55
56    [Storable]
57    private string targetVariable;
58    [Storable]
59    private string[] allowedInputVariables;
60
61    [StorableConstructor]
62    private RandomForestRegressionModel(bool deserializing)
63      : base(deserializing) {
64      if (deserializing)
65        randomForest = new alglib.decisionforest();
66    }
67    private RandomForestRegressionModel(RandomForestRegressionModel original, Cloner cloner)
68      : base(original, cloner) {
69      randomForest = new alglib.decisionforest();
70      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
71      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
72      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
73      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
74      randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone();
75      targetVariable = original.targetVariable;
76      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
77    }
78    public RandomForestRegressionModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables)
79      : base() {
80      this.name = ItemName;
81      this.description = ItemDescription;
82      this.randomForest = randomForest;
83      this.targetVariable = targetVariable;
84      this.allowedInputVariables = allowedInputVariables.ToArray();
85    }
86
87    public override IDeepCloneable Clone(Cloner cloner) {
88      return new RandomForestRegressionModel(this, cloner);
89    }
90
91    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
92      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
93
94      int n = inputData.GetLength(0);
95      int columns = inputData.GetLength(1);
96      double[] x = new double[columns];
97      double[] y = new double[1];
98
99      for (int row = 0; row < n; row++) {
100        for (int column = 0; column < columns; column++) {
101          x[column] = inputData[row, column];
102        }
103        alglib.dfprocess(randomForest, x, ref y);
104        yield return y[0];
105      }
106    }
107
108    #region events
109    public event EventHandler Changed;
110    private void OnChanged(EventArgs e) {
111      var handlers = Changed;
112      if (handlers != null)
113        handlers(this, e);
114    }
115    #endregion
116
117    #region persistence
118    [Storable]
119    private int RandomForestBufSize {
120      get {
121        return randomForest.innerobj.bufsize;
122      }
123      set {
124        randomForest.innerobj.bufsize = value;
125      }
126    }
127    [Storable]
128    private int RandomForestNClasses {
129      get {
130        return randomForest.innerobj.nclasses;
131      }
132      set {
133        randomForest.innerobj.nclasses = value;
134      }
135    }
136    [Storable]
137    private int RandomForestNTrees {
138      get {
139        return randomForest.innerobj.ntrees;
140      }
141      set {
142        randomForest.innerobj.ntrees = value;
143      }
144    }
145    [Storable]
146    private int RandomForestNVars {
147      get {
148        return randomForest.innerobj.nvars;
149      }
150      set {
151        randomForest.innerobj.nvars = value;
152      }
153    }
154    [Storable]
155    private double[] RandomForestTrees {
156      get {
157        return randomForest.innerobj.trees;
158      }
159      set {
160        randomForest.innerobj.trees = value;
161      }
162    }
163    #endregion
164  }
165}
Note: See TracBrowser for help on using the repository browser.