Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/08/16 14:40:02 (9 years ago)
Author:
gkronber
Message:

#2434: merged trunk changes r12934:14026 from trunk to branch

Location:
branches/crossvalidation-2434
Files:
7 edited

Legend:

Unmodified
Added
Removed
  • branches/crossvalidation-2434

  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis

  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r12504 r14029  
    3232  /// Random forest classification data analysis algorithm.
    3333  /// </summary>
    34   [Item("Random Forest Classification", "Random forest classification data analysis algorithm (wrapper for ALGLIB).")]
     34  [Item("Random Forest Classification (RF)", "Random forest classification data analysis algorithm (wrapper for ALGLIB).")]
    3535  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 120)]
    3636  [StorableClass]
     
    4242    private const string SeedParameterName = "Seed";
    4343    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     44    private const string CreateSolutionParameterName = "CreateSolution";
    4445
    4546    #region parameter properties
     
    5859    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
    5960      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
     61    }
     62    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
     63      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
    6064    }
    6165    #endregion
     
    8185      set { SetSeedRandomlyParameter.Value.Value = value; }
    8286    }
     87    public bool CreateSolution {
     88      get { return CreateSolutionParameter.Value.Value; }
     89      set { CreateSolutionParameter.Value.Value = value; }
     90    }
    8391    #endregion
    8492
     
    96104      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    97105      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     106      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     107      Parameters[CreateSolutionParameterName].Hidden = true;
     108
    98109      Problem = new ClassificationProblem();
    99110    }
     
    101112    [StorableHook(HookType.AfterDeserialization)]
    102113    private void AfterDeserialization() {
     114      // BackwardsCompatibility3.3
     115      #region Backwards compatible code, remove with 3.4
    103116      if (!Parameters.ContainsKey(MParameterName))
    104117        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     
    107120      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
    108121        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     122      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
     123        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     124        Parameters[CreateSolutionParameterName].Hidden = true;
     125      }
     126      #endregion
    109127    }
    110128
     
    118136      if (SetSeedRandomly) Seed = new System.Random().Next();
    119137
    120       var solution = CreateRandomForestClassificationSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    121       Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
     138      var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    122139      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
    123140      Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError)));
    124141      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
    125142      Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error  of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError)));
     143
     144      if (CreateSolution) {
     145        var solution = new RandomForestClassificationSolution(model, (IClassificationProblemData)Problem.ProblemData.Clone());
     146        Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
     147      }
    126148    }
    127149
    128     public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     150    // keep for compatibility with old API
     151    public static RandomForestClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    129152      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    130       var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    131       return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model);
     153      var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     154      return new RandomForestClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
     155    }
     156
     157    public static RandomForestModel CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     158      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
     159      return RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    132160    }
    133161    #endregion
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassificationSolution.cs

    r12012 r14029  
    4343      : base(original, cloner) {
    4444    }
    45     public RandomForestClassificationSolution(IClassificationProblemData problemData, IRandomForestModel randomForestModel)
     45    public RandomForestClassificationSolution(IRandomForestModel randomForestModel, IClassificationProblemData problemData)
    4646      : base(randomForestModel, problemData) {
    4747    }
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r12509 r14029  
    3434  [StorableClass]
    3535  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
    36   public sealed class RandomForestModel : NamedItem, IRandomForestModel {
     36  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
    3737    // not persisted
    3838    private alglib.decisionforest randomForest;
     
    4444      }
    4545    }
     46
     47    public override IEnumerable<string> VariablesUsedForPrediction {
     48      get { return originalTrainingData.AllowedInputVariables; }
     49    }
     50
    4651
    4752    // instead of storing the data of the model itself
     
    9196
    9297    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
    93     private RandomForestModel(alglib.decisionforest randomForest,
     98    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
    9499      int seed, IDataAnalysisProblemData originalTrainingData,
    95100      int nTrees, double r, double m, double[] classValues = null)
    96       : base() {
     101      : base(targetVariable) {
    97102      this.name = ItemName;
    98103      this.description = ItemDescription;
     
    147152    }
    148153
    149     public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
     154    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
    150155      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
    151156      AssertInputMatrix(inputData);
     
    174179    }
    175180
    176     public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    177       return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this);
    178     }
    179     IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
    180       return CreateRegressionSolution(problemData);
    181     }
    182     public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
    183       return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this);
    184     }
    185     IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
    186       return CreateClassificationSolution(problemData);
     181
     182    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     183      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
     184    }
     185    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
     186      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
    187187    }
    188188
     
    205205      outOfBagRmsError = rep.oobrmserror;
    206206
    207       return new RandomForestModel(dForest, seed, problemData, nTrees, r, m);
     207      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
    208208    }
    209209
     
    242242      outOfBagRelClassificationError = rep.oobrelclserror;
    243243
    244       return new RandomForestModel(dForest, seed, problemData, nTrees, r, m, classValues);
     244      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
    245245    }
    246246
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r12504 r14029  
    3232  /// Random forest regression data analysis algorithm.
    3333  /// </summary>
    34   [Item("Random Forest Regression", "Random forest regression data analysis algorithm (wrapper for ALGLIB).")]
     34  [Item("Random Forest Regression (RF)", "Random forest regression data analysis algorithm (wrapper for ALGLIB).")]
    3535  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
    3636  [StorableClass]
     
    4242    private const string SeedParameterName = "Seed";
    4343    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     44    private const string CreateSolutionParameterName = "CreateSolution";
    4445
    4546    #region parameter properties
     
    5859    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
    5960      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
     61    }
     62    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
     63      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
    6064    }
    6165    #endregion
     
    8185      set { SetSeedRandomlyParameter.Value.Value = value; }
    8286    }
     87    public bool CreateSolution {
     88      get { return CreateSolutionParameter.Value.Value; }
     89      set { CreateSolutionParameter.Value.Value = value; }
     90    }
    8391    #endregion
    8492    [StorableConstructor]
     
    95103      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    96104      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     105      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     106      Parameters[CreateSolutionParameterName].Hidden = true;
     107
    97108      Problem = new RegressionProblem();
    98109    }
     
    100111    [StorableHook(HookType.AfterDeserialization)]
    101112    private void AfterDeserialization() {
     113      // BackwardsCompatibility3.3
     114      #region Backwards compatible code, remove with 3.4
    102115      if (!Parameters.ContainsKey(MParameterName))
    103116        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     
    106119      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
    107120        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     121      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
     122        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     123        Parameters[CreateSolutionParameterName].Hidden = true;
     124      }
     125      #endregion
    108126    }
    109127
     
    116134      double rmsError, avgRelError, outOfBagRmsError, outOfBagAvgRelError;
    117135      if (SetSeedRandomly) Seed = new System.Random().Next();
     136      var model = CreateRandomForestRegressionModel(Problem.ProblemData, NumberOfTrees, R, M, Seed,
     137        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
    118138
    119       var solution = CreateRandomForestRegressionSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
    120       Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
    121139      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
    122140      Results.Add(new Result("Average relative error", "The average of relative errors of the random forest regression solution on the training set.", new PercentValue(avgRelError)));
    123141      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
    124142      Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError)));
     143
     144      if (CreateSolution) {
     145        var solution = new RandomForestRegressionSolution(model, (IRegressionProblemData)Problem.ProblemData.Clone());
     146        Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
     147      }
    125148    }
    126149
    127     public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
     150    // keep for compatibility with old API
     151    public static RandomForestRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    128152      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
    129       var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
    130       return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), model);
     153      var model = CreateRandomForestRegressionModel(problemData, nTrees, r, m, seed,
     154        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     155      return new RandomForestRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
    131156    }
     157
     158    public static RandomForestModel CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
     159      double r, double m, int seed,
     160      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     161      return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     162    }
     163
    132164    #endregion
    133165  }
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegressionSolution.cs

    r12012 r14029  
    4343      : base(original, cloner) {
    4444    }
    45     public RandomForestRegressionSolution(IRegressionProblemData problemData, IRandomForestModel randomForestModel)
     45    public RandomForestRegressionSolution(IRandomForestModel randomForestModel, IRegressionProblemData problemData)
    4646      : base(randomForestModel, problemData) {
    4747    }
Note: See TracChangeset for help on using the changeset viewer.