- Timestamp:
- 08/06/18 17:35:11 (6 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2904_CalculateImpacts/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs
r16041 r16055 90 90 public ClassificationSolutionVariableImpactsCalculator() 91 91 : base() { 92 Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum. Median)));92 Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Shuffle))); 93 93 Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best))); 94 94 Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training))); … … 110 110 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 111 111 DataPartitionEnum dataPartition = DataPartitionEnum.Training) { 112 return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedClassValues, replacementMethod, factorReplacementMethod, dataPartition); 113 } 114 115 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 116 IClassificationModel model, 117 IClassificationProblemData problemData, 118 IEnumerable<double> estimatedValues, 119 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 120 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 121 DataPartitionEnum dataPartition = DataPartitionEnum.Training) { 122 IEnumerable<int> rows = GetPartitionRows(dataPartition, problemData); 123 return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod); 112 113 IEnumerable<int> rows = GetPartitionRows(dataPartition, solution.ProblemData); 114 IEnumerable<double> estimatedClassValues = solution.GetEstimatedClassValues(rows); 115 return CalculateImpacts(solution.Model, solution.ProblemData, estimatedClassValues, rows, replacementMethod, factorReplacementMethod); 124 116 } 125 117 … … 131 123 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 132 124 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 133 //Calculate original quality-values (via calculator, default is Accuracy) 134 OnlineCalculatorError error; 135 IEnumerable<double> targetValuesPartition = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 136 IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedClassValues.ElementAt(v)); 137 var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error); 138 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation."); 125 126 //fholzing: try and catch in case a different dataset is loaded, otherwise statement is neglectable 127 var missingVariables = model.VariablesUsedForPrediction.Except(problemData.Dataset.VariableNames); 128 if (missingVariables.Any()) { 129 throw new InvalidOperationException(string.Format("Can not calculate variable impacts, because the model uses inputs missing in the dataset ({0})", string.Join(", ", missingVariables))); 130 } 131 IEnumerable<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 132 var originalQuality = CalculateQuality(targetValues, estimatedClassValues); 139 133 140 134 var impacts = new Dictionary<string, double>(); 141 135 var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction)); 142 var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();143 136 var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable(); 144 137 145 foreach (var inputVariable in allowedInputVariables) { 146 if (model.VariablesUsedForPrediction.Contains(inputVariable)) { 147 impacts[inputVariable] = CalculateImpact(inputVariable, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod); 148 } else { 149 impacts[inputVariable] = 0; 150 } 151 } 152 153 return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value)); 138 foreach (var inputVariable in inputvariables) { 139 impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData, modifiableDataset, rows, replacementMethod, factorReplacementMethod, targetValues, originalQuality); 140 } 141 142 return impacts.Select(i => Tuple.Create(i.Key, i.Value)); 154 143 } 155 144 156 145 public static double CalculateImpact(string variableName, 157 146 IClassificationModel model, 147 IClassificationProblemData problemData, 158 148 ModifiableDataset modifiableDataset, 159 149 IEnumerable<int> rows, 150 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 151 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 152 IEnumerable<double> targetValues = null, 153 double quality = double.NaN) { 154 155 if (!model.VariablesUsedForPrediction.Contains(variableName)) { return 0.0; } 156 if (!problemData.Dataset.VariableNames.Contains(variableName)) { 157 throw new InvalidOperationException(string.Format("Can not calculate variable impact, because the model uses inputs missing in the dataset ({0})", variableName)); 158 } 159 160 if (targetValues == null) { 161 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 162 } 163 if (quality == double.NaN) { 164 quality = CalculateQuality(model.GetEstimatedClassValues(modifiableDataset, rows), targetValues); 165 } 166 167 IList originalValues = null; 168 IList replacementValues = GetReplacementValues(modifiableDataset, variableName, model, rows, targetValues, out originalValues, replacementMethod, factorReplacementMethod); 169 170 double newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, replacementValues, targetValues); 171 double impact = quality - newValue; 172 173 return impact; 174 } 175 176 private static IList GetReplacementValues(ModifiableDataset modifiableDataset, 177 string variableName, 178 IClassificationModel model, 179 IEnumerable<int> rows, 160 180 IEnumerable<double> targetValues, 161 double originalValue,181 out IList originalValues, 162 182 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 163 183 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 164 double impact = 0; 165 OnlineCalculatorError error; 166 IRandom random; 167 double replacementValue; 168 IEnumerable<double> newEstimates = null; 169 double newValue = 0; 170 184 185 IList replacementValues = null; 171 186 if (modifiableDataset.VariableHasType<double>(variableName)) { 172 #region NumericalVariable 173 var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList(); 174 List<double> replacementValues; 175 176 switch (replacementMethod) { 177 case ReplacementMethodEnum.Median: 178 replacementValue = rows.Select(r => originalValues[r]).Median(); 179 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 180 break; 181 case ReplacementMethodEnum.Average: 182 replacementValue = rows.Select(r => originalValues[r]).Average(); 183 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 184 break; 185 case ReplacementMethodEnum.Shuffle: 186 // new var has same empirical distribution but the relation to y is broken 187 random = new FastRandom(31415); 188 // prepare a complete column for the dataset 189 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 190 // shuffle only the selected rows 191 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 192 int i = 0; 193 // update column values 194 foreach (var r in rows) { 195 replacementValues[r] = shuffledValues[i++]; 196 } 197 break; 198 case ReplacementMethodEnum.Noise: 199 var avg = rows.Select(r => originalValues[r]).Average(); 200 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 201 random = new FastRandom(31415); 202 // prepare a complete column for the dataset 203 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 204 // update column values 205 foreach (var r in rows) { 206 replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev); 207 } 208 break; 209 210 default: 211 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod)); 212 } 213 214 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 215 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 216 if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 217 218 impact = originalValue - newValue; 219 #endregion 187 originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList(); 188 replacementValues = GetReplacementValuesForDouble(modifiableDataset, rows, (List<double>)originalValues, replacementMethod); 220 189 } else if (modifiableDataset.VariableHasType<string>(variableName)) { 221 #region FactorVariable 222 var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 223 List<string> replacementValues; 224 225 switch (factorReplacementMethod) { 226 case FactorReplacementMethodEnum.Best: 227 // try replacing with all possible values and find the best replacement value 228 var smallestImpact = double.PositiveInfinity; 229 foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { 230 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList()); 231 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 232 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 233 234 var curImpact = originalValue - newValue; 235 if (curImpact < smallestImpact) smallestImpact = curImpact; 236 } 237 impact = smallestImpact; 238 break; 239 case FactorReplacementMethodEnum.Mode: 240 var mostCommonValue = rows.Select(r => originalValues[r]) 241 .GroupBy(v => v) 242 .OrderByDescending(g => g.Count()) 243 .First().Key; 244 replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList(); 245 246 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 247 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 248 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 249 250 impact = originalValue - newValue; 251 break; 252 case FactorReplacementMethodEnum.Shuffle: 253 // new var has same empirical distribution but the relation to y is broken 254 random = new FastRandom(31415); 255 // prepare a complete column for the dataset 256 replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList(); 257 // shuffle only the selected rows 258 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 259 int i = 0; 260 // update column values 261 foreach (var r in rows) { 262 replacementValues[r] = shuffledValues[i++]; 263 } 264 265 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 266 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 267 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 268 269 impact = originalValue - newValue; 270 break; 271 default: 272 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod)); 273 } 274 #endregion 190 originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 191 replacementValues = GetReplacementValuesForString(model, modifiableDataset, variableName, rows, originalValues, targetValues, factorReplacementMethod); 275 192 } else { 276 193 throw new NotSupportedException("Variable not supported"); 277 194 } 278 195 279 return impact; 280 } 281 282 /// <summary> 283 /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values 284 /// and changes the value of the model-variables back to the original ones. 285 /// </summary> 286 /// <param name="originalValues"></param> 287 /// <param name="model"></param> 288 /// <param name="variableName"></param> 289 /// <param name="modifiableDataset"></param> 290 /// <param name="rows"></param> 291 /// <param name="replacementValues"></param> 292 /// <returns></returns> 293 private static IEnumerable<double> GetReplacedEstimates( 196 return replacementValues; 197 } 198 199 private static IList GetReplacementValuesForDouble(ModifiableDataset modifiableDataset, 200 IEnumerable<int> rows, 201 List<double> originalValues, 202 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle) { 203 204 IRandom random = new FastRandom(31415); 205 List<double> replacementValues; 206 double replacementValue; 207 208 switch (replacementMethod) { 209 case ReplacementMethodEnum.Median: 210 replacementValue = rows.Select(r => originalValues[r]).Median(); 211 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 212 break; 213 case ReplacementMethodEnum.Average: 214 replacementValue = rows.Select(r => originalValues[r]).Average(); 215 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 216 break; 217 case ReplacementMethodEnum.Shuffle: 218 // new var has same empirical distribution but the relation to y is broken 219 // prepare a complete column for the dataset 220 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 221 // shuffle only the selected rows 222 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 223 int i = 0; 224 // update column values 225 foreach (var r in rows) { 226 replacementValues[r] = shuffledValues[i++]; 227 } 228 break; 229 case ReplacementMethodEnum.Noise: 230 var avg = rows.Select(r => originalValues[r]).Average(); 231 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 232 // prepare a complete column for the dataset 233 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 234 // update column values 235 foreach (var r in rows) { 236 replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev); 237 } 238 break; 239 240 default: 241 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod)); 242 } 243 244 return replacementValues; 245 } 246 247 private static IList GetReplacementValuesForString(IClassificationModel model, 248 ModifiableDataset modifiableDataset, 249 string variableName, 250 IEnumerable<int> rows, 294 251 IList originalValues, 252 IEnumerable<double> targetValues, 253 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Shuffle) { 254 255 IList replacementValues = null; 256 IRandom random = new FastRandom(31415); 257 258 switch (factorReplacementMethod) { 259 case FactorReplacementMethodEnum.Best: 260 // try replacing with all possible values and find the best replacement value 261 var bestQuality = double.NegativeInfinity; 262 foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { 263 List<string> curReplacementValues = Enumerable.Repeat(repl, modifiableDataset.Rows).ToList(); 264 //fholzing: this result could be used later on (theoretically), but is neglected for better readability/method consistency 265 var newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, curReplacementValues, targetValues); 266 var curQuality = newValue; 267 268 if (curQuality > bestQuality) { 269 bestQuality = curQuality; 270 replacementValues = curReplacementValues; 271 } 272 } 273 break; 274 case FactorReplacementMethodEnum.Mode: 275 var mostCommonValue = rows.Select(r => originalValues[r]) 276 .GroupBy(v => v) 277 .OrderByDescending(g => g.Count()) 278 .First().Key; 279 replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList(); 280 break; 281 case FactorReplacementMethodEnum.Shuffle: 282 // new var has same empirical distribution but the relation to y is broken 283 // prepare a complete column for the dataset 284 replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList(); 285 // shuffle only the selected rows 286 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 287 int i = 0; 288 // update column values 289 foreach (var r in rows) { 290 replacementValues[r] = shuffledValues[i++]; 291 } 292 break; 293 default: 294 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod)); 295 } 296 297 return replacementValues; 298 } 299 300 private static double CalculateQualityForReplacement( 295 301 IClassificationModel model, 302 ModifiableDataset modifiableDataset, 296 303 string variableName, 297 ModifiableDataset modifiableDataset, 298 IEnumerable<int> rows, 299 IList replacementValues) { 304 IList originalValues, 305 IEnumerable<int> rows, 306 IList replacementValues, 307 IEnumerable<double> targetValues) { 308 300 309 modifiableDataset.ReplaceVariable(variableName, replacementValues); 301 302 310 var discModel = model as IDiscriminantFunctionClassificationModel; 303 311 if (discModel != null) { … … 308 316 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 309 317 var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList(); 318 var ret = CalculateQuality(targetValues, estimates); 310 319 modifiableDataset.ReplaceVariable(variableName, originalValues); 311 320 312 return estimates; 313 } 314 315 /// <summary> 316 /// Calculates and returns the VariableImpact (calculated via Accuracy). 317 /// </summary> 318 /// <param name="targetValues">The actual values</param> 319 /// <param name="estimatedValues">The calculated/replaced values</param> 320 /// <param name="errorState"></param> 321 /// <returns></returns> 322 public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) { 323 //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality 324 //as the code below does. But this way we can easily swap the calculator later on, so the user 325 //could choose a Calculator during runtime in future versions. 326 IOnlineCalculator calculator = new OnlineAccuracyCalculator(); 327 IEnumerator<double> firstEnumerator = targetValues.GetEnumerator(); 328 IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator(); 329 330 // always move forward both enumerators (do not use short-circuit evaluation!) 331 while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) { 332 double original = firstEnumerator.Current; 333 double estimated = secondEnumerator.Current; 334 calculator.Add(original, estimated); 335 if (calculator.ErrorState != OnlineCalculatorError.None) break; 336 } 337 338 // check if both enumerators are at the end to make sure both enumerations have the same length 339 if (calculator.ErrorState == OnlineCalculatorError.None && 340 (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) { 341 throw new ArgumentException("Number of elements in first and second enumeration doesn't match."); 342 } else { 343 errorState = calculator.ErrorState; 344 return calculator.Value; 345 } 346 } 347 348 /// <summary> 349 /// Returns a collection of the row-indices for a given DataPartition (training or test) 350 /// </summary> 351 /// <param name="dataPartition"></param> 352 /// <param name="problemData"></param> 353 /// <returns></returns> 321 return ret; 322 } 323 324 public static double CalculateQuality(IEnumerable<double> targetValues, IEnumerable<double> estimatedClassValues) { 325 OnlineCalculatorError errorState; 326 var ret = OnlineAccuracyCalculator.Calculate(targetValues, estimatedClassValues, out errorState); 327 if (errorState != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 328 return ret; 329 } 330 354 331 public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IClassificationProblemData problemData) { 355 332 IEnumerable<int> rows;
Note: See TracChangeset
for help on using the changeset viewer.