Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/20/11 18:54:39 (12 years ago)
Author:
gkronber
Message:

#1473: implemented random forest wrapper for classification.

File:
1 moved

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r6240 r6241  
    3333namespace HeuristicLab.Algorithms.DataAnalysis {
    3434  /// <summary>
    35   /// Represents a random forest regression model.
     35  /// Represents a random forest model for regression and classification
    3636  /// </summary>
    3737  [StorableClass]
    38   [Item("RandomForestRegressionModel", "Represents a random forest regression model.")]
    39   public sealed class RandomForestRegressionModel : NamedItem, IRandomForestRegressionModel {
     38  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
     39  public sealed class RandomForestModel : NamedItem, IRandomForestModel {
    4040
    4141    private alglib.decisionforest randomForest;
    42     /// <summary>
    43     /// Gets or sets the SVM model.
    44     /// </summary>
    4542    public alglib.decisionforest RandomForest {
    4643      get { return randomForest; }
     
    5855    [Storable]
    5956    private string[] allowedInputVariables;
    60 
     57    [Storable]
     58    private double[] classValues;
    6159    [StorableConstructor]
    62     private RandomForestRegressionModel(bool deserializing)
     60    private RandomForestModel(bool deserializing)
    6361      : base(deserializing) {
    6462      if (deserializing)
    6563        randomForest = new alglib.decisionforest();
    6664    }
    67     private RandomForestRegressionModel(RandomForestRegressionModel original, Cloner cloner)
     65    private RandomForestModel(RandomForestModel original, Cloner cloner)
    6866      : base(original, cloner) {
    6967      randomForest = new alglib.decisionforest();
     
    7573      targetVariable = original.targetVariable;
    7674      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
     75      if (original.classValues != null)
     76        this.classValues = (double[])original.classValues.Clone();
    7777    }
    78     public RandomForestRegressionModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables)
     78    public RandomForestModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
    7979      : base() {
    8080      this.name = ItemName;
     
    8383      this.targetVariable = targetVariable;
    8484      this.allowedInputVariables = allowedInputVariables.ToArray();
     85      if (classValues != null)
     86        this.classValues = (double[])classValues.Clone();
    8587    }
    8688
    8789    public override IDeepCloneable Clone(Cloner cloner) {
    88       return new RandomForestRegressionModel(this, cloner);
     90      return new RandomForestModel(this, cloner);
    8991    }
    9092
     
    103105        alglib.dfprocess(randomForest, x, ref y);
    104106        yield return y[0];
     107      }
     108    }
     109
     110    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
     111      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     112
     113      int n = inputData.GetLength(0);
     114      int columns = inputData.GetLength(1);
     115      double[] x = new double[columns];
     116      double[] y = new double[randomForest.innerobj.nclasses];
     117
     118      for (int row = 0; row < n; row++) {
     119        for (int column = 0; column < columns; column++) {
     120          x[column] = inputData[row, column];
     121        }
     122        alglib.dfprocess(randomForest, x, ref y);
     123        // find class for with the largest probability value
     124        int maxProbClassIndex = 0;
     125        double maxProb = y[0];
     126        for (int i = 1; i < y.Length; i++) {
     127          if (maxProb < y[i]) {
     128            maxProb = y[i];
     129            maxProbClassIndex = i;
     130          }
     131        }
     132        yield return classValues[maxProbClassIndex];
    105133      }
    106134    }
Note: See TracChangeset for help on using the changeset viewer.