Changeset 5736 for branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
- Timestamp:
- 03/17/11 15:14:45 (14 years ago)
- Location:
- branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
- Files:
-
- 5 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs
r5730 r5736 82 82 } 83 83 84 pr ivatevoid RecalculateResults() {84 protected void RecalculateResults() { 85 85 double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values 86 86 IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs
r5730 r5736 40 40 [Storable] 41 41 private IRegressionModel model; 42 42 43 [Storable] 43 44 private double[] classValues; 44 // class values are not necessarily sorted in ascending order45 45 public IEnumerable<double> ClassValues { 46 46 get { return (double[])classValues.Clone(); } 47 set { 48 if (value == null) throw new ArgumentException(); 49 double[] newValue = value.ToArray(); 50 if (newValue.Length != classValues.Length) throw new ArgumentException(); 51 classValues = newValue; 52 } 47 private set { classValues = value.ToArray(); } 53 48 } 49 54 50 [Storable] 55 51 private double[] thresholds; 56 52 public IEnumerable<double> Thresholds { 57 53 get { return (IEnumerable<double>)thresholds.Clone(); } 58 set { 59 thresholds = value.ToArray(); 60 OnThresholdsChanged(EventArgs.Empty); 61 } 54 private set { thresholds = value.ToArray(); } 62 55 } 63 56 … … 71 64 thresholds = (double[])original.thresholds.Clone(); 72 65 } 73 public DiscriminantFunctionClassificationModel(IRegressionModel model, IEnumerable<double> classValues, IEnumerable<double> thresholds) 66 67 public DiscriminantFunctionClassificationModel(IRegressionModel model) 74 68 : base() { 75 69 this.name = ItemName; 76 70 this.description = ItemDescription; 77 71 this.model = model; 78 this.classValues = classValues.ToArray();79 this.thresholds = thresholds.ToArray();72 this.classValues = new double[] { 0.0 }; 73 this.thresholds = new double[] { double.NegativeInfinity }; 80 74 } 81 75 82 76 public override IDeepCloneable Clone(Cloner cloner) { 83 77 return new DiscriminantFunctionClassificationModel(this, cloner); 78 } 79 80 public void SetThresholdsAndClassValues(IEnumerable<double> thresholds, IEnumerable<double> classValues) { 81 var classValuesArr = classValues.ToArray(); 82 var thresholdsArr = thresholds.ToArray(); 83 if (thresholdsArr.Length != classValuesArr.Length) throw new ArgumentException(); 84 85 this.classValues = classValuesArr; 86 this.thresholds = thresholdsArr; 87 OnThresholdsChanged(EventArgs.Empty); 84 88 } 85 89 … … 96 100 else break; 97 101 } 98 yield return classValues.ElementAt(classIndex );102 yield return classValues.ElementAt(classIndex - 1); 99 103 } 100 104 } -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs
r5730 r5736 40 40 public new IDiscriminantFunctionClassificationModel Model { 41 41 get { return (IDiscriminantFunctionClassificationModel)base.Model; } 42 protected set { base.Model = value; } 42 protected set { 43 if (value != null && value != Model) { 44 if (Model != null) { 45 Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged); 46 } 47 value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); 48 base.Model = value; 49 } 50 } 43 51 } 44 52 … … 47 55 protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner) 48 56 : base(original, cloner) { 57 RegisterEventHandler(); 49 58 } 50 public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData , IEnumerable<double> classValues, IEnumerable<double> thresholds)51 : this(new DiscriminantFunctionClassificationModel(model , classValues, thresholds), problemData) {59 public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData) 60 : this(new DiscriminantFunctionClassificationModel(model), problemData) { 52 61 } 53 62 public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) 54 63 : base(model, problemData) { 64 RegisterEventHandler(); 65 SetAccuracyMaximizingThresholds(); 66 } 67 68 [StorableHook(HookType.AfterDeserialization)] 69 private void AfterDeserialization() { 70 RegisterEventHandler(); 71 } 72 73 private void RegisterEventHandler() { 74 Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); 75 } 76 private void Model_ThresholdsChanged(object sender, EventArgs e) { 77 OnModelThresholdsChanged(e); 78 } 79 80 public void SetAccuracyMaximizingThresholds() { 81 double[] classValues; 82 double[] thresholds; 83 var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 84 AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 85 86 Model.SetThresholdsAndClassValues(thresholds, classValues); 87 } 88 89 public void SetClassDistibutionCutPointThresholds() { 90 double[] classValues; 91 double[] thresholds; 92 var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 93 NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 94 95 Model.SetThresholdsAndClassValues(thresholds, classValues); 96 } 97 98 protected override void OnModelChanged(EventArgs e) { 99 base.OnModelChanged(e); 100 SetAccuracyMaximizingThresholds(); 101 } 102 103 protected override void OnProblemDataChanged(EventArgs e) { 104 base.OnProblemDataChanged(e); 105 SetAccuracyMaximizingThresholds(); 106 } 107 protected virtual void OnModelThresholdsChanged(EventArgs e) { 108 RecalculateResults(); 55 109 } 56 110 -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/AccuracyMaximizationThresholdCalculator.cs
r5730 r5736 69 69 classValues = problemData.ClassValues.OrderBy(x => x).ToArray(); 70 70 int nClasses = classValues.Length; 71 thresholds = new double[nClasses + 1];71 thresholds = new double[nClasses]; 72 72 thresholds[0] = double.NegativeInfinity; 73 thresholds[thresholds.Length - 1] = double.PositiveInfinity;73 // thresholds[thresholds.Length - 1] = double.PositiveInfinity; 74 74 75 75 // incrementally calculate accuracy of all possible thresholds 76 76 int[,] confusionMatrix = new int[nClasses, nClasses]; 77 77 78 for (int i = 1; i < thresholds.Length - 1; i++) {78 for (int i = 1; i < thresholds.Length; i++) { 79 79 double lowerThreshold = thresholds[i - 1]; 80 80 double actualThreshold = Math.Max(lowerThreshold, minEstimatedValue); -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/NormalDistributionCutPointsThresholdCalculator.cs
r5730 r5736 88 88 thresholdList.Sort(); 89 89 thresholdList.Insert(0, double.NegativeInfinity); 90 thresholdList.Add(double.PositiveInfinity);91 90 92 91 // determine class values for each partition separated by a threshold by calculating the density of all class distributions 93 92 // all points in the partition are classified as the class with the maximal density in the parition 94 93 List<double> classValuesList = new List<double>(); 95 for (int i = 0; i < thresholdList.Count - 1; i++) {94 for (int i = 0; i < thresholdList.Count; i++) { 96 95 double m; 97 96 if (double.IsNegativeInfinity(thresholdList[i])) { 98 97 m = thresholdList[i + 1] - 1.0; // smaller than the smalles non-infinity threshold 99 } else if (double.IsPositiveInfinity(thresholdList[i + 1])) { 100 m = thresholdList[i] + 1.0; // larger than the largest non-infinity threshold 98 } else if (i == thresholdList.Count - 1) { 99 // last threshold 100 m = thresholdList[i] + 1.0; // larger than the last threshold 101 101 } else { 102 102 m = thresholdList[i] + (thresholdList[i + 1] - thresholdList[i]) / 2.0; // middle of partition … … 135 135 } 136 136 } 137 filteredThresholds.Add(double.PositiveInfinity);138 137 thresholds = filteredThresholds.ToArray(); 139 138 classValues = filteredClassValues.ToArray();
Note: See TracChangeset
for help on using the changeset viewer.