Changeset 14242 for branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers
- Timestamp:
- 08/08/16 11:39:03 (8 years ago)
- Location:
- branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers
- Files:
-
- 2 added
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers/OneR.cs
r14185 r14242 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; … … 64 65 65 66 public static IClassificationSolution CreateOneRSolution(IClassificationProblemData problemData, int minBucketSize = 6) { 67 var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices); 68 var model1 = FindBestDoubleVariableModel(problemData, minBucketSize); 69 var model2 = FindBestFactorModel(problemData); 70 71 if (model1 == null && model2 == null) throw new InvalidProgramException("Could not create OneR solution"); 72 else if (model1 == null) return new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone()); 73 else if (model2 == null) return new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone()); 74 else { 75 var model1EstimatedValues = model1.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices); 76 var model1NumCorrect = classValues.Zip(model1EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e); 77 78 var model2EstimatedValues = model2.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices); 79 var model2NumCorrect = classValues.Zip(model2EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e); 80 81 if (model1NumCorrect > model2NumCorrect) { 82 return new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone()); 83 } else { 84 return new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone()); 85 } 86 } 87 } 88 89 private static OneRClassificationModel FindBestDoubleVariableModel(IClassificationProblemData problemData, int minBucketSize = 6) { 66 90 var bestClassified = 0; 67 91 List<Split> bestSplits = null; … … 70 94 var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices); 71 95 72 foreach (var variable in problemData.AllowedInputVariables) { 96 var allowedInputVariables = problemData.AllowedInputVariables.Where(problemData.Dataset.VariableHasType<double>); 97 98 if (!allowedInputVariables.Any()) return null; 99 100 foreach (var variable in allowedInputVariables) { 73 101 var inputValues = problemData.Dataset.GetDoubleValues(variable, problemData.TrainingIndices); 74 102 var samples = inputValues.Zip(classValues, (i, v) => new Sample(i, v)).OrderBy(s => s.inputValue); 75 103 76 var missingValuesDistribution = samples.Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue).ToDictionary(s => s.Key, s => s.Count()).MaxItems(s => s.Value).FirstOrDefault(); 104 var missingValuesDistribution = samples 105 .Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue) 106 .ToDictionary(s => s.Key, s => s.Count()) 107 .MaxItems(s => s.Value) 108 .FirstOrDefault(); 77 109 78 110 //calculate class distributions for all distinct inputValues … … 119 151 while (sample.inputValue >= splits[splitIndex].thresholdValue) 120 152 splitIndex++; 121 correctClassified += sample.classValue == splits[splitIndex].classValue? 1 : 0;153 correctClassified += sample.classValue.IsAlmost(splits[splitIndex].classValue) ? 1 : 0; 122 154 } 123 155 correctClassified += missingValuesDistribution.Value; … … 133 165 //remove neighboring splits with the same class value 134 166 for (int i = 0; i < bestSplits.Count - 1; i++) { 135 if (bestSplits[i].classValue == bestSplits[i + 1].classValue) {167 if (bestSplits[i].classValue.IsAlmost(bestSplits[i + 1].classValue)) { 136 168 bestSplits.Remove(bestSplits[i]); 137 169 i--; … … 139 171 } 140 172 141 var model = new OneRClassificationModel(problemData.TargetVariable, bestVariable, bestSplits.Select(s => s.thresholdValue).ToArray(), bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass); 142 var solution = new OneRClassificationSolution(model, (IClassificationProblemData)problemData.Clone()); 143 144 return solution; 173 var model = new OneRClassificationModel(problemData.TargetVariable, bestVariable, 174 bestSplits.Select(s => s.thresholdValue).ToArray(), 175 bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass); 176 177 return model; 178 } 179 private static OneFactorClassificationModel FindBestFactorModel(IClassificationProblemData problemData) { 180 var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices); 181 var defaultClass = FindMostFrequentClassValue(classValues); 182 // only select string variables 183 var allowedInputVariables = problemData.AllowedInputVariables.Where(problemData.Dataset.VariableHasType<string>); 184 185 if (!allowedInputVariables.Any()) return null; 186 187 OneFactorClassificationModel bestModel = null; 188 var bestModelNumCorrect = 0; 189 190 foreach (var variable in allowedInputVariables) { 191 var variableValues = problemData.Dataset.GetStringValues(variable, problemData.TrainingIndices); 192 var groupedClassValues = variableValues 193 .Zip(classValues, (v, c) => new KeyValuePair<string, double>(v, c)) 194 .GroupBy(kvp => kvp.Key) 195 .ToDictionary(g => g.Key, g => FindMostFrequentClassValue(g.Select(kvp => kvp.Value))); 196 197 var model = new OneFactorClassificationModel(problemData.TargetVariable, variable, 198 groupedClassValues.Select(kvp => kvp.Key).ToArray(), groupedClassValues.Select(kvp => kvp.Value).ToArray(), defaultClass); 199 200 var modelEstimatedValues = model.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices); 201 var modelNumCorrect = classValues.Zip(modelEstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e); 202 if (modelNumCorrect > bestModelNumCorrect) { 203 bestModelNumCorrect = modelNumCorrect; 204 bestModel = model; 205 } 206 } 207 208 return bestModel; 209 } 210 211 private static double FindMostFrequentClassValue(IEnumerable<double> classValues) { 212 return classValues.GroupBy(c => c).OrderByDescending(g => g.Count()).Select(g => g.Key).First(); 145 213 } 146 214 -
branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers/OneRClassificationModel.cs
r14185 r14242 67 67 this.splits = (double[])original.splits.Clone(); 68 68 this.classes = (double[])original.classes.Clone(); 69 this.missingValuesClass = original.missingValuesClass; 69 70 } 70 71 public override IDeepCloneable Clone(Cloner cloner) { return new OneRClassificationModel(this, cloner); }
Note: See TracChangeset
for help on using the changeset viewer.