- Timestamp:
- 03/14/11 19:00:05 (14 years ago)
- Location:
- branches/DataAnalysis Refactoring
- Files:
-
- 2 added
- 9 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/LinearDiscriminantAnalysis.cs
r5664 r5678 107 107 addition.AddSubTree(cNode); 108 108 109 var model = new SymbolicDiscriminantFunctionClassificationModel(tree, new SymbolicDataAnalysisExpressionTreeInterpreter(), classValues); 109 110 var model = LinearDiscriminantAnalysis.CreateDiscriminantFunctionModel(tree, new SymbolicDataAnalysisExpressionTreeInterpreter(), problemData, rows); 110 111 SymbolicDiscriminantFunctionClassificationSolution solution = new SymbolicDiscriminantFunctionClassificationSolution(model, problemData); 112 111 113 return solution; 112 114 } 113 115 #endregion 116 117 private static SymbolicDiscriminantFunctionClassificationModel CreateDiscriminantFunctionModel(ISymbolicExpressionTree tree, 118 ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, 119 IClassificationProblemData problemData, 120 IEnumerable<int> rows) { 121 string targetVariable = problemData.TargetVariable; 122 List<double> originalClasses = problemData.ClassValues.ToList(); 123 int nClasses = problemData.Classes; 124 List<double> estimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, rows).ToList(); 125 double maxEstimatedValue = estimatedValues.Max(); 126 double minEstimatedValue = estimatedValues.Min(); 127 var estimatedTargetValues = 128 (from row in problemData.TrainingIndizes 129 select new { EstimatedValue = estimatedValues[row], TargetValue = problemData.Dataset[targetVariable, row] }) 130 .ToList(); 131 132 Dictionary<double, double> classMean = new Dictionary<double, double>(); 133 Dictionary<double, double> classStdDev = new Dictionary<double, double>(); 134 // calculate moments per class 135 foreach (var classValue in originalClasses) { 136 var estimatedValuesForClass = from x in estimatedTargetValues 137 where x.TargetValue == classValue 138 select x.EstimatedValue; 139 double mean, variance; 140 OnlineMeanAndVarianceCalculator.Calculate(estimatedValuesForClass, out mean, out variance); 141 classMean[classValue] = mean; 142 classStdDev[classValue] = Math.Sqrt(variance); 143 } 144 List<double> thresholds = new List<double>(); 145 for (int i = 0; i < nClasses - 1; i++) { 146 for (int j = i + 1; j < nClasses; j++) { 147 double x1, x2; 148 double class0 = originalClasses[i]; 149 double class1 = originalClasses[j]; 150 // calculate all thresholds 151 CalculateCutPoints(classMean[class0], classStdDev[class0], classMean[class1], classStdDev[class1], out x1, out x2); 152 if (!thresholds.Any(x => x.IsAlmost(x1))) thresholds.Add(x1); 153 if (!thresholds.Any(x => x.IsAlmost(x2))) thresholds.Add(x2); 154 } 155 } 156 thresholds.Sort(); 157 thresholds.Insert(0, double.NegativeInfinity); 158 thresholds.Add(double.PositiveInfinity); 159 List<double> classValues = new List<double>(); 160 for (int i = 0; i < thresholds.Count - 1; i++) { 161 double m; 162 if (double.IsNegativeInfinity(thresholds[i])) { 163 m = thresholds[i + 1] - 1.0; 164 } else if (double.IsPositiveInfinity(thresholds[i + 1])) { 165 m = thresholds[i] + 1.0; 166 } else { 167 m = thresholds[i] + (thresholds[i + 1] - thresholds[i]) / 2.0; 168 } 169 170 double maxDensity = 0; 171 double maxDensityClassValue = -1; 172 foreach (var classValue in originalClasses) { 173 double density = NormalDensity(m, classMean[classValue], classStdDev[classValue]); 174 if (density > maxDensity) { 175 maxDensity = density; 176 maxDensityClassValue = classValue; 177 } 178 } 179 classValues.Add(maxDensityClassValue); 180 } 181 List<double> filteredThresholds = new List<double>(); 182 List<double> filteredClassValues = new List<double>(); 183 filteredThresholds.Add(thresholds[0]); 184 filteredClassValues.Add(classValues[0]); 185 for (int i = 0; i < classValues.Count - 1; i++) { 186 if (classValues[i] != classValues[i + 1]) { 187 filteredThresholds.Add(thresholds[i + 1]); 188 filteredClassValues.Add(classValues[i + 1]); 189 } 190 } 191 filteredThresholds.Add(double.PositiveInfinity); 192 193 return new SymbolicDiscriminantFunctionClassificationModel(tree, interpreter, filteredClassValues, filteredThresholds); 194 } 195 196 private static double NormalDensity(double x, double mu, double sigma) { 197 return (1.0 / Math.Sqrt(2.0 * Math.PI * sigma * sigma)) * Math.Exp(-((x - mu) * (x - mu)) / (2.0 * sigma * sigma)); 198 } 199 200 private static void CalculateCutPoints(double m1, double s1, double m2, double s2, out double x1, out double x2) { 201 double a = (s1 * s1 - s2 * s2); 202 double b = (m1 * s2 * s2 - m2 * s1 * s1); 203 double c = 2 * s1 * s1 * s2 * s2 * Math.Log(s2) - 2 * s1 * s1 * s2 * s2 * Math.Log(s1) - s1 * s1 * m2 * m2 + s2 * s2 * m1 * m1; 204 x1 = -(-m2 * s1 * s1 + m1 * s2 * s2 + Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (-s1 * s1 + s2 * s2) * Math.Log(s2 / s1)))) / a; 205 x2 = (m2 * s1 * s1 - m1 * s2 * s2 + Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (-s1 * s1 + s2 * s2) * Math.Log(s2 / s1)))) / a; 206 } 114 207 } 115 208 } -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/MultiObjective/SymbolicClassificationMultiObjectiveTrainingBestSolutionAnalyzer.cs
r5649 r5678 70 70 71 71 protected override ISymbolicClassificationSolution CreateSolution(ISymbolicExpressionTree bestTree, double[] bestQuality) { 72 var model = new SymbolicDiscriminantFunctionClassificationModel(bestTree, SymbolicDataAnalysisTreeInterpreter, ProblemData.ClassValues); 72 double[] classValues; 73 double[] thresholds; 74 var estimatedValues = SymbolicDataAnalysisTreeInterpreter.GetSymbolicExpressionTreeValues(bestTree, ProblemData.Dataset, ProblemData.TrainingIndizes); 75 var targetValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 76 DiscriminantFunctionClassificationSolution.CalculateClassThresholds(ProblemData, estimatedValues, targetValues, out classValues, out thresholds); 77 var model = new SymbolicDiscriminantFunctionClassificationModel(bestTree, SymbolicDataAnalysisTreeInterpreter, classValues, thresholds); 73 78 return new SymbolicDiscriminantFunctionClassificationSolution(model, ProblemData); 74 } 79 } 75 80 } 76 81 } -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SingleObjective/SymbolicClassificationSingleObjectiveTrainingBestSolutionAnalyzer.cs
r5649 r5678 68 68 69 69 protected override ISymbolicClassificationSolution CreateSolution(ISymbolicExpressionTree bestTree, double bestQuality) { 70 var model = new SymbolicDiscriminantFunctionClassificationModel(bestTree, SymbolicDataAnalysisTreeInterpreter, ProblemData.ClassValues); 70 double[] classValues; 71 double[] thresholds; 72 var estimatedValues = SymbolicDataAnalysisTreeInterpreter.GetSymbolicExpressionTreeValues(bestTree, ProblemData.Dataset, ProblemData.TrainingIndizes); 73 var targetValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 74 DiscriminantFunctionClassificationSolution.CalculateClassThresholds(ProblemData, estimatedValues, targetValues, out classValues, out thresholds); 75 var model = new SymbolicDiscriminantFunctionClassificationModel(bestTree, SymbolicDataAnalysisTreeInterpreter, classValues, thresholds); 71 76 return new SymbolicDiscriminantFunctionClassificationSolution(model, ProblemData); 72 77 } -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationModel.cs
r5657 r5678 39 39 [Item(Name = "SymbolicDiscriminantFunctionClassificationModel", Description = "Represents a symbolic classification model unsing a discriminant function.")] 40 40 public class SymbolicDiscriminantFunctionClassificationModel : SymbolicDataAnalysisModel, ISymbolicDiscriminantFunctionClassificationModel { 41 [Storable]42 private double[] classValues;43 41 44 42 [Storable] … … 51 49 } 52 50 } 53 51 [Storable] 52 private double[] classValues; 53 public IEnumerable<double> ClassValues { 54 get { return (IEnumerable<double>)classValues.Clone(); } 55 set { classValues = value.ToArray(); } 56 } 54 57 [StorableConstructor] 55 58 protected SymbolicDiscriminantFunctionClassificationModel(bool deserializing) : base(deserializing) { } … … 59 62 thresholds = (double[])original.thresholds.Clone(); 60 63 } 61 public SymbolicDiscriminantFunctionClassificationModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IEnumerable<double> classValues )64 public SymbolicDiscriminantFunctionClassificationModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IEnumerable<double> classValues, IEnumerable<double> thresholds) 62 65 : base(tree, interpreter) { 63 66 this.classValues = classValues.ToArray(); 64 this.thresholds = new double[0];67 this.thresholds = thresholds.ToArray(); 65 68 } 66 69 … … 76 79 foreach (var x in GetEstimatedValues(dataset, rows)) { 77 80 int classIndex = 0; 78 // find first threshold value which is smaller than x => class index = threshold index + 181 // find first threshold value which is larger than x => class index = threshold index + 1 79 82 for (int i = 0; i < thresholds.Length; i++) { 80 83 if (x > thresholds[i]) classIndex++; … … 91 94 if (listener != null) listener(this, e); 92 95 } 93 #endregion 96 #endregion 94 97 } 95 98 } -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationSolution.cs
r5657 r5678 41 41 #region ISymbolicClassificationSolution Members 42 42 43 public new ISymbolicClassificationModel Model { 44 get { return (ISymbolicClassificationModel)base.Model; } 43 public new IDiscriminantFunctionClassificationModel Model { 44 get { return (IDiscriminantFunctionClassificationModel)base.Model; } 45 } 46 47 ISymbolicClassificationModel ISymbolicClassificationSolution.Model { 48 get { return (ISymbolicClassificationModel)Model; } 45 49 } 46 50 47 51 ISymbolicDataAnalysisModel ISymbolicDataAnalysisSolution.Model { 48 get { return (ISymbolicDataAnalysisModel) base.Model; }52 get { return (ISymbolicDataAnalysisModel)Model; } 49 53 } 50 54 -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Views/3.4/HeuristicLab.Problems.DataAnalysis.Views-3.4.csproj
r5664 r5678 122 122 <DependentUpon>ClassificationSolutionEstimatedClassValuesView.cs</DependentUpon> 123 123 </Compile> 124 <Compile Include="Classification\DiscriminantFunctionClassificationSolutionEstimatedClassValuesView.cs"> 125 <SubType>UserControl</SubType> 126 </Compile> 127 <Compile Include="Classification\DiscriminantFunctionClassificationSolutionEstimatedClassValuesView.Designer.cs"> 128 <DependentUpon>DiscriminantFunctionClassificationSolutionEstimatedClassValuesView.cs</DependentUpon> 129 </Compile> 124 130 <Compile Include="Classification\DiscriminantFunctionClassificationRocCurvesView.cs"> 125 131 <SubType>UserControl</SubType> -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/DiscriminantFunctionClassificationModel.cs
r5649 r5678 42 42 [Storable] 43 43 private double[] classValues; 44 45 [StorableConstructor] 46 protected DiscriminantFunctionClassificationModel() : base() { } 47 protected DiscriminantFunctionClassificationModel(DiscriminantFunctionClassificationModel original, Cloner cloner) 48 : base(original, cloner) { 49 model = cloner.Clone(original.model); 50 classValues = (double[])original.classValues.Clone(); 44 // class values are not necessarily sorted in ascending order 45 public IEnumerable<double> ClassValues { 46 get { return (double[])classValues.Clone(); } 47 set { 48 if (value == null) throw new ArgumentException(); 49 double[] newValue = value.ToArray(); 50 if (newValue.Length != classValues.Length) throw new ArgumentException(); 51 classValues = newValue; 52 } 51 53 } 52 public DiscriminantFunctionClassificationModel(IRegressionModel model, IEnumerable<double> classValues) 53 : base() { 54 this.name = ItemName; 55 this.description = ItemDescription; 56 this.model = model; 57 this.classValues = classValues.ToArray(); 58 } 59 60 public override IDeepCloneable Clone(Cloner cloner) { 61 return new DiscriminantFunctionClassificationModel(this, cloner); 62 } 63 64 #region IDiscriminantFunctionClassificationModel Members 65 54 [Storable] 66 55 private double[] thresholds; 67 56 public IEnumerable<double> Thresholds { … … 73 62 } 74 63 75 public event EventHandler ThresholdsChanged; 76 protected virtual void OnThresholdsChanged(EventArgs e) { 77 var listener = ThresholdsChanged; 78 if (listener != null) listener(this, e); 64 65 [StorableConstructor] 66 protected DiscriminantFunctionClassificationModel() : base() { } 67 protected DiscriminantFunctionClassificationModel(DiscriminantFunctionClassificationModel original, Cloner cloner) 68 : base(original, cloner) { 69 model = cloner.Clone(original.model); 70 classValues = (double[])original.classValues.Clone(); 71 thresholds = (double[])original.thresholds.Clone(); 72 } 73 public DiscriminantFunctionClassificationModel(IRegressionModel model, IEnumerable<double> classValues, IEnumerable<double> thresholds) 74 : base() { 75 this.name = ItemName; 76 this.description = ItemDescription; 77 this.model = model; 78 this.classValues = classValues.ToArray(); 79 this.thresholds = thresholds.ToArray(); 80 } 81 82 public override IDeepCloneable Clone(Cloner cloner) { 83 return new DiscriminantFunctionClassificationModel(this, cloner); 79 84 } 80 85 … … 86 91 foreach (var x in GetEstimatedValues(dataset, rows)) { 87 92 int classIndex = 0; 88 // find first threshold value which is smaller than x => class index = threshold index + 193 // find first threshold value which is larger than x => class index = threshold index + 1 89 94 for (int i = 0; i < thresholds.Length; i++) { 90 95 if (x > thresholds[i]) classIndex++; … … 94 99 } 95 100 } 96 101 #region events 102 public event EventHandler ThresholdsChanged; 103 protected virtual void OnThresholdsChanged(EventArgs e) { 104 var listener = ThresholdsChanged; 105 if (listener != null) listener(this, e); 106 } 97 107 #endregion 98 108 } -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/DiscriminantFunctionClassificationSolution.cs
r5664 r5678 44 44 } 45 45 public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData) 46 : this(new DiscriminantFunctionClassificationModel(model, problemData.ClassValues ), problemData) {46 : this(new DiscriminantFunctionClassificationModel(model, problemData.ClassValues, CalculateClassThresholds(model, problemData, problemData.TrainingIndizes)), problemData) { 47 47 } 48 48 public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) … … 92 92 #endregion 93 93 94 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 95 if (Model.Thresholds == null || Model.Thresholds.Count() == 0) RecalculateClassIntermediates(); 96 return base.GetEstimatedClassValues(rows); 94 private static double[] CalculateClassThresholds(IRegressionModel model, IClassificationProblemData problemData, IEnumerable<int> rows) { 95 double[] thresholds; 96 double[] classValues; 97 CalculateClassThresholds(problemData, model.GetEstimatedValues(problemData.Dataset, rows), problemData.Dataset.GetEnumeratedVariableValues(problemData.TargetVariable, rows), out classValues, out thresholds); 98 return thresholds; 97 99 } 98 100 99 p rivate void RecalculateClassIntermediates() {101 public static void CalculateClassThresholds(IClassificationProblemData problemData, IEnumerable<double> estimatedValues, IEnumerable<double> targetClassValues, out double[] classValues, out double[] thresholds) { 100 102 int slices = 100; 101 List<double> estimatedValues = EstimatedValues.ToList(); 102 List<int> classInstances = (from classValue in ProblemData.Dataset.GetVariableValues(ProblemData.TargetVariable) 103 group classValue by classValue into grouping 104 select grouping.Count()).ToList(); 105 double maxEstimatedValue = estimatedValues.Max(); 106 double minEstimatedValue = estimatedValues.Min(); 107 List<KeyValuePair<double, double>> estimatedTargetValues = 108 (from row in ProblemData.TrainingIndizes 109 select new KeyValuePair<double, double>( 110 estimatedValues[row], 111 ProblemData.Dataset[ProblemData.TargetVariable, row])).ToList(); 103 List<double> estimatedValuesList = estimatedValues.ToList(); 104 double maxEstimatedValue = estimatedValuesList.Max(); 105 double minEstimatedValue = estimatedValuesList.Min(); 106 double thresholdIncrement = (maxEstimatedValue - minEstimatedValue) / slices; 107 var estimatedAndTargetValuePairs = 108 estimatedValuesList.Zip(targetClassValues, (x, y) => new { EstimatedValue = x, TargetClassValue = y }) 109 .OrderBy(x => x.EstimatedValue) 110 .ToList(); 112 111 113 List<double> originalClasses = ProblemData.ClassValues.OrderBy(x => x).ToList();114 int nClasses = originalClasses.Distinct().Count();115 double[]thresholds = new double[nClasses + 1];112 classValues = problemData.ClassValues.OrderBy(x => x).ToArray(); 113 int nClasses = classValues.Length; 114 thresholds = new double[nClasses + 1]; 116 115 thresholds[0] = double.NegativeInfinity; 117 116 thresholds[thresholds.Length - 1] = double.PositiveInfinity; 118 117 119 double thresholdIncrement = (maxEstimatedValue - minEstimatedValue) / slices; 118 // incrementally calculate accuracy of all possible thresholds 119 int[,] confusionMatrix = new int[nClasses, nClasses]; 120 120 121 // one threshold is always treated as binary separation of the remaining classes 121 122 for (int i = 1; i < thresholds.Length - 1; i++) { 122 123 double lowerThreshold = thresholds[i - 1]; … … 130 131 double classificationScore = 0.0; 131 132 132 foreach ( KeyValuePair<double, double> estimatedTarget in estimatedTargetValues) {133 foreach (var pair in estimatedAndTargetValuePairs) { 133 134 //all positives 134 if ( estimatedTarget.Value.IsAlmost(originalClasses[i - 1])) {135 if ( estimatedTarget.Key > lowerThreshold && estimatedTarget.Key< actualThreshold)135 if (pair.TargetClassValue.IsAlmost(classValues[i - 1])) { 136 if (pair.EstimatedValue > lowerThreshold && pair.EstimatedValue < actualThreshold) 136 137 //true positive 137 classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i - 1]);138 classificationScore += problemData.GetClassificationPenalty(classValues[i - 1], classValues[i - 1]); 138 139 else 139 140 //false negative 140 classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i - 1]);141 classificationScore += problemData.GetClassificationPenalty(classValues[i], classValues[i - 1]); 141 142 } 142 143 //all negatives 143 144 else { 144 if ( estimatedTarget.Key > lowerThreshold && estimatedTarget.Key< actualThreshold)145 if (pair.EstimatedValue > lowerThreshold && pair.EstimatedValue < actualThreshold) 145 146 //false positive 146 classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i]);147 classificationScore += problemData.GetClassificationPenalty(classValues[i - 1], classValues[i]); 147 148 else 148 149 //true negative, consider only upper class 149 classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i]);150 classificationScore += problemData.GetClassificationPenalty(classValues[i], classValues[i]); 150 151 } 151 152 } … … 167 168 } 168 169 //scale lowest thresholds and highest found optimal threshold according to the misclassification matrix 169 double falseNegativePenalty = ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i - 1]);170 double falsePositivePenalty = ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i]);170 double falseNegativePenalty = problemData.GetClassificationPenalty(classValues[i], classValues[i - 1]); 171 double falsePositivePenalty = problemData.GetClassificationPenalty(classValues[i - 1], classValues[i]); 171 172 thresholds[i] = (lowestBestThreshold * falsePositivePenalty + highestBestThreshold * falseNegativePenalty) / (falseNegativePenalty + falsePositivePenalty); 172 173 } 173 Thresholds = new List<double>(thresholds);174 174 } 175 175 } -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IDiscriminantFunctionClassificationModel.cs
r5657 r5678 25 25 public interface IDiscriminantFunctionClassificationModel : IClassificationModel { 26 26 IEnumerable<double> Thresholds { get; set; } 27 IEnumerable<double> ClassValues { get; set; } 27 28 event EventHandler ThresholdsChanged; 28 29 IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows);
Note: See TracChangeset
for help on using the changeset viewer.