Changeset 16036 for branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs
- Timestamp:
- 08/01/18 14:01:08 (3 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs
r16035 r16036 100 100 #endregion 101 101 102 #region Public Methods/Wrappers103 102 //mkommend: annoying name clash with static method, open to better naming suggestions 104 103 public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) { … … 159 158 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 160 159 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 161 162 160 double impact = 0; 163 164 // calculate impacts for double variables 161 OnlineCalculatorError error; 162 IRandom random; 163 double replacementValue; 164 IEnumerable<double> newEstimates = null; 165 double newValue = 0; 166 165 167 if (modifiableDataset.VariableHasType<double>(variableName)) { 166 impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod); 168 #region NumericalVariable 169 var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList(); 170 List<double> replacementValues; 171 172 switch (replacementMethod) { 173 case ReplacementMethodEnum.Median: 174 replacementValue = rows.Select(r => originalValues[r]).Median(); 175 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 176 break; 177 case ReplacementMethodEnum.Average: 178 replacementValue = rows.Select(r => originalValues[r]).Average(); 179 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 180 break; 181 case ReplacementMethodEnum.Shuffle: 182 // new var has same empirical distribution but the relation to y is broken 183 random = new FastRandom(31415); 184 // prepare a complete column for the dataset 185 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 186 // shuffle only the selected rows 187 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 188 int i = 0; 189 // update column values 190 foreach (var r in rows) { 191 replacementValues[r] = shuffledValues[i++]; 192 } 193 break; 194 case ReplacementMethodEnum.Noise: 195 var avg = rows.Select(r => originalValues[r]).Average(); 196 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 197 random = new FastRandom(31415); 198 // prepare a complete column for the dataset 199 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 200 // update column values 201 foreach (var r in rows) { 202 replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev); 203 } 204 break; 205 206 default: 207 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod)); 208 } 209 210 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 211 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 212 if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 213 214 impact = originalValue - newValue; 215 #endregion 167 216 } else if (modifiableDataset.VariableHasType<string>(variableName)) { 168 impact = CalculateImpactForFactorVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod); 217 #region FactorVariable 218 var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 219 List<string> replacementValues; 220 221 switch (factorReplacementMethod) { 222 case FactorReplacementMethodEnum.Best: 223 // try replacing with all possible values and find the best replacement value 224 var smallestImpact = double.PositiveInfinity; 225 foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { 226 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList()); 227 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 228 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 229 230 var curImpact = originalValue - newValue; 231 if (curImpact < smallestImpact) smallestImpact = curImpact; 232 } 233 impact = smallestImpact; 234 break; 235 case FactorReplacementMethodEnum.Mode: 236 var mostCommonValue = rows.Select(r => originalValues[r]) 237 .GroupBy(v => v) 238 .OrderByDescending(g => g.Count()) 239 .First().Key; 240 replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList(); 241 242 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 243 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 244 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 245 246 impact = originalValue - newValue; 247 break; 248 case FactorReplacementMethodEnum.Shuffle: 249 // new var has same empirical distribution but the relation to y is broken 250 random = new FastRandom(31415); 251 // prepare a complete column for the dataset 252 replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList(); 253 // shuffle only the selected rows 254 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 255 int i = 0; 256 // update column values 257 foreach (var r in rows) { 258 replacementValues[r] = shuffledValues[i++]; 259 } 260 261 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 262 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 263 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 264 265 impact = originalValue - newValue; 266 break; 267 default: 268 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod)); 269 } 270 #endregion 169 271 } else { 170 272 throw new NotSupportedException("Variable not supported"); 171 273 } 274 172 275 return impact; 173 276 } 174 #endregion 175 176 private static double CalculateImpactForNumericalVariables(string variableName, 177 IRegressionModel model, 178 ModifiableDataset modifiableDataset, 179 IEnumerable<int> rows, 180 IEnumerable<double> targetValues, 181 double originalValue, 182 ReplacementMethodEnum replacementMethod) { 183 OnlineCalculatorError error; 184 var newEstimates = GetReplacedValuesForNumericalVariables(model, variableName, modifiableDataset, rows, replacementMethod); 185 var newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 186 if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 187 return originalValue - newValue; 188 } 189 190 private static double CalculateImpactForFactorVariables(string variableName, 191 IRegressionModel model, 192 ModifiableDataset modifiableDataset, 193 IEnumerable<int> rows, 194 IEnumerable<double> targetValues, 195 double originalValue, 196 FactorReplacementMethodEnum factorReplacementMethod) { 197 198 OnlineCalculatorError error; 199 if (factorReplacementMethod == FactorReplacementMethodEnum.Best) { 200 // try replacing with all possible values and find the best replacement value 201 var smallestImpact = double.PositiveInfinity; 202 foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { 203 var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 204 var newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList()); 205 var newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 206 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 207 208 var curImpact = originalValue - newValue; 209 if (curImpact < smallestImpact) smallestImpact = curImpact; 210 } 211 return smallestImpact; 212 } else { 213 // for replacement methods shuffle and mode 214 // calculate impacts for factor variables 215 var newEstimates = GetReplacedValuesForFactorVariables(model, variableName, modifiableDataset, rows, factorReplacementMethod); 216 var newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 217 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 218 219 return originalValue - newValue; 220 } 221 } 222 223 private static IEnumerable<double> GetReplacedValuesForNumericalVariables( 224 IRegressionModel model, 225 string variable, 226 ModifiableDataset dataset, 227 IEnumerable<int> rows, 228 ReplacementMethodEnum replacement = ReplacementMethodEnum.Shuffle) { 229 var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList(); 230 double replacementValue; 231 List<double> replacementValues; 232 IRandom rand; 233 234 switch (replacement) { 235 case ReplacementMethodEnum.Median: 236 replacementValue = rows.Select(r => originalValues[r]).Median(); 237 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList(); 238 break; 239 case ReplacementMethodEnum.Average: 240 replacementValue = rows.Select(r => originalValues[r]).Average(); 241 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList(); 242 break; 243 case ReplacementMethodEnum.Shuffle: 244 // new var has same empirical distribution but the relation to y is broken 245 rand = new FastRandom(31415); 246 // prepare a complete column for the dataset 247 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList(); 248 // shuffle only the selected rows 249 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList(); 250 int i = 0; 251 // update column values 252 foreach (var r in rows) { 253 replacementValues[r] = shuffledValues[i++]; 254 } 255 break; 256 case ReplacementMethodEnum.Noise: 257 var avg = rows.Select(r => originalValues[r]).Average(); 258 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 259 rand = new FastRandom(31415); 260 // prepare a complete column for the dataset 261 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList(); 262 // update column values 263 foreach (var r in rows) { 264 replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev); 265 } 266 break; 267 268 default: 269 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement)); 270 } 271 272 return GetReplacedEstimates(originalValues, model, variable, dataset, rows, replacementValues); 273 } 274 275 private static IEnumerable<double> GetReplacedValuesForFactorVariables( 276 IRegressionModel model, 277 string variable, 278 ModifiableDataset dataset, 279 IEnumerable<int> rows, 280 FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) { 281 var originalValues = dataset.GetReadOnlyStringValues(variable).ToList(); 282 List<string> replacementValues; 283 IRandom rand; 284 285 switch (replacement) { 286 case FactorReplacementMethodEnum.Mode: 287 var mostCommonValue = rows.Select(r => originalValues[r]) 288 .GroupBy(v => v) 289 .OrderByDescending(g => g.Count()) 290 .First().Key; 291 replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList(); 292 break; 293 case FactorReplacementMethodEnum.Shuffle: 294 // new var has same empirical distribution but the relation to y is broken 295 rand = new FastRandom(31415); 296 // prepare a complete column for the dataset 297 replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList(); 298 // shuffle only the selected rows 299 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList(); 300 int i = 0; 301 // update column values 302 foreach (var r in rows) { 303 replacementValues[r] = shuffledValues[i++]; 304 } 305 break; 306 default: 307 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement)); 308 } 309 310 return GetReplacedEstimates(originalValues, model, variable, dataset, rows, replacementValues); 311 } 312 277 278 /// <summary> 279 /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values 280 /// and changes the value of the model-variables back to the original ones. 281 /// </summary> 282 /// <param name="originalValues"></param> 283 /// <param name="model"></param> 284 /// <param name="variableName"></param> 285 /// <param name="modifiableDataset"></param> 286 /// <param name="rows"></param> 287 /// <param name="replacementValues"></param> 288 /// <returns></returns> 313 289 private static IEnumerable<double> GetReplacedEstimates( 314 290 IList originalValues, 315 291 IRegressionModel model, 316 string variable ,317 ModifiableDataset dataset,292 string variableName, 293 ModifiableDataset modifiableDataset, 318 294 IEnumerable<int> rows, 319 295 IList replacementValues) { 320 dataset.ReplaceVariable(variable, replacementValues);296 modifiableDataset.ReplaceVariable(variableName, replacementValues); 321 297 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 322 var estimates = model.GetEstimatedValues( dataset, rows).ToList();323 dataset.ReplaceVariable(variable, originalValues);298 var estimates = model.GetEstimatedValues(modifiableDataset, rows).ToList(); 299 modifiableDataset.ReplaceVariable(variableName, originalValues); 324 300 325 301 return estimates; 326 302 } 327 303 328 public static double CalculateVariableImpact(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) { 329 IEnumerator<double> firstEnumerator = originalValues.GetEnumerator(); 304 /// <summary> 305 /// Calculates and returns the VariableImpact (calculated via Pearsons R²). 306 /// </summary> 307 /// <param name="targetValues">The actual values</param> 308 /// <param name="estimatedValues">The calculated/replaced values</param> 309 /// <param name="errorState"></param> 310 /// <returns></returns> 311 public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) { 312 //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality 313 //as the code below does. But this way we can easily swap the calculator later on, so the user 314 //could choose a Calculator during runtime in future versions. 315 IOnlineCalculator calculator = new OnlinePearsonsRSquaredCalculator(); 316 IEnumerator<double> firstEnumerator = targetValues.GetEnumerator(); 330 317 IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator(); 331 var calculator = new OnlinePearsonsRSquaredCalculator();332 318 333 319 // always move forward both enumerators (do not use short-circuit evaluation!) … … 349 335 } 350 336 337 /// <summary> 338 /// Returns a collection of the row-indices for a given DataPartition (training or test) 339 /// </summary> 340 /// <param name="dataPartition"></param> 341 /// <param name="problemData"></param> 342 /// <returns></returns> 351 343 public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IRegressionProblemData problemData) { 352 344 IEnumerable<int> rows;
Note: See TracChangeset
for help on using the changeset viewer.