Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/29/16 10:36:52 (8 years ago)
Author:
pfleck
Message:

#2597

  • Merged recent trunk changes.
  • Adapted VariablesUsedForPrediction property for RegressionSolutionTargetResponseGradientView.
  • Fixed a reference (.dll to project ref).
Location:
branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis

  • branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r12509 r13948  
    3434  [StorableClass]
    3535  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
    36   public sealed class RandomForestModel : NamedItem, IRandomForestModel {
     36  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
    3737    // not persisted
    3838    private alglib.decisionforest randomForest;
     
    4444      }
    4545    }
     46
     47    public override IEnumerable<string> VariablesUsedForPrediction {
     48      get { return originalTrainingData.AllowedInputVariables; }
     49    }
     50
    4651
    4752    // instead of storing the data of the model itself
     
    9196
    9297    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
    93     private RandomForestModel(alglib.decisionforest randomForest,
     98    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
    9499      int seed, IDataAnalysisProblemData originalTrainingData,
    95100      int nTrees, double r, double m, double[] classValues = null)
    96       : base() {
     101      : base(targetVariable) {
    97102      this.name = ItemName;
    98103      this.description = ItemDescription;
     
    147152    }
    148153
    149     public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
     154    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
    150155      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
    151156      AssertInputMatrix(inputData);
     
    174179    }
    175180
    176     public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    177       return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this);
    178     }
    179     IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
    180       return CreateRegressionSolution(problemData);
    181     }
    182     public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
    183       return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this);
    184     }
    185     IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
    186       return CreateClassificationSolution(problemData);
     181
     182    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     183      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
     184    }
     185    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
     186      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
    187187    }
    188188
     
    205205      outOfBagRmsError = rep.oobrmserror;
    206206
    207       return new RandomForestModel(dForest, seed, problemData, nTrees, r, m);
     207      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
    208208    }
    209209
     
    242242      outOfBagRelClassificationError = rep.oobrelclserror;
    243243
    244       return new RandomForestModel(dForest, seed, problemData, nTrees, r, m, classValues);
     244      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
    245245    }
    246246
Note: See TracChangeset for help on using the changeset viewer.