Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/28/16 13:33:17 (8 years ago)
Author:
mkommend
Message:

#2604:

  • Base classes for data analysis, classification, and regression models
  • Added target variable to classification and regression models
  • Switched parameter order in data analysis solutions (model, problemdata)
Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r13238 r13941  
    143143
    144144      if (CreateSolution) {
    145         var solution = new RandomForestClassificationSolution((IClassificationProblemData)Problem.ProblemData.Clone(), model);
     145        var solution = new RandomForestClassificationSolution(model, (IClassificationProblemData)Problem.ProblemData.Clone());
    146146        Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
    147147      }
    148148    }
    149    
     149
    150150    // keep for compatibility with old API
    151151    public static RandomForestClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    152152      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    153153      var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    154       return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model);
     154      return new RandomForestClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
    155155    }
    156156
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassificationSolution.cs

    r12012 r13941  
    4343      : base(original, cloner) {
    4444    }
    45     public RandomForestClassificationSolution(IClassificationProblemData problemData, IRandomForestModel randomForestModel)
     45    public RandomForestClassificationSolution(IRandomForestModel randomForestModel, IClassificationProblemData problemData)
    4646      : base(randomForestModel, problemData) {
    4747    }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r13921 r13941  
    3434  [StorableClass]
    3535  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
    36   public sealed class RandomForestModel : NamedItem, IRandomForestModel {
     36  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
    3737    // not persisted
    3838    private alglib.decisionforest randomForest;
     
    4545    }
    4646
    47     public IEnumerable<string> VariablesUsedForPrediction {
     47    public override IEnumerable<string> VariablesUsedForPrediction {
    4848      get { return originalTrainingData.AllowedInputVariables; }
    4949    }
    5050
    51     public string TargetVariable {
    52       get {
    53         var regressionProblemData = originalTrainingData as IRegressionProblemData;
    54         var classificationProblemData = originalTrainingData as IClassificationProblemData;
    55         if (classificationProblemData != null)
    56           return classificationProblemData.TargetVariable;
    57         if (regressionProblemData != null)
    58           return regressionProblemData.TargetVariable;
    59         throw new InvalidOperationException("Getting the target variable requires either a regression or a classification problem data.");
    60       }
    61     }
    6251
    6352    // instead of storing the data of the model itself
     
    10796
    10897    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
    109     private RandomForestModel(alglib.decisionforest randomForest,
     98    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
    11099      int seed, IDataAnalysisProblemData originalTrainingData,
    111100      int nTrees, double r, double m, double[] classValues = null)
    112       : base() {
     101      : base(targetVariable) {
    113102      this.name = ItemName;
    114103      this.description = ItemDescription;
     
    163152    }
    164153
    165     public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
     154    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
    166155      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
    167156      AssertInputMatrix(inputData);
     
    190179    }
    191180
    192     public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    193       return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this);
    194     }
    195     IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
    196       return CreateRegressionSolution(problemData);
    197     }
    198     public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
    199       return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this);
    200     }
    201     IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
    202       return CreateClassificationSolution(problemData);
     181
     182    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     183      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
     184    }
     185    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
     186      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
    203187    }
    204188
     
    221205      outOfBagRmsError = rep.oobrmserror;
    222206
    223       return new RandomForestModel(dForest, seed, problemData, nTrees, r, m);
     207      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
    224208    }
    225209
     
    258242      outOfBagRelClassificationError = rep.oobrelclserror;
    259243
    260       return new RandomForestModel(dForest, seed, problemData, nTrees, r, m, classValues);
     244      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
    261245    }
    262246
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r13238 r13941  
    143143
    144144      if (CreateSolution) {
    145         var solution = new RandomForestRegressionSolution((IRegressionProblemData)Problem.ProblemData.Clone(), model);
     145        var solution = new RandomForestRegressionSolution(model, (IRegressionProblemData)Problem.ProblemData.Clone());
    146146        Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
    147147      }
     
    153153      var model = CreateRandomForestRegressionModel(problemData, nTrees, r, m, seed,
    154154        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
    155       return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), model);
     155      return new RandomForestRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
    156156    }
    157157
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegressionSolution.cs

    r12012 r13941  
    4343      : base(original, cloner) {
    4444    }
    45     public RandomForestRegressionSolution(IRegressionProblemData problemData, IRandomForestModel randomForestModel)
     45    public RandomForestRegressionSolution(IRandomForestModel randomForestModel, IRegressionProblemData problemData)
    4646      : base(randomForestModel, problemData) {
    4747    }
Note: See TracChangeset for help on using the changeset viewer.