Changeset 16536 for branches/2971_named_intervals/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
- Timestamp:
- 01/17/19 14:36:59 (6 years ago)
- Location:
- branches/2971_named_intervals/HeuristicLab.Problems.DataAnalysis
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2971_named_intervals/HeuristicLab.Problems.DataAnalysis
- Property svn:mergeinfo changed
-
branches/2971_named_intervals/HeuristicLab.Problems.DataAnalysis/3.4
-
Property
svn:mergeinfo
set to
(toggle deleted branches)
/branches/2839_HiveProjectManagement/HeuristicLab.Problems.DataAnalysis/3.4 merged eligible /branches/2915-AbsoluteSymbol/HeuristicLab.Problems.DataAnalysis/3.4 merged eligible /branches/2947_ConfigurableIndexedDataTable/HeuristicLab.Problems.DataAnalysis/3.4 merged eligible /branches/2965_CancelablePersistence/HeuristicLab.Problems.DataAnalysis/3.4 merged eligible /stable/HeuristicLab.Problems.DataAnalysis/3.4 merged eligible /trunk/HeuristicLab.Problems.DataAnalysis/3.4 merged eligible /branches/2892_LR-prediction-intervals/HeuristicLab.Problems.DataAnalysis/3.4 15743-16388 /branches/2904_CalculateImpacts/3.4 15808-16421 /branches/2966_interval_calculation/HeuristicLab.Problems.DataAnalysis/3.4 16320-16406 /branches/Async/HeuristicLab.Problems.DataAnalysis/3.4 13329-15286 /branches/Classification-Extensions/HeuristicLab.Problems.DataAnalysis/3.4 11606-11761 /branches/ClassificationModelComparison/HeuristicLab.Problems.DataAnalysis/3.4 9073-13099 /branches/CloningRefactoring/HeuristicLab.Problems.DataAnalysis/3.4 4656-4721 /branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4 5471-5808 /branches/DataAnalysis SolutionEnsembles/HeuristicLab.Problems.DataAnalysis/3.4 5815-6180 /branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis/3.4 4220,4226,4236-4238,4389,4458-4459,4462,4464 /branches/DataAnalysisCSVImport/HeuristicLab.Problems.DataAnalysis/3.4 8713-8875 /branches/DataPreprocessing/HeuristicLab.Problems.DataAnalysis/3.4 10085-11101 /branches/DatasetFeatureCorrelation/HeuristicLab.Problems.DataAnalysis/3.4 8035-8538 /branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4 6284-6795 /branches/GP.Symbols (TimeLag, Diff, Integral)/HeuristicLab.Problems.DataAnalysis/3.4 5060 /branches/HeuristicLab.DatasetRefactor/sources/HeuristicLab.Problems.DataAnalysis/3.4 11570-12508 /branches/HeuristicLab.Problems.Orienteering/HeuristicLab.Problems.DataAnalysis/3.4 11130-12721 /branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Problems.DataAnalysis/3.4 13819-14091 /branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis/3.4 7098-8789 /branches/LogResidualEvaluator/HeuristicLab.Problems.DataAnalysis/3.4 10202-10483 /branches/NET40/sources/HeuristicLab.Problems.DataAnalysis/3.4 5138-5162 /branches/ParallelEngine/HeuristicLab.Problems.DataAnalysis/3.4 5175-5192 /branches/ProblemInstancesRegressionAndClassification/HeuristicLab.Problems.DataAnalysis/3.4 7570-7810 /branches/QAPAlgorithms/HeuristicLab.Problems.DataAnalysis/3.4 6350-6627 /branches/Restructure trunk solution/HeuristicLab.Problems.DataAnalysis/3.4 6828 /branches/SimplifierViewsProgress/HeuristicLab.Problems.DataAnalysis/3.4 15318-15370 /branches/SpectralKernelForGaussianProcesses/HeuristicLab.Problems.DataAnalysis/3.4 10204-10479 /branches/Trunk/HeuristicLab.Problems.DataAnalysis/3.4 6829-6865 /branches/histogram/HeuristicLab.Problems.DataAnalysis/3.4 5959-6341 /branches/symbreg-factors-2650/HeuristicLab.Problems.DataAnalysis/3.4 14232-14825
-
Property
svn:mergeinfo
set to
(toggle deleted branches)
-
branches/2971_named_intervals/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs
r15871 r16536 23 23 24 24 using System; 25 using System.Collections; 25 26 using System.Collections.Generic; 26 27 using System.Linq; … … 36 37 [Item("ClassificationSolution Impacts Calculator", "Calculation of the impacts of input variables for any classification solution")] 37 38 public sealed class ClassificationSolutionVariableImpactsCalculator : ParameterizedNamedItem { 39 #region Parameters/Properties 38 40 public enum ReplacementMethodEnum { 39 41 Median, … … 54 56 55 57 private const string ReplacementParameterName = "Replacement Method"; 58 private const string FactorReplacementParameterName = "Factor Replacement Method"; 56 59 private const string DataPartitionParameterName = "DataPartition"; 57 60 58 61 public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter { 59 62 get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; } 63 } 64 public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter { 65 get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; } 60 66 } 61 67 public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter { … … 67 73 set { ReplacementParameter.Value.Value = value; } 68 74 } 75 public FactorReplacementMethodEnum FactorReplacementMethod { 76 get { return FactorReplacementParameter.Value.Value; } 77 set { FactorReplacementParameter.Value.Value = value; } 78 } 69 79 public DataPartitionEnum DataPartition { 70 80 get { return DataPartitionParameter.Value.Value; } 71 81 set { DataPartitionParameter.Value.Value = value; } 72 82 } 73 74 83 #endregion 84 85 #region Ctor/Cloner 75 86 [StorableConstructor] 76 87 private ClassificationSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { } 77 88 private ClassificationSolutionVariableImpactsCalculator(ClassificationSolutionVariableImpactsCalculator original, Cloner cloner) 78 89 : base(original, cloner) { } 90 public ClassificationSolutionVariableImpactsCalculator() 91 : base() { 92 Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Shuffle))); 93 Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best))); 94 Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training))); 95 } 96 79 97 public override IDeepCloneable Clone(Cloner cloner) { 80 98 return new ClassificationSolutionVariableImpactsCalculator(this, cloner); 81 99 } 82 83 public ClassificationSolutionVariableImpactsCalculator() 84 : base() { 85 Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median))); 86 Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training))); 87 } 100 #endregion 88 101 89 102 //mkommend: annoying name clash with static method, open to better naming suggestions 90 103 public IEnumerable<Tuple<string, double>> Calculate(IClassificationSolution solution) { 91 return CalculateImpacts(solution, DataPartition, ReplacementMethod);104 return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition); 92 105 } 93 106 94 107 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 95 108 IClassificationSolution solution, 96 DataPartitionEnum data = DataPartitionEnum.Training, 97 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median, 109 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 110 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 111 DataPartitionEnum dataPartition = DataPartitionEnum.Training) { 112 113 IEnumerable<int> rows = GetPartitionRows(dataPartition, solution.ProblemData); 114 IEnumerable<double> estimatedClassValues = solution.GetEstimatedClassValues(rows); 115 var model = (IClassificationModel)solution.Model.Clone(); //mkommend: clone of model is necessary, because the thresholds for IDiscriminantClassificationModels are updated 116 117 return CalculateImpacts(model, solution.ProblemData, estimatedClassValues, rows, replacementMethod, factorReplacementMethod); 118 } 119 120 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 121 IClassificationModel model, 122 IClassificationProblemData problemData, 123 IEnumerable<double> estimatedClassValues, 124 IEnumerable<int> rows, 125 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 126 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 127 128 //fholzing: try and catch in case a different dataset is loaded, otherwise statement is neglectable 129 var missingVariables = model.VariablesUsedForPrediction.Except(problemData.Dataset.VariableNames); 130 if (missingVariables.Any()) { 131 throw new InvalidOperationException(string.Format("Can not calculate variable impacts, because the model uses inputs missing in the dataset ({0})", string.Join(", ", missingVariables))); 132 } 133 IEnumerable<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 134 var originalQuality = CalculateQuality(targetValues, estimatedClassValues); 135 136 var impacts = new Dictionary<string, double>(); 137 var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction)); 138 var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable(); 139 140 foreach (var inputVariable in inputvariables) { 141 impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData, modifiableDataset, rows, replacementMethod, factorReplacementMethod, targetValues, originalQuality); 142 } 143 144 return impacts.Select(i => Tuple.Create(i.Key, i.Value)); 145 } 146 147 public static double CalculateImpact(string variableName, 148 IClassificationModel model, 149 IClassificationProblemData problemData, 150 ModifiableDataset modifiableDataset, 151 IEnumerable<int> rows, 152 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 153 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 154 IEnumerable<double> targetValues = null, 155 double quality = double.NaN) { 156 157 if (!model.VariablesUsedForPrediction.Contains(variableName)) { return 0.0; } 158 if (!problemData.Dataset.VariableNames.Contains(variableName)) { 159 throw new InvalidOperationException(string.Format("Can not calculate variable impact, because the model uses inputs missing in the dataset ({0})", variableName)); 160 } 161 162 if (targetValues == null) { 163 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 164 } 165 if (quality == double.NaN) { 166 quality = CalculateQuality(model.GetEstimatedClassValues(modifiableDataset, rows), targetValues); 167 } 168 169 IList originalValues = null; 170 IList replacementValues = GetReplacementValues(modifiableDataset, variableName, model, rows, targetValues, out originalValues, replacementMethod, factorReplacementMethod); 171 172 double newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, replacementValues, targetValues); 173 double impact = quality - newValue; 174 175 return impact; 176 } 177 178 private static IList GetReplacementValues(ModifiableDataset modifiableDataset, 179 string variableName, 180 IClassificationModel model, 181 IEnumerable<int> rows, 182 IEnumerable<double> targetValues, 183 out IList originalValues, 184 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 98 185 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 99 186 100 var problemData = solution.ProblemData; 101 var dataset = problemData.Dataset; 102 var model = (IClassificationModel)solution.Model.Clone(); //mkommend: clone of model is necessary, because the thresholds for IDiscriminantClassificationModels are updated 103 104 IEnumerable<int> rows; 105 IEnumerable<double> targetValues; 106 double originalAccuracy; 107 108 OnlineCalculatorError error; 109 110 switch (data) { 111 case DataPartitionEnum.All: 112 rows = problemData.AllIndices; 113 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.AllIndices).ToList(); 114 originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedClassValues, out error); 115 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation."); 116 break; 117 case DataPartitionEnum.Training: 118 rows = problemData.TrainingIndices; 119 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToList(); 120 originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTrainingClassValues, out error); 121 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation."); 122 break; 123 case DataPartitionEnum.Test: 124 rows = problemData.TestIndices; 125 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToList(); 126 originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTestClassValues, out error); 127 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation."); 128 break; 129 default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data)); 130 } 131 132 var impacts = new Dictionary<string, double>(); 133 var modifiableDataset = ((Dataset)dataset).ToModifiable(); 134 135 var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction)); 136 var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList(); 137 138 // calculate impacts for double variables 139 foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) { 140 var newEstimates = EvaluateModelWithReplacedVariable(model, inputVariable, modifiableDataset, rows, replacementMethod); 141 var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error); 142 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs."); 143 144 impacts[inputVariable] = originalAccuracy - newAccuracy; 145 } 146 147 // calculate impacts for string variables 148 foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) { 149 if (factorReplacementMethod == FactorReplacementMethodEnum.Best) { 150 // try replacing with all possible values and find the best replacement value 151 var smallestImpact = double.PositiveInfinity; 152 foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) { 153 var newEstimates = EvaluateModelWithReplacedVariable(model, inputVariable, modifiableDataset, rows, 154 Enumerable.Repeat(repl, dataset.Rows)); 155 var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error); 156 if (error != OnlineCalculatorError.None) 157 throw new InvalidOperationException("Error during accuracy calculation with replaced inputs."); 158 159 var impact = originalAccuracy - newAccuracy; 160 if (impact < smallestImpact) smallestImpact = impact; 161 } 162 impacts[inputVariable] = smallestImpact; 163 } else { 164 // for replacement methods shuffle and mode 165 // calculate impacts for factor variables 166 167 var newEstimates = EvaluateModelWithReplacedVariable(model, inputVariable, modifiableDataset, rows, 168 factorReplacementMethod); 169 var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error); 170 if (error != OnlineCalculatorError.None) 171 throw new InvalidOperationException("Error during accuracy calculation with replaced inputs."); 172 173 impacts[inputVariable] = originalAccuracy - newAccuracy; 174 } 175 } // foreach 176 return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value)); 177 } 178 179 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) { 180 var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList(); 187 IList replacementValues = null; 188 if (modifiableDataset.VariableHasType<double>(variableName)) { 189 originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList(); 190 replacementValues = GetReplacementValuesForDouble(modifiableDataset, rows, (List<double>)originalValues, replacementMethod); 191 } else if (modifiableDataset.VariableHasType<string>(variableName)) { 192 originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 193 replacementValues = GetReplacementValuesForString(model, modifiableDataset, variableName, rows, (List<string>)originalValues, targetValues, factorReplacementMethod); 194 } else { 195 throw new NotSupportedException("Variable not supported"); 196 } 197 198 return replacementValues; 199 } 200 201 private static IList GetReplacementValuesForDouble(ModifiableDataset modifiableDataset, 202 IEnumerable<int> rows, 203 List<double> originalValues, 204 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle) { 205 206 IRandom random = new FastRandom(31415); 207 List<double> replacementValues; 181 208 double replacementValue; 182 List<double> replacementValues; 183 IRandom rand; 184 185 switch (replacement) { 209 210 switch (replacementMethod) { 186 211 case ReplacementMethodEnum.Median: 187 212 replacementValue = rows.Select(r => originalValues[r]).Median(); 188 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();213 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 189 214 break; 190 215 case ReplacementMethodEnum.Average: 191 216 replacementValue = rows.Select(r => originalValues[r]).Average(); 192 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();217 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 193 218 break; 194 219 case ReplacementMethodEnum.Shuffle: 195 220 // new var has same empirical distribution but the relation to y is broken 196 rand = new FastRandom(31415);197 221 // prepare a complete column for the dataset 198 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();222 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 199 223 // shuffle only the selected rows 200 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand ).ToList();224 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 201 225 int i = 0; 202 226 // update column values … … 208 232 var avg = rows.Select(r => originalValues[r]).Average(); 209 233 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 210 rand = new FastRandom(31415);211 234 // prepare a complete column for the dataset 212 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();235 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 213 236 // update column values 214 237 foreach (var r in rows) { 215 replacementValues[r] = NormalDistributedRandom.NextDouble(rand , avg, stdDev);238 replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev); 216 239 } 217 240 break; 218 241 219 242 default: 220 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement)); 221 } 222 223 return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues); 224 } 225 226 private static IEnumerable<double> EvaluateModelWithReplacedVariable( 227 IClassificationModel model, string variable, ModifiableDataset dataset, 228 IEnumerable<int> rows, 229 FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) { 230 var originalValues = dataset.GetReadOnlyStringValues(variable).ToList(); 231 List<string> replacementValues; 232 IRandom rand; 233 234 switch (replacement) { 243 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod)); 244 } 245 246 return replacementValues; 247 } 248 249 private static IList GetReplacementValuesForString(IClassificationModel model, 250 ModifiableDataset modifiableDataset, 251 string variableName, 252 IEnumerable<int> rows, 253 List<string> originalValues, 254 IEnumerable<double> targetValues, 255 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Shuffle) { 256 257 List<string> replacementValues = null; 258 IRandom random = new FastRandom(31415); 259 260 switch (factorReplacementMethod) { 261 case FactorReplacementMethodEnum.Best: 262 // try replacing with all possible values and find the best replacement value 263 var bestQuality = double.NegativeInfinity; 264 foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { 265 List<string> curReplacementValues = Enumerable.Repeat(repl, modifiableDataset.Rows).ToList(); 266 //fholzing: this result could be used later on (theoretically), but is neglected for better readability/method consistency 267 var newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, curReplacementValues, targetValues); 268 var curQuality = newValue; 269 270 if (curQuality > bestQuality) { 271 bestQuality = curQuality; 272 replacementValues = curReplacementValues; 273 } 274 } 275 break; 235 276 case FactorReplacementMethodEnum.Mode: 236 277 var mostCommonValue = rows.Select(r => originalValues[r]) … … 238 279 .OrderByDescending(g => g.Count()) 239 280 .First().Key; 240 replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();281 replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList(); 241 282 break; 242 283 case FactorReplacementMethodEnum.Shuffle: 243 284 // new var has same empirical distribution but the relation to y is broken 244 rand = new FastRandom(31415);245 285 // prepare a complete column for the dataset 246 replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();286 replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList(); 247 287 // shuffle only the selected rows 248 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand ).ToList();288 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 249 289 int i = 0; 250 290 // update column values … … 254 294 break; 255 295 default: 256 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement)); 257 } 258 259 return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues); 260 } 261 262 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, 263 ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) { 264 var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList(); 265 dataset.ReplaceVariable(variable, replacementValues.ToList()); 266 296 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod)); 297 } 298 299 return replacementValues; 300 } 301 302 private static double CalculateQualityForReplacement( 303 IClassificationModel model, 304 ModifiableDataset modifiableDataset, 305 string variableName, 306 IList originalValues, 307 IEnumerable<int> rows, 308 IList replacementValues, 309 IEnumerable<double> targetValues) { 310 311 modifiableDataset.ReplaceVariable(variableName, replacementValues); 267 312 var discModel = model as IDiscriminantFunctionClassificationModel; 268 313 if (discModel != null) { 269 var problemData = new ClassificationProblemData( dataset, dataset.VariableNames, model.TargetVariable);314 var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable); 270 315 discModel.RecalculateModelParameters(problemData, rows); 271 316 } 272 317 273 318 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 274 var estimates = model.GetEstimatedClassValues(dataset, rows).ToList(); 275 dataset.ReplaceVariable(variable, originalValues); 276 277 return estimates; 278 } 279 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, 280 ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) { 281 var originalValues = dataset.GetReadOnlyStringValues(variable).ToList(); 282 dataset.ReplaceVariable(variable, replacementValues.ToList()); 283 284 285 var discModel = model as IDiscriminantFunctionClassificationModel; 286 if (discModel != null) { 287 var problemData = new ClassificationProblemData(dataset, dataset.VariableNames, model.TargetVariable); 288 discModel.RecalculateModelParameters(problemData, rows); 289 } 290 291 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 292 var estimates = model.GetEstimatedClassValues(dataset, rows).ToList(); 293 dataset.ReplaceVariable(variable, originalValues); 294 295 return estimates; 319 var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList(); 320 var ret = CalculateQuality(targetValues, estimates); 321 modifiableDataset.ReplaceVariable(variableName, originalValues); 322 323 return ret; 324 } 325 326 public static double CalculateQuality(IEnumerable<double> targetValues, IEnumerable<double> estimatedClassValues) { 327 OnlineCalculatorError errorState; 328 var ret = OnlineAccuracyCalculator.Calculate(targetValues, estimatedClassValues, out errorState); 329 if (errorState != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 330 return ret; 331 } 332 333 public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IClassificationProblemData problemData) { 334 IEnumerable<int> rows; 335 336 switch (dataPartition) { 337 case DataPartitionEnum.All: 338 rows = problemData.AllIndices; 339 break; 340 case DataPartitionEnum.Test: 341 rows = problemData.TestIndices; 342 break; 343 case DataPartitionEnum.Training: 344 rows = problemData.TrainingIndices; 345 break; 346 default: 347 throw new NotSupportedException("DataPartition not supported"); 348 } 349 350 return rows; 296 351 } 297 352 }
Note: See TracChangeset
for help on using the changeset viewer.