Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/20/11 18:54:39 (14 years ago)
Author:
gkronber
Message:

#1473: implemented random forest wrapper for classification.

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
Files:
2 added
2 edited
1 moved

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r6240 r6241  
    3333namespace HeuristicLab.Algorithms.DataAnalysis {
    3434  /// <summary>
    35   /// Represents a random forest regression model.
     35  /// Represents a random forest model for regression and classification
    3636  /// </summary>
    3737  [StorableClass]
    38   [Item("RandomForestRegressionModel", "Represents a random forest regression model.")]
    39   public sealed class RandomForestRegressionModel : NamedItem, IRandomForestRegressionModel {
     38  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
     39  public sealed class RandomForestModel : NamedItem, IRandomForestModel {
    4040
    4141    private alglib.decisionforest randomForest;
    42     /// <summary>
    43     /// Gets or sets the SVM model.
    44     /// </summary>
    4542    public alglib.decisionforest RandomForest {
    4643      get { return randomForest; }
     
    5855    [Storable]
    5956    private string[] allowedInputVariables;
    60 
     57    [Storable]
     58    private double[] classValues;
    6159    [StorableConstructor]
    62     private RandomForestRegressionModel(bool deserializing)
     60    private RandomForestModel(bool deserializing)
    6361      : base(deserializing) {
    6462      if (deserializing)
    6563        randomForest = new alglib.decisionforest();
    6664    }
    67     private RandomForestRegressionModel(RandomForestRegressionModel original, Cloner cloner)
     65    private RandomForestModel(RandomForestModel original, Cloner cloner)
    6866      : base(original, cloner) {
    6967      randomForest = new alglib.decisionforest();
     
    7573      targetVariable = original.targetVariable;
    7674      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
     75      if (original.classValues != null)
     76        this.classValues = (double[])original.classValues.Clone();
    7777    }
    78     public RandomForestRegressionModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables)
     78    public RandomForestModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
    7979      : base() {
    8080      this.name = ItemName;
     
    8383      this.targetVariable = targetVariable;
    8484      this.allowedInputVariables = allowedInputVariables.ToArray();
     85      if (classValues != null)
     86        this.classValues = (double[])classValues.Clone();
    8587    }
    8688
    8789    public override IDeepCloneable Clone(Cloner cloner) {
    88       return new RandomForestRegressionModel(this, cloner);
     90      return new RandomForestModel(this, cloner);
    8991    }
    9092
     
    103105        alglib.dfprocess(randomForest, x, ref y);
    104106        yield return y[0];
     107      }
     108    }
     109
     110    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
     111      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     112
     113      int n = inputData.GetLength(0);
     114      int columns = inputData.GetLength(1);
     115      double[] x = new double[columns];
     116      double[] y = new double[randomForest.innerobj.nclasses];
     117
     118      for (int row = 0; row < n; row++) {
     119        for (int column = 0; column < columns; column++) {
     120          x[column] = inputData[row, column];
     121        }
     122        alglib.dfprocess(randomForest, x, ref y);
     123        // find class for with the largest probability value
     124        int maxProbClassIndex = 0;
     125        double maxProb = y[0];
     126        for (int i = 1; i < y.Length; i++) {
     127          if (maxProb < y[i]) {
     128            maxProb = y[i];
     129            maxProbClassIndex = i;
     130          }
     131        }
     132        yield return classValues[maxProbClassIndex];
    105133      }
    106134    }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r6240 r6241  
    8787      Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
    8888      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
    89       Results.Add(new Result("Average relative error", "The average of relative errors of the random forest regression solution on the training set.", new DoubleValue(avgRelError)));
    90       Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(outOfBagRmsError)));
    91       Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution on the training set.", new DoubleValue(outOfBagAvgRelError)));
     89      Results.Add(new Result("Average relative error", "The average of relative errors of the random forest regression solution on the training set.", new PercentValue(avgRelError)));
     90      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
     91      Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError)));
    9292    }
    9393
     
    116116      outOfBagRmsError = rep.oobrmserror;
    117117
    118       return new RandomForestRegressionSolution(problemData, new RandomForestRegressionModel(dforest, targetVariable, allowedInputVariables));
     118      return new RandomForestRegressionSolution(problemData, new RandomForestModel(dforest, targetVariable, allowedInputVariables));
    119119    }
    120120    #endregion
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegressionSolution.cs

    r6240 r6241  
    3737  public sealed class RandomForestRegressionSolution : RegressionSolution, IRandomForestRegressionSolution {
    3838
    39     public new IRandomForestRegressionModel Model {
    40       get { return (IRandomForestRegressionModel)base.Model; }
     39    public new IRandomForestModel Model {
     40      get { return (IRandomForestModel)base.Model; }
    4141      set { base.Model = value; }
    4242    }
     
    4747      : base(original, cloner) {
    4848    }
    49     public RandomForestRegressionSolution(IRegressionProblemData problemData, IRandomForestRegressionModel randomForestModel)
     49    public RandomForestRegressionSolution(IRegressionProblemData problemData, IRandomForestModel randomForestModel)
    5050      : base(randomForestModel, problemData) {
    5151    }
Note: See TracChangeset for help on using the changeset viewer.