Changeset 16189
- Timestamp:
- 09/27/18 09:51:35 (6 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2904_CalculateImpacts/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs
r16188 r16189 23 23 24 24 using System; 25 using System.Collections; 25 26 using System.Collections.Generic; 26 27 using System.Linq; … … 36 37 [Item("ClassificationSolution Impacts Calculator", "Calculation of the impacts of input variables for any classification solution")] 37 38 public sealed class ClassificationSolutionVariableImpactsCalculator : ParameterizedNamedItem { 39 #region Parameters/Properties 38 40 public enum ReplacementMethodEnum { 39 41 Median, … … 54 56 55 57 private const string ReplacementParameterName = "Replacement Method"; 58 private const string FactorReplacementParameterName = "Factor Replacement Method"; 56 59 private const string DataPartitionParameterName = "DataPartition"; 57 60 58 61 public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter { 59 62 get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; } 63 } 64 public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter { 65 get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; } 60 66 } 61 67 public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter { … … 67 73 set { ReplacementParameter.Value.Value = value; } 68 74 } 75 public FactorReplacementMethodEnum FactorReplacementMethod { 76 get { return FactorReplacementParameter.Value.Value; } 77 set { FactorReplacementParameter.Value.Value = value; } 78 } 69 79 public DataPartitionEnum DataPartition { 70 80 get { return DataPartitionParameter.Value.Value; } 71 81 set { DataPartitionParameter.Value.Value = value; } 72 82 } 73 74 83 #endregion 84 85 #region Ctor/Cloner 75 86 [StorableConstructor] 76 87 private ClassificationSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { } 77 88 private ClassificationSolutionVariableImpactsCalculator(ClassificationSolutionVariableImpactsCalculator original, Cloner cloner) 78 89 : base(original, cloner) { } 90 public ClassificationSolutionVariableImpactsCalculator() 91 : base() { 92 Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Shuffle))); 93 Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best))); 94 Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training))); 95 } 96 79 97 public override IDeepCloneable Clone(Cloner cloner) { 80 98 return new ClassificationSolutionVariableImpactsCalculator(this, cloner); 81 99 } 82 83 public ClassificationSolutionVariableImpactsCalculator() 84 : base() { 85 Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median))); 86 Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training))); 87 } 100 #endregion 88 101 89 102 //mkommend: annoying name clash with static method, open to better naming suggestions 90 103 public IEnumerable<Tuple<string, double>> Calculate(IClassificationSolution solution) { 91 return CalculateImpacts(solution, DataPartition, ReplacementMethod);104 return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition); 92 105 } 93 106 94 107 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 95 108 IClassificationSolution solution, 96 DataPartitionEnum data = DataPartitionEnum.Training, 97 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median, 109 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 110 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 111 DataPartitionEnum dataPartition = DataPartitionEnum.Training) { 112 113 IEnumerable<int> rows = GetPartitionRows(dataPartition, solution.ProblemData); 114 IEnumerable<double> estimatedClassValues = solution.GetEstimatedClassValues(rows); 115 return CalculateImpacts(solution.Model, solution.ProblemData, estimatedClassValues, rows, replacementMethod, factorReplacementMethod); 116 } 117 118 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 119 IClassificationModel model, 120 IClassificationProblemData problemData, 121 IEnumerable<double> estimatedClassValues, 122 IEnumerable<int> rows, 123 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 124 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 125 126 //fholzing: try and catch in case a different dataset is loaded, otherwise statement is neglectable 127 var missingVariables = model.VariablesUsedForPrediction.Except(problemData.Dataset.VariableNames); 128 if (missingVariables.Any()) { 129 throw new InvalidOperationException(string.Format("Can not calculate variable impacts, because the model uses inputs missing in the dataset ({0})", string.Join(", ", missingVariables))); 130 } 131 IEnumerable<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 132 var originalQuality = CalculateQuality(targetValues, estimatedClassValues); 133 134 var impacts = new Dictionary<string, double>(); 135 var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction)); 136 var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable(); 137 138 foreach (var inputVariable in inputvariables) { 139 impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData, modifiableDataset, rows, replacementMethod, factorReplacementMethod, targetValues, originalQuality); 140 } 141 142 return impacts.Select(i => Tuple.Create(i.Key, i.Value)); 143 } 144 145 public static double CalculateImpact(string variableName, 146 IClassificationModel model, 147 IClassificationProblemData problemData, 148 ModifiableDataset modifiableDataset, 149 IEnumerable<int> rows, 150 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 151 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 152 IEnumerable<double> targetValues = null, 153 double quality = double.NaN) { 154 155 if (!model.VariablesUsedForPrediction.Contains(variableName)) { return 0.0; } 156 if (!problemData.Dataset.VariableNames.Contains(variableName)) { 157 throw new InvalidOperationException(string.Format("Can not calculate variable impact, because the model uses inputs missing in the dataset ({0})", variableName)); 158 } 159 160 if (targetValues == null) { 161 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 162 } 163 if (quality == double.NaN) { 164 quality = CalculateQuality(model.GetEstimatedClassValues(modifiableDataset, rows), targetValues); 165 } 166 167 IList originalValues = null; 168 IList replacementValues = GetReplacementValues(modifiableDataset, variableName, model, rows, targetValues, out originalValues, replacementMethod, factorReplacementMethod); 169 170 double newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, replacementValues, targetValues); 171 double impact = quality - newValue; 172 173 return impact; 174 } 175 176 private static IList GetReplacementValues(ModifiableDataset modifiableDataset, 177 string variableName, 178 IClassificationModel model, 179 IEnumerable<int> rows, 180 IEnumerable<double> targetValues, 181 out IList originalValues, 182 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 98 183 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 99 184 100 var problemData = solution.ProblemData; 101 var dataset = problemData.Dataset; 102 103 IEnumerable<int> rows; 104 IEnumerable<double> targetValues; 105 double originalAccuracy; 106 107 OnlineCalculatorError error; 108 109 switch (data) { 110 case DataPartitionEnum.All: 111 rows = problemData.AllIndices; 112 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.AllIndices).ToList(); 113 originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedClassValues, out error); 114 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation."); 115 break; 116 case DataPartitionEnum.Training: 117 rows = problemData.TrainingIndices; 118 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToList(); 119 originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTrainingClassValues, out error); 120 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation."); 121 break; 122 case DataPartitionEnum.Test: 123 rows = problemData.TestIndices; 124 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToList(); 125 originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTestClassValues, out error); 126 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation."); 127 break; 128 default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data)); 129 } 130 131 var impacts = new Dictionary<string, double>(); 132 var modifiableDataset = ((Dataset)dataset).ToModifiable(); 133 134 var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction)); 135 var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList(); 136 137 // calculate impacts for double variables 138 foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) { 139 var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, replacementMethod); 140 var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error); 141 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs."); 142 143 impacts[inputVariable] = originalAccuracy - newAccuracy; 144 } 145 146 // calculate impacts for string variables 147 foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) { 148 if (factorReplacementMethod == FactorReplacementMethodEnum.Best) { 149 // try replacing with all possible values and find the best replacement value 150 var smallestImpact = double.PositiveInfinity; 151 foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) { 152 var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, 153 Enumerable.Repeat(repl, dataset.Rows)); 154 var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error); 155 if (error != OnlineCalculatorError.None) 156 throw new InvalidOperationException("Error during accuracy calculation with replaced inputs."); 157 158 var impact = originalAccuracy - newAccuracy; 159 if (impact < smallestImpact) smallestImpact = impact; 160 } 161 impacts[inputVariable] = smallestImpact; 162 } else { 163 // for replacement methods shuffle and mode 164 // calculate impacts for factor variables 165 166 var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, 167 factorReplacementMethod); 168 var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error); 169 if (error != OnlineCalculatorError.None) 170 throw new InvalidOperationException("Error during accuracy calculation with replaced inputs."); 171 172 impacts[inputVariable] = originalAccuracy - newAccuracy; 173 } 174 } // foreach 175 return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value)); 176 } 177 178 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) { 179 var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList(); 185 IList replacementValues = null; 186 if (modifiableDataset.VariableHasType<double>(variableName)) { 187 originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList(); 188 replacementValues = GetReplacementValuesForDouble(modifiableDataset, rows, (List<double>)originalValues, replacementMethod); 189 } else if (modifiableDataset.VariableHasType<string>(variableName)) { 190 originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 191 replacementValues = GetReplacementValuesForString(model, modifiableDataset, variableName, rows, (List<string>)originalValues, targetValues, factorReplacementMethod); 192 } else { 193 throw new NotSupportedException("Variable not supported"); 194 } 195 196 return replacementValues; 197 } 198 199 private static IList GetReplacementValuesForDouble(ModifiableDataset modifiableDataset, 200 IEnumerable<int> rows, 201 List<double> originalValues, 202 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle) { 203 204 IRandom random = new FastRandom(31415); 205 List<double> replacementValues; 180 206 double replacementValue; 181 List<double> replacementValues; 182 IRandom rand; 183 184 switch (replacement) { 207 208 switch (replacementMethod) { 185 209 case ReplacementMethodEnum.Median: 186 210 replacementValue = rows.Select(r => originalValues[r]).Median(); 187 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();211 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 188 212 break; 189 213 case ReplacementMethodEnum.Average: 190 214 replacementValue = rows.Select(r => originalValues[r]).Average(); 191 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();215 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 192 216 break; 193 217 case ReplacementMethodEnum.Shuffle: 194 218 // new var has same empirical distribution but the relation to y is broken 195 rand = new FastRandom(31415);196 219 // prepare a complete column for the dataset 197 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();220 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 198 221 // shuffle only the selected rows 199 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand ).ToList();222 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 200 223 int i = 0; 201 224 // update column values … … 207 230 var avg = rows.Select(r => originalValues[r]).Average(); 208 231 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 209 rand = new FastRandom(31415);210 232 // prepare a complete column for the dataset 211 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();233 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 212 234 // update column values 213 235 foreach (var r in rows) { 214 replacementValues[r] = NormalDistributedRandom.NextDouble(rand , avg, stdDev);236 replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev); 215 237 } 216 238 break; 217 239 218 240 default: 219 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement)); 220 } 221 222 return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues); 223 } 224 225 private static IEnumerable<double> EvaluateModelWithReplacedVariable( 226 IClassificationModel model, string variable, ModifiableDataset dataset, 227 IEnumerable<int> rows, 228 FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) { 229 var originalValues = dataset.GetReadOnlyStringValues(variable).ToList(); 230 List<string> replacementValues; 231 IRandom rand; 232 233 switch (replacement) { 241 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod)); 242 } 243 244 return replacementValues; 245 } 246 247 private static IList GetReplacementValuesForString(IClassificationModel model, 248 ModifiableDataset modifiableDataset, 249 string variableName, 250 IEnumerable<int> rows, 251 List<string> originalValues, 252 IEnumerable<double> targetValues, 253 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Shuffle) { 254 255 List<string> replacementValues = null; 256 IRandom random = new FastRandom(31415); 257 258 switch (factorReplacementMethod) { 259 case FactorReplacementMethodEnum.Best: 260 // try replacing with all possible values and find the best replacement value 261 var bestQuality = double.NegativeInfinity; 262 foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { 263 List<string> curReplacementValues = Enumerable.Repeat(repl, modifiableDataset.Rows).ToList(); 264 //fholzing: this result could be used later on (theoretically), but is neglected for better readability/method consistency 265 var newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, curReplacementValues, targetValues); 266 var curQuality = newValue; 267 268 if (curQuality > bestQuality) { 269 bestQuality = curQuality; 270 replacementValues = curReplacementValues; 271 } 272 } 273 break; 234 274 case FactorReplacementMethodEnum.Mode: 235 275 var mostCommonValue = rows.Select(r => originalValues[r]) … … 237 277 .OrderByDescending(g => g.Count()) 238 278 .First().Key; 239 replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();279 replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList(); 240 280 break; 241 281 case FactorReplacementMethodEnum.Shuffle: 242 282 // new var has same empirical distribution but the relation to y is broken 243 rand = new FastRandom(31415);244 283 // prepare a complete column for the dataset 245 replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();284 replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList(); 246 285 // shuffle only the selected rows 247 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand ).ToList();286 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 248 287 int i = 0; 249 288 // update column values … … 253 292 break; 254 293 default: 255 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement)); 256 } 257 258 return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues); 259 } 260 261 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, 262 ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) { 263 var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList(); 264 dataset.ReplaceVariable(variable, replacementValues.ToList()); 294 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod)); 295 } 296 297 return replacementValues; 298 } 299 300 private static double CalculateQualityForReplacement( 301 IClassificationModel model, 302 ModifiableDataset modifiableDataset, 303 string variableName, 304 IList originalValues, 305 IEnumerable<int> rows, 306 IList replacementValues, 307 IEnumerable<double> targetValues) { 308 309 modifiableDataset.ReplaceVariable(variableName, replacementValues); 310 var discModel = model as IDiscriminantFunctionClassificationModel; 311 if (discModel != null) { 312 var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable); 313 discModel.RecalculateModelParameters(problemData, rows); 314 } 315 265 316 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 266 var estimates = model.GetEstimatedClassValues(dataset, rows).ToList(); 267 dataset.ReplaceVariable(variable, originalValues); 268 269 return estimates; 270 } 271 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, 272 ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) { 273 var originalValues = dataset.GetReadOnlyStringValues(variable).ToList(); 274 dataset.ReplaceVariable(variable, replacementValues.ToList()); 275 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 276 var estimates = model.GetEstimatedClassValues(dataset, rows).ToList(); 277 dataset.ReplaceVariable(variable, originalValues); 278 279 return estimates; 317 var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList(); 318 var ret = CalculateQuality(targetValues, estimates); 319 modifiableDataset.ReplaceVariable(variableName, originalValues); 320 321 return ret; 322 } 323 324 public static double CalculateQuality(IEnumerable<double> targetValues, IEnumerable<double> estimatedClassValues) { 325 OnlineCalculatorError errorState; 326 var ret = OnlineAccuracyCalculator.Calculate(targetValues, estimatedClassValues, out errorState); 327 if (errorState != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 328 return ret; 329 } 330 331 public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IClassificationProblemData problemData) { 332 IEnumerable<int> rows; 333 334 switch (dataPartition) { 335 case DataPartitionEnum.All: 336 rows = problemData.AllIndices; 337 break; 338 case DataPartitionEnum.Test: 339 rows = problemData.TestIndices; 340 break; 341 case DataPartitionEnum.Training: 342 rows = problemData.TrainingIndices; 343 break; 344 default: 345 throw new NotSupportedException("DataPartition not supported"); 346 } 347 348 return rows; 280 349 } 281 350 }
Note: See TracChangeset
for help on using the changeset viewer.