Changeset 14029 for branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
- Timestamp:
- 07/08/16 14:40:02 (8 years ago)
- Location:
- branches/crossvalidation-2434
- Files:
-
- 8 edited
- 2 copied
Legend:
- Unmodified
- Added
- Removed
-
branches/crossvalidation-2434
- Property svn:mergeinfo changed
-
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis
- Property svn:mergeinfo changed
-
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleModel.cs
r12509 r14029 32 32 [StorableClass] 33 33 [Item("ClassificationEnsembleModel", "A classification model that contains an ensemble of multiple classification models")] 34 public class ClassificationEnsembleModel : NamedItem, IClassificationEnsembleModel { 34 public class ClassificationEnsembleModel : ClassificationModel, IClassificationEnsembleModel { 35 public override IEnumerable<string> VariablesUsedForPrediction { 36 get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); } 37 } 35 38 36 39 [Storable] … … 49 52 public ClassificationEnsembleModel() : this(Enumerable.Empty<IClassificationModel>()) { } 50 53 public ClassificationEnsembleModel(IEnumerable<IClassificationModel> models) 51 : base( ) {54 : base(string.Empty) { 52 55 this.name = ItemName; 53 56 this.description = ItemDescription; 54 57 this.models = new List<IClassificationModel>(models); 58 59 if (this.models.Any()) this.TargetVariable = this.models.First().TargetVariable; 55 60 } 56 61 … … 59 64 } 60 65 61 #region IClassificationEnsembleModel Members62 66 public void Add(IClassificationModel model) { 67 if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable; 63 68 models.Add(model); 64 69 } 65 70 public void Remove(IClassificationModel model) { 66 71 models.Remove(model); 72 if (!models.Any()) TargetVariable = string.Empty; 67 73 } 68 74 … … 78 84 } 79 85 80 #endregion81 86 82 #region IClassificationModel Members 83 84 public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) { 87 public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) { 85 88 foreach (var estimatedValuesVector in GetEstimatedClassValueVectors(dataset, rows)) { 86 89 // return the class which is most often occuring … … 94 97 } 95 98 96 IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {99 public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 97 100 return new ClassificationEnsembleSolution(models, new ClassificationEnsembleProblemData(problemData)); 98 101 } 99 #endregion 102 103 100 104 } 101 105 } -
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationPerformanceMeasures.cs
r12012 r14029 37 37 protected const string TrainingFalsePositiveRateResultName = "False positive rate (training)"; 38 38 protected const string TrainingFalseDiscoveryRateResultName = "False discovery rate (training)"; 39 protected const string TrainingF1ScoreResultName = "F1 score (training)"; 40 protected const string TrainingMatthewsCorrelationResultName = "Matthews Correlation (training)"; 39 41 protected const string TestTruePositiveRateResultName = "True positive rate (test)"; 40 42 protected const string TestTrueNegativeRateResultName = "True negative rate (test)"; … … 43 45 protected const string TestFalsePositiveRateResultName = "False positive rate (test)"; 44 46 protected const string TestFalseDiscoveryRateResultName = "False discovery rate (test)"; 47 protected const string TestF1ScoreResultName = "F1 score (test)"; 48 protected const string TestMatthewsCorrelationResultName = "Matthews Correlation (test)"; 45 49 #endregion 46 50 … … 89 93 set { ((DoubleValue)this[TrainingFalseDiscoveryRateResultName].Value).Value = value; } 90 94 } 95 public double TrainingF1Score { 96 get { return ((DoubleValue)this[TrainingF1ScoreResultName].Value).Value; } 97 set { ((DoubleValue)this[TrainingF1ScoreResultName].Value).Value = value; } 98 } 99 public double TrainingMatthewsCorrelation { 100 get { return ((DoubleValue)this[TrainingMatthewsCorrelationResultName].Value).Value; } 101 set { ((DoubleValue)this[TrainingMatthewsCorrelationResultName].Value).Value = value; } 102 } 91 103 public double TestTruePositiveRate { 92 104 get { return ((DoubleValue)this[TestTruePositiveRateResultName].Value).Value; } … … 112 124 get { return ((DoubleValue)this[TestFalseDiscoveryRateResultName].Value).Value; } 113 125 set { ((DoubleValue)this[TestFalseDiscoveryRateResultName].Value).Value = value; } 126 } 127 public double TestF1Score { 128 get { return ((DoubleValue)this[TestF1ScoreResultName].Value).Value; } 129 set { ((DoubleValue)this[TestF1ScoreResultName].Value).Value = value; } 130 } 131 public double TestMatthewsCorrelation { 132 get { return ((DoubleValue)this[TestMatthewsCorrelationResultName].Value).Value; } 133 set { ((DoubleValue)this[TestMatthewsCorrelationResultName].Value).Value = value; } 114 134 } 115 135 #endregion … … 123 143 Add(new Result(TrainingFalsePositiveRateResultName, "The false positive rate is the complement of the true negative rate of the model on the training partition.", new PercentValue())); 124 144 Add(new Result(TrainingFalseDiscoveryRateResultName, "The false discovery rate is the complement of the positive predictive value of the model on the training partition.", new PercentValue())); 145 Add(new Result(TrainingF1ScoreResultName, "The F1 score of the model on the training partition.", new DoubleValue())); 146 Add(new Result(TrainingMatthewsCorrelationResultName, "The Matthews correlation value of the model on the training partition.", new DoubleValue())); 125 147 Add(new Result(TestTruePositiveRateResultName, "Sensitivity/True positive rate of the model on the test partition\n(TP/(TP+FN)).", new PercentValue())); 126 148 Add(new Result(TestTrueNegativeRateResultName, "Specificity/True negative rate of the model on the test partition\n(TN/(FP+TN)).", new PercentValue())); … … 129 151 Add(new Result(TestFalsePositiveRateResultName, "The false positive rate is the complement of the true negative rate of the model on the test partition.", new PercentValue())); 130 152 Add(new Result(TestFalseDiscoveryRateResultName, "The false discovery rate is the complement of the positive predictive value of the model on the test partition.", new PercentValue())); 153 Add(new Result(TestF1ScoreResultName, "The F1 score of the model on the test partition.", new DoubleValue())); 154 Add(new Result(TestMatthewsCorrelationResultName, "The Matthews correlation value of the model on the test partition.", new DoubleValue())); 155 156 Reset(); 157 } 158 159 160 public void Reset() { 131 161 TrainingTruePositiveRate = double.NaN; 132 162 TrainingTrueNegativeRate = double.NaN; … … 135 165 TrainingFalsePositiveRate = double.NaN; 136 166 TrainingFalseDiscoveryRate = double.NaN; 167 TrainingF1Score = double.NaN; 168 TrainingMatthewsCorrelation = double.NaN; 137 169 TestTruePositiveRate = double.NaN; 138 170 TestTrueNegativeRate = double.NaN; … … 141 173 TestFalsePositiveRate = double.NaN; 142 174 TestFalseDiscoveryRate = double.NaN; 175 TestF1Score = double.NaN; 176 TestMatthewsCorrelation = double.NaN; 143 177 } 144 178 -
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblem.cs
r12504 r14029 35 35 public override IDeepCloneable Clone(Cloner cloner) { return new ClassificationProblem(this, cloner); } 36 36 37 public ClassificationProblem() 38 : base() { 39 ProblemData = new ClassificationProblemData(); 40 } 37 public ClassificationProblem() : base(new ClassificationProblemData()) { } 41 38 } 42 39 } -
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblemData.cs
r12509 r14029 283 283 private void AfterDeserialization() { 284 284 RegisterParameterEvents(); 285 286 classNamesCache = new List<string>(); 287 for (int i = 0; i < ClassNamesParameter.Value.Rows; i++) 288 classNamesCache.Add(ClassNamesParameter.Value[i, 0]); 289 285 290 // BackwardsCompatibility3.4 286 291 #region Backwards compatible code, remove with 3.5 … … 297 302 : base(original, cloner) { 298 303 RegisterParameterEvents(); 304 classNamesCache = new List<string>(); 305 for (int i = 0; i < ClassNamesParameter.Value.Rows; i++) 306 classNamesCache.Add(ClassNamesParameter.Value[i, 0]); 299 307 } 300 308 public override IDeepCloneable Clone(Cloner cloner) { -
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolutionBase.cs
r12012 r14029 26 26 using HeuristicLab.Optimization; 27 27 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 using HeuristicLab.Problems.DataAnalysis.OnlineCalculators; 28 29 29 30 namespace HeuristicLab.Problems.DataAnalysis { … … 128 129 TestNormalizedGiniCoefficient = testNormalizedGini; 129 130 131 ClassificationPerformanceMeasures.Reset(); 132 130 133 trainingPerformanceCalculator.Calculate(originalTrainingClassValues, estimatedTrainingClassValues); 131 134 if (trainingPerformanceCalculator.ErrorState == OnlineCalculatorError.None) … … 135 138 if (testPerformanceCalculator.ErrorState == OnlineCalculatorError.None) 136 139 ClassificationPerformanceMeasures.SetTestResults(testPerformanceCalculator); 140 141 if (ProblemData.Classes == 2) { 142 var f1Training = FOneScoreCalculator.Calculate(originalTrainingClassValues, estimatedTrainingClassValues, out errorState); 143 if (errorState == OnlineCalculatorError.None) ClassificationPerformanceMeasures.TrainingF1Score = f1Training; 144 var f1Test = FOneScoreCalculator.Calculate(originalTestClassValues, estimatedTestClassValues, out errorState); 145 if (errorState == OnlineCalculatorError.None) ClassificationPerformanceMeasures.TestF1Score = f1Test; 146 } 147 148 var mccTraining = MatthewsCorrelationCoefficientCalculator.Calculate(originalTrainingClassValues, estimatedTrainingClassValues, out errorState); 149 if (errorState == OnlineCalculatorError.None) ClassificationPerformanceMeasures.TrainingMatthewsCorrelation = mccTraining; 150 var mccTest = MatthewsCorrelationCoefficientCalculator.Calculate(originalTestClassValues, estimatedTestClassValues, out errorState); 151 if (errorState == OnlineCalculatorError.None) ClassificationPerformanceMeasures.TestMatthewsCorrelation = mccTest; 137 152 } 138 153 -
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs
r12509 r14029 33 33 [StorableClass] 34 34 [Item("DiscriminantFunctionClassificationModel", "Represents a classification model that uses a discriminant function and classification thresholds.")] 35 public class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel { 35 public class DiscriminantFunctionClassificationModel : ClassificationModel, IDiscriminantFunctionClassificationModel { 36 public override IEnumerable<string> VariablesUsedForPrediction { 37 get { return model.VariablesUsedForPrediction; } 38 } 39 36 40 [Storable] 37 41 private IRegressionModel model; … … 73 77 74 78 public DiscriminantFunctionClassificationModel(IRegressionModel model, IDiscriminantFunctionThresholdCalculator thresholdCalculator) 75 : base( ) {79 : base(model.TargetVariable) { 76 80 this.name = ItemName; 77 81 this.description = ItemDescription; 82 78 83 this.model = model; 79 84 this.classValues = new double[0]; … … 115 120 } 116 121 117 public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {122 public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) { 118 123 if (!Thresholds.Any() && !ClassValues.Any()) throw new ArgumentException("No thresholds and class values were set for the current classification model."); 119 124 foreach (var x in GetEstimatedValues(dataset, rows)) { … … 135 140 #endregion 136 141 137 public virtual IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData) { 142 public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 143 return CreateDiscriminantFunctionClassificationSolution(problemData); 144 } 145 public virtual IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution( 146 IClassificationProblemData problemData) { 138 147 return new DiscriminantFunctionClassificationSolution(this, new ClassificationProblemData(problemData)); 139 }140 141 public virtual IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {142 return CreateDiscriminantFunctionClassificationSolution(problemData);143 148 } 144 149 }
Note: See TracChangeset
for help on using the changeset viewer.