Changeset 8623 for trunk/sources/HeuristicLab.Problems.DataAnalysis
- Timestamp:
- 09/10/12 17:13:46 (12 years ago)
- Location:
- trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs
r7259 r8623 51 51 } 52 52 53 private IDiscriminantFunctionThresholdCalculator thresholdCalculator; 54 [Storable] 55 public IDiscriminantFunctionThresholdCalculator ThresholdCalculator { 56 get { return thresholdCalculator; } 57 private set { thresholdCalculator = value; } 58 } 59 53 60 54 61 [StorableConstructor] … … 61 68 } 62 69 63 public DiscriminantFunctionClassificationModel(IRegressionModel model )70 public DiscriminantFunctionClassificationModel(IRegressionModel model, IDiscriminantFunctionThresholdCalculator thresholdCalculator) 64 71 : base() { 65 72 this.name = ItemName; 66 73 this.description = ItemDescription; 67 74 this.model = model; 68 this.classValues = new double[] { 0.0 }; 69 this.thresholds = new double[] { double.NegativeInfinity }; 75 this.classValues = new double[0]; 76 this.thresholds = new double[0]; 77 this.thresholdCalculator = thresholdCalculator; 78 } 79 80 [StorableHook(HookType.AfterDeserialization)] 81 private void AfterDeserialization() { 82 if (ThresholdCalculator == null) ThresholdCalculator = new AccuracyMaximizationThresholdCalculator(); 70 83 } 71 84 … … 80 93 } 81 94 95 public virtual void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) { 96 double[] classValues; 97 double[] thresholds; 98 var targetClassValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 99 var estimatedTrainingValues = GetEstimatedValues(problemData.Dataset, rows); 100 thresholdCalculator.Calculate(problemData, estimatedTrainingValues, targetClassValues, out classValues, out thresholds); 101 SetThresholdsAndClassValues(thresholds, classValues); 102 } 103 104 82 105 public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) { 83 106 return model.GetEstimatedValues(dataset, rows); … … 85 108 86 109 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 110 if (!Thresholds.Any() && !ClassValues.Any()) throw new ArgumentException("No thresholds and class values were set for the current classification model."); 87 111 foreach (var x in GetEstimatedValues(dataset, rows)) { 88 112 int classIndex = 0; -
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/NormalDistributionCutPointsThresholdCalculator.cs
r7259 r8623 107 107 double maxDensityClassValue = -1; 108 108 foreach (var classValue in originalClasses) { 109 double density = NormalDensity(m, classMean[classValue], classStdDev[classValue]);109 double density = LogNormalDensity(m, classMean[classValue], classStdDev[classValue]); 110 110 if (density > maxDensity) { 111 111 maxDensity = density; … … 139 139 } 140 140 141 private static double NormalDensity(double x, double mu, double sigma) { 142 if (sigma.IsAlmost(0.0)) { 143 if (x.IsAlmost(mu)) return 1.0; else return 0.0; 144 } else { 145 return (1.0 / Math.Sqrt(2.0 * Math.PI * sigma * sigma)) * Math.Exp(-((x - mu) * (x - mu)) / (2.0 * sigma * sigma)); 146 } 141 private static double LogNormalDensity(double x, double mu, double sigma) { 142 return -0.5 * Math.Log(2.0 * Math.PI * sigma * sigma) - ((x - mu) * (x - mu)) / (2.0 * sigma * sigma); 147 143 } 148 144 -
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IDiscriminantFunctionClassificationModel.cs
r7259 r8623 26 26 IEnumerable<double> Thresholds { get; } 27 27 IEnumerable<double> ClassValues { get; } 28 IDiscriminantFunctionThresholdCalculator ThresholdCalculator { get; } 29 void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows); 28 30 // class values and thresholds can only be assigned simultanously 29 31 void SetThresholdsAndClassValues(IEnumerable<double> thresholds, IEnumerable<double> classValues);
Note: See TracChangeset
for help on using the changeset viewer.