Changeset 8430 for branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
- Timestamp:
- 08/08/12 14:04:17 (12 years ago)
- Location:
- branches/HeuristicLab.TimeSeries
- Files:
-
- 9 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.TimeSeries
- Property svn:ignore
-
old new 20 20 bin 21 21 protoc.exe 22 _ReSharper.HeuristicLab.TimeSeries-3.3
-
- Property svn:ignore
-
branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis
- Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Problems.DataAnalysis merged: 7921,7969,8113,8121,8126,8139,8151-8153,8167,8174,8246,8355
- Property svn:mergeinfo changed
-
branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r7268 r8430 37 37 [Creatable("Data Analysis - Ensembles")] 38 38 public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution { 39 private readonly Dictionary<int, double> trainingEvaluationCache = new Dictionary<int, double>(); 40 private readonly Dictionary<int, double> testEvaluationCache = new Dictionary<int, double>(); 41 39 42 public new IClassificationEnsembleModel Model { 40 43 get { return (IClassificationEnsembleModel)base.Model; } … … 85 88 } 86 89 90 trainingEvaluationCache = new Dictionary<int, double>(original.ProblemData.TrainingIndices.Count()); 91 testEvaluationCache = new Dictionary<int, double>(original.ProblemData.TestIndices.Count()); 92 87 93 classificationSolutions = cloner.Clone(original.classificationSolutions); 88 94 RegisterClassificationSolutionsEventHandler(); … … 128 134 } 129 135 136 trainingEvaluationCache = new Dictionary<int, double>(problemData.TrainingIndices.Count()); 137 testEvaluationCache = new Dictionary<int, double>(problemData.TestIndices.Count()); 138 130 139 RegisterClassificationSolutionsEventHandler(); 131 140 classificationSolutions.AddRange(solutions); … … 148 157 public override IEnumerable<double> EstimatedTrainingClassValues { 149 158 get { 150 var rows = ProblemData.TrainingIndizes; 151 var estimatedValuesEnumerators = (from model in Model.Models 152 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 153 .ToList(); 154 var rowsEnumerator = rows.GetEnumerator(); 155 // aggregate to make sure that MoveNext is called for all enumerators 156 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 157 int currentRow = rowsEnumerator.Current; 158 159 var selectedEnumerators = from pair in estimatedValuesEnumerators 160 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 161 select pair.EstimatedValuesEnumerator; 162 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 159 var rows = ProblemData.TrainingIndices; 160 var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys); 161 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 162 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator(); 163 164 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 165 trainingEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 163 166 } 167 168 return rows.Select(row => trainingEvaluationCache[row]); 164 169 } 165 170 } … … 167 172 public override IEnumerable<double> EstimatedTestClassValues { 168 173 get { 169 var rows = ProblemData.TestIndizes; 170 var estimatedValuesEnumerators = (from model in Model.Models 171 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 172 .ToList(); 173 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 174 // aggregate to make sure that MoveNext is called for all enumerators 175 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 176 int currentRow = rowsEnumerator.Current; 177 178 var selectedEnumerators = from pair in estimatedValuesEnumerators 179 where RowIsTestForModel(currentRow, pair.Model) 180 select pair.EstimatedValuesEnumerator; 181 182 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 174 var rows = ProblemData.TestIndices; 175 var rowsToEvaluate = rows.Except(testEvaluationCache.Keys); 176 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 177 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator(); 178 179 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 180 testEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 183 181 } 182 183 return rows.Select(row => testEvaluationCache[row]); 184 } 185 } 186 187 private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IClassificationModel, bool> modelSelectionPredicate) { 188 var estimatedValuesEnumerators = (from model in Model.Models 189 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 190 .ToList(); 191 var rowsEnumerator = rows.GetEnumerator(); 192 // aggregate to make sure that MoveNext is called for all enumerators 193 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 194 int currentRow = rowsEnumerator.Current; 195 196 var selectedEnumerators = from pair in estimatedValuesEnumerators 197 where modelSelectionPredicate(currentRow, pair.Model) 198 select pair.EstimatedValuesEnumerator; 199 200 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 184 201 } 185 202 } … … 196 213 197 214 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 198 return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows) 199 select AggregateEstimatedClassValues(xs); 215 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 216 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 217 var valuesEnumerator = (from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rowsToEvaluate) 218 select AggregateEstimatedClassValues(xs)) 219 .GetEnumerator(); 220 221 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 222 evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 223 } 224 225 return rows.Select(row => evaluationCache[row]); 200 226 } 201 227 … … 223 249 224 250 protected override void OnProblemDataChanged() { 251 trainingEvaluationCache.Clear(); 252 testEvaluationCache.Clear(); 253 evaluationCache.Clear(); 254 225 255 IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset, 226 256 ProblemData.AllowedInputVariables, … … 251 281 public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) { 252 282 classificationSolutions.AddRange(solutions); 283 284 trainingEvaluationCache.Clear(); 285 testEvaluationCache.Clear(); 286 evaluationCache.Clear(); 253 287 } 254 288 public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) { 255 289 classificationSolutions.RemoveRange(solutions); 290 291 trainingEvaluationCache.Clear(); 292 testEvaluationCache.Clear(); 293 evaluationCache.Clear(); 256 294 } 257 295 … … 275 313 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 276 314 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 315 316 trainingEvaluationCache.Clear(); 317 testEvaluationCache.Clear(); 318 evaluationCache.Clear(); 277 319 } 278 320 … … 282 324 trainingPartitions.Remove(solution.Model); 283 325 testPartitions.Remove(solution.Model); 326 327 trainingEvaluationCache.Clear(); 328 testEvaluationCache.Clear(); 329 evaluationCache.Clear(); 284 330 } 285 331 } -
branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblemData.cs
r7842 r8430 207 207 208 208 #region parameter properties 209 public ConstrainedValueParameter<StringValue> TargetVariableParameter {210 get { return ( ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }209 public IConstrainedValueParameter<StringValue> TargetVariableParameter { 210 get { return (IConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; } 211 211 } 212 212 public IFixedValueParameter<StringMatrix> ClassNamesParameter { -
branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs
r7268 r8430 44 44 public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData) 45 45 : base(model, problemData) { 46 evaluationCache = new Dictionary<int, double>( );46 evaluationCache = new Dictionary<int, double>(problemData.Dataset.Rows); 47 47 } 48 48 … … 51 51 } 52 52 public override IEnumerable<double> EstimatedTrainingClassValues { 53 get { return GetEstimatedClassValues(ProblemData.TrainingIndi zes); }53 get { return GetEstimatedClassValues(ProblemData.TrainingIndices); } 54 54 } 55 55 public override IEnumerable<double> EstimatedTestClassValues { 56 get { return GetEstimatedClassValues(ProblemData.TestIndi zes); }56 get { return GetEstimatedClassValues(ProblemData.TestIndices); } 57 57 } 58 58 -
branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolutionBase.cs
r7268 r8430 87 87 protected void CalculateResults() { 88 88 double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values 89 double[] originalTrainingClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndi zes).ToArray();89 double[] originalTrainingClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).ToArray(); 90 90 double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values 91 double[] originalTestClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndi zes).ToArray();91 double[] originalTestClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices).ToArray(); 92 92 93 93 OnlineCalculatorError errorState; -
branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs
r7268 r8430 59 59 } 60 60 public override IEnumerable<double> EstimatedTrainingClassValues { 61 get { return GetEstimatedClassValues(ProblemData.TrainingIndi zes); }61 get { return GetEstimatedClassValues(ProblemData.TrainingIndices); } 62 62 } 63 63 public override IEnumerable<double> EstimatedTestClassValues { 64 get { return GetEstimatedClassValues(ProblemData.TestIndi zes); }64 get { return GetEstimatedClassValues(ProblemData.TestIndices); } 65 65 } 66 66 … … 82 82 } 83 83 public override IEnumerable<double> EstimatedTrainingValues { 84 get { return GetEstimatedValues(ProblemData.TrainingIndi zes); }84 get { return GetEstimatedValues(ProblemData.TrainingIndices); } 85 85 } 86 86 public override IEnumerable<double> EstimatedTestValues { 87 get { return GetEstimatedValues(ProblemData.TestIndi zes); }87 get { return GetEstimatedValues(ProblemData.TestIndices); } 88 88 } 89 89 -
branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolutionBase.cs
r7268 r8430 103 103 protected void CalculateRegressionResults() { 104 104 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 105 double[] originalTrainingValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndi zes).ToArray();105 double[] originalTrainingValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).ToArray(); 106 106 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 107 double[] originalTestValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndi zes).ToArray();107 double[] originalTestValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices).ToArray(); 108 108 109 109 OnlineCalculatorError errorState; … … 140 140 double[] classValues; 141 141 double[] thresholds; 142 var targetClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndi zes);142 var targetClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices); 143 143 AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 144 144 … … 149 149 double[] classValues; 150 150 double[] thresholds; 151 var targetClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndi zes);151 var targetClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices); 152 152 NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 153 153 -
branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/AccuracyMaximizationThresholdCalculator.cs
r7268 r8430 54 54 public static void CalculateThresholds(IClassificationProblemData problemData, IEnumerable<double> estimatedValues, IEnumerable<double> targetClassValues, out double[] classValues, out double[] thresholds) { 55 55 int slices = 100; 56 double minThresholdInc = 10e-5; // necessary to prevent infinite loop when maxEstimated - minEstimated is effectively zero (constant model) 56 57 List<double> estimatedValuesList = estimatedValues.ToList(); 57 58 double maxEstimatedValue = estimatedValuesList.Max(); 58 59 double minEstimatedValue = estimatedValuesList.Min(); 59 double thresholdIncrement = (maxEstimatedValue - minEstimatedValue) / slices;60 double thresholdIncrement = Math.Max((maxEstimatedValue - minEstimatedValue) / slices, minThresholdInc); 60 61 var estimatedAndTargetValuePairs = 61 62 estimatedValuesList.Zip(targetClassValues, (x, y) => new { EstimatedValue = x, TargetClassValue = y }) … … 70 71 71 72 // incrementally calculate accuracy of all possible thresholds 72 int[,] confusionMatrix = new int[nClasses, nClasses];73 74 73 for (int i = 1; i < thresholds.Length; i++) { 75 74 double lowerThreshold = thresholds[i - 1];
Note: See TracChangeset
for help on using the changeset viewer.