Changeset 8623 for trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs
- Timestamp:
- 09/10/12 17:13:46 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs
r7259 r8623 51 51 } 52 52 53 private IDiscriminantFunctionThresholdCalculator thresholdCalculator; 54 [Storable] 55 public IDiscriminantFunctionThresholdCalculator ThresholdCalculator { 56 get { return thresholdCalculator; } 57 private set { thresholdCalculator = value; } 58 } 59 53 60 54 61 [StorableConstructor] … … 61 68 } 62 69 63 public DiscriminantFunctionClassificationModel(IRegressionModel model )70 public DiscriminantFunctionClassificationModel(IRegressionModel model, IDiscriminantFunctionThresholdCalculator thresholdCalculator) 64 71 : base() { 65 72 this.name = ItemName; 66 73 this.description = ItemDescription; 67 74 this.model = model; 68 this.classValues = new double[] { 0.0 }; 69 this.thresholds = new double[] { double.NegativeInfinity }; 75 this.classValues = new double[0]; 76 this.thresholds = new double[0]; 77 this.thresholdCalculator = thresholdCalculator; 78 } 79 80 [StorableHook(HookType.AfterDeserialization)] 81 private void AfterDeserialization() { 82 if (ThresholdCalculator == null) ThresholdCalculator = new AccuracyMaximizationThresholdCalculator(); 70 83 } 71 84 … … 80 93 } 81 94 95 public virtual void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) { 96 double[] classValues; 97 double[] thresholds; 98 var targetClassValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 99 var estimatedTrainingValues = GetEstimatedValues(problemData.Dataset, rows); 100 thresholdCalculator.Calculate(problemData, estimatedTrainingValues, targetClassValues, out classValues, out thresholds); 101 SetThresholdsAndClassValues(thresholds, classValues); 102 } 103 104 82 105 public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) { 83 106 return model.GetEstimatedValues(dataset, rows); … … 85 108 86 109 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 110 if (!Thresholds.Any() && !ClassValues.Any()) throw new ArgumentException("No thresholds and class values were set for the current classification model."); 87 111 foreach (var x in GetEstimatedValues(dataset, rows)) { 88 112 int classIndex = 0;
Note: See TracChangeset
for help on using the changeset viewer.