Changeset 16036 for branches/2904_CalculateImpacts/3.4/Implementation
- Timestamp:
- 08/01/18 14:01:08 (6 years ago)
- Location:
- branches/2904_CalculateImpacts/3.4/Implementation
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2904_CalculateImpacts/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs
r15674 r16036 23 23 24 24 using System; 25 using System.Collections; 25 26 using System.Collections.Generic; 26 27 using System.Linq; … … 36 37 [Item("ClassificationSolution Impacts Calculator", "Calculation of the impacts of input variables for any classification solution")] 37 38 public sealed class ClassificationSolutionVariableImpactsCalculator : ParameterizedNamedItem { 39 #region Parameters/Properties 38 40 public enum ReplacementMethodEnum { 39 41 Median, … … 54 56 55 57 private const string ReplacementParameterName = "Replacement Method"; 58 private const string FactorReplacementParameterName = "Factor Replacement Method"; 56 59 private const string DataPartitionParameterName = "DataPartition"; 57 60 58 61 public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter { 59 62 get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; } 63 } 64 public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter { 65 get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; } 60 66 } 61 67 public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter { … … 67 73 set { ReplacementParameter.Value.Value = value; } 68 74 } 75 public FactorReplacementMethodEnum FactorReplacementMethod { 76 get { return FactorReplacementParameter.Value.Value; } 77 set { FactorReplacementParameter.Value.Value = value; } 78 } 69 79 public DataPartitionEnum DataPartition { 70 80 get { return DataPartitionParameter.Value.Value; } 71 81 set { DataPartitionParameter.Value.Value = value; } 72 82 } 73 74 83 #endregion 84 85 #region Ctor/Cloner 75 86 [StorableConstructor] 76 87 private ClassificationSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { } 77 88 private ClassificationSolutionVariableImpactsCalculator(ClassificationSolutionVariableImpactsCalculator original, Cloner cloner) 78 89 : base(original, cloner) { } 90 public ClassificationSolutionVariableImpactsCalculator() 91 : base() { 92 Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Shuffle))); 93 Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training))); 94 } 95 79 96 public override IDeepCloneable Clone(Cloner cloner) { 80 97 return new ClassificationSolutionVariableImpactsCalculator(this, cloner); 81 98 } 82 83 public ClassificationSolutionVariableImpactsCalculator() 84 : base() { 85 Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median))); 86 Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training))); 87 } 99 #endregion 88 100 89 101 //mkommend: annoying name clash with static method, open to better naming suggestions 90 102 public IEnumerable<Tuple<string, double>> Calculate(IClassificationSolution solution) { 91 return CalculateImpacts(solution, DataPartition, ReplacementMethod);103 return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition); 92 104 } 93 105 94 106 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 95 107 IClassificationSolution solution, 96 DataPartitionEnum data = DataPartitionEnum.Training, 97 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median, 108 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 109 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 110 DataPartitionEnum dataPartition = DataPartitionEnum.Training) { 111 return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedClassValues, replacementMethod, factorReplacementMethod, dataPartition); 112 } 113 114 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 115 IClassificationModel model, 116 IClassificationProblemData problemData, 117 IEnumerable<double> estimatedValues, 118 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 119 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 120 DataPartitionEnum dataPartition = DataPartitionEnum.Training) { 121 IEnumerable<int> rows = GetPartitionRows(dataPartition, problemData); 122 return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod); 123 } 124 125 126 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 127 IClassificationModel model, 128 IClassificationProblemData problemData, 129 IEnumerable<double> estimatedClassValues, 130 IEnumerable<int> rows, 131 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 132 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 133 //Calculate original quality-values (via calculator, default is Accuracy) 134 OnlineCalculatorError error; 135 IEnumerable<double> targetValuesPartition = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 136 IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedClassValues.ElementAt(v)); 137 var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error); 138 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation."); 139 140 var impacts = new Dictionary<string, double>(); 141 var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction)); 142 var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList(); 143 var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable(); 144 145 foreach (var inputVariable in allowedInputVariables) { 146 impacts[inputVariable] = CalculateImpact(inputVariable, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod); 147 } 148 149 return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value)); 150 } 151 152 153 public static double CalculateImpact(string variableName, 154 IClassificationModel model, 155 ModifiableDataset modifiableDataset, 156 IEnumerable<int> rows, 157 IEnumerable<double> targetValues, 158 double originalValue, 159 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 98 160 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 99 100 var problemData = solution.ProblemData; 101 var dataset = problemData.Dataset; 102 161 double impact = 0; 162 OnlineCalculatorError error; 163 IRandom random; 164 double replacementValue; 165 IEnumerable<double> newEstimates = null; 166 double newValue = 0; 167 168 if (modifiableDataset.VariableHasType<double>(variableName)) { 169 #region NumericalVariable 170 var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList(); 171 List<double> replacementValues; 172 IRandom rand; 173 174 switch (replacementMethod) { 175 case ReplacementMethodEnum.Median: 176 replacementValue = rows.Select(r => originalValues[r]).Median(); 177 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 178 break; 179 case ReplacementMethodEnum.Average: 180 replacementValue = rows.Select(r => originalValues[r]).Average(); 181 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 182 break; 183 case ReplacementMethodEnum.Shuffle: 184 // new var has same empirical distribution but the relation to y is broken 185 rand = new FastRandom(31415); 186 // prepare a complete column for the dataset 187 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 188 // shuffle only the selected rows 189 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList(); 190 int i = 0; 191 // update column values 192 foreach (var r in rows) { 193 replacementValues[r] = shuffledValues[i++]; 194 } 195 break; 196 case ReplacementMethodEnum.Noise: 197 var avg = rows.Select(r => originalValues[r]).Average(); 198 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 199 rand = new FastRandom(31415); 200 // prepare a complete column for the dataset 201 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 202 // update column values 203 foreach (var r in rows) { 204 replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev); 205 } 206 break; 207 208 default: 209 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod)); 210 } 211 212 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 213 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 214 if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 215 216 impact = originalValue - newValue; 217 #endregion 218 } else if (modifiableDataset.VariableHasType<string>(variableName)) { 219 #region FactorVariable 220 var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 221 List<string> replacementValues; 222 223 switch (factorReplacementMethod) { 224 case FactorReplacementMethodEnum.Best: 225 // try replacing with all possible values and find the best replacement value 226 var smallestImpact = double.PositiveInfinity; 227 foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { 228 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList()); 229 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 230 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 231 232 var curImpact = originalValue - newValue; 233 if (curImpact < smallestImpact) smallestImpact = curImpact; 234 } 235 impact = smallestImpact; 236 break; 237 case FactorReplacementMethodEnum.Mode: 238 var mostCommonValue = rows.Select(r => originalValues[r]) 239 .GroupBy(v => v) 240 .OrderByDescending(g => g.Count()) 241 .First().Key; 242 replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList(); 243 244 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 245 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 246 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 247 248 impact = originalValue - newValue; 249 break; 250 case FactorReplacementMethodEnum.Shuffle: 251 // new var has same empirical distribution but the relation to y is broken 252 random = new FastRandom(31415); 253 // prepare a complete column for the dataset 254 replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList(); 255 // shuffle only the selected rows 256 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 257 int i = 0; 258 // update column values 259 foreach (var r in rows) { 260 replacementValues[r] = shuffledValues[i++]; 261 } 262 263 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 264 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 265 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 266 267 impact = originalValue - newValue; 268 break; 269 default: 270 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod)); 271 } 272 #endregion 273 } else { 274 throw new NotSupportedException("Variable not supported"); 275 } 276 277 return impact; 278 } 279 280 /// <summary> 281 /// Calculates and returns the VariableImpact (calculated via Accuracy). 282 /// </summary> 283 /// <param name="targetValues">The actual values</param> 284 /// <param name="estimatedValues">The calculated/replaced values</param> 285 /// <param name="errorState"></param> 286 /// <returns></returns> 287 public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) { 288 //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality 289 //as the code below does. But this way we can easily swap the calculator later on, so the user 290 //could choose a Calculator during runtime in future versions. 291 IOnlineCalculator calculator = new OnlineAccuracyCalculator(); 292 IEnumerator<double> firstEnumerator = targetValues.GetEnumerator(); 293 IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator(); 294 295 // always move forward both enumerators (do not use short-circuit evaluation!) 296 while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) { 297 double original = firstEnumerator.Current; 298 double estimated = secondEnumerator.Current; 299 calculator.Add(original, estimated); 300 if (calculator.ErrorState != OnlineCalculatorError.None) break; 301 } 302 303 // check if both enumerators are at the end to make sure both enumerations have the same length 304 if (calculator.ErrorState == OnlineCalculatorError.None && 305 (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) { 306 throw new ArgumentException("Number of elements in first and second enumeration doesn't match."); 307 } else { 308 errorState = calculator.ErrorState; 309 return calculator.Value; 310 } 311 } 312 313 /// <summary> 314 /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values 315 /// and changes the value of the model-variables back to the original ones. 316 /// </summary> 317 /// <param name="originalValues"></param> 318 /// <param name="model"></param> 319 /// <param name="variableName"></param> 320 /// <param name="modifiableDataset"></param> 321 /// <param name="rows"></param> 322 /// <param name="replacementValues"></param> 323 /// <returns></returns> 324 private static IEnumerable<double> GetReplacedEstimates( 325 IList originalValues, 326 IClassificationModel model, 327 string variableName, 328 ModifiableDataset modifiableDataset, 329 IEnumerable<int> rows, 330 IList replacementValues) { 331 modifiableDataset.ReplaceVariable(variableName, replacementValues); 332 333 var discModel = model as IDiscriminantFunctionClassificationModel; 334 if (discModel != null) { 335 var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable); 336 discModel.RecalculateModelParameters(problemData, rows); 337 } 338 339 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 340 var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList(); 341 modifiableDataset.ReplaceVariable(variableName, originalValues); 342 343 return estimates; 344 } 345 346 347 /// <summary> 348 /// Returns a collection of the row-indices for a given DataPartition (training or test) 349 /// </summary> 350 /// <param name="dataPartition"></param> 351 /// <param name="problemData"></param> 352 /// <returns></returns> 353 public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IClassificationProblemData problemData) { 103 354 IEnumerable<int> rows; 104 IEnumerable<double> targetValues; 105 double originalAccuracy; 106 107 OnlineCalculatorError error; 108 109 switch (data) { 355 356 switch (dataPartition) { 110 357 case DataPartitionEnum.All: 111 358 rows = problemData.AllIndices; 112 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.AllIndices).ToList();113 originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedClassValues, out error);114 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");359 break; 360 case DataPartitionEnum.Test: 361 rows = problemData.TestIndices; 115 362 break; 116 363 case DataPartitionEnum.Training: 117 364 rows = problemData.TrainingIndices; 118 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToList();119 originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTrainingClassValues, out error);120 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");121 break;122 case DataPartitionEnum.Test:123 rows = problemData.TestIndices;124 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToList();125 originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTestClassValues, out error);126 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");127 break;128 default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data));129 }130 131 var impacts = new Dictionary<string, double>();132 var modifiableDataset = ((Dataset)dataset).ToModifiable();133 134 var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));135 var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();136 137 // calculate impacts for double variables138 foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) {139 var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, replacementMethod);140 var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);141 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");142 143 impacts[inputVariable] = originalAccuracy - newAccuracy;144 }145 146 // calculate impacts for string variables147 foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) {148 if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {149 // try replacing with all possible values and find the best replacement value150 var smallestImpact = double.PositiveInfinity;151 foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) {152 var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,153 Enumerable.Repeat(repl, dataset.Rows));154 var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);155 if (error != OnlineCalculatorError.None)156 throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");157 158 var impact = originalAccuracy - newAccuracy;159 if (impact < smallestImpact) smallestImpact = impact;160 }161 impacts[inputVariable] = smallestImpact;162 } else {163 // for replacement methods shuffle and mode164 // calculate impacts for factor variables165 166 var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,167 factorReplacementMethod);168 var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);169 if (error != OnlineCalculatorError.None)170 throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");171 172 impacts[inputVariable] = originalAccuracy - newAccuracy;173 }174 } // foreach175 return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));176 }177 178 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {179 var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();180 double replacementValue;181 List<double> replacementValues;182 IRandom rand;183 184 switch (replacement) {185 case ReplacementMethodEnum.Median:186 replacementValue = rows.Select(r => originalValues[r]).Median();187 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();188 break;189 case ReplacementMethodEnum.Average:190 replacementValue = rows.Select(r => originalValues[r]).Average();191 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();192 break;193 case ReplacementMethodEnum.Shuffle:194 // new var has same empirical distribution but the relation to y is broken195 rand = new FastRandom(31415);196 // prepare a complete column for the dataset197 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();198 // shuffle only the selected rows199 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();200 int i = 0;201 // update column values202 foreach (var r in rows) {203 replacementValues[r] = shuffledValues[i++];204 }205 break;206 case ReplacementMethodEnum.Noise:207 var avg = rows.Select(r => originalValues[r]).Average();208 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();209 rand = new FastRandom(31415);210 // prepare a complete column for the dataset211 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();212 // update column values213 foreach (var r in rows) {214 replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);215 }216 break;217 218 default:219 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));220 }221 222 return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);223 }224 225 private static IEnumerable<double> EvaluateModelWithReplacedVariable(226 IClassificationModel model, string variable, ModifiableDataset dataset,227 IEnumerable<int> rows,228 FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {229 var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();230 List<string> replacementValues;231 IRandom rand;232 233 switch (replacement) {234 case FactorReplacementMethodEnum.Mode:235 var mostCommonValue = rows.Select(r => originalValues[r])236 .GroupBy(v => v)237 .OrderByDescending(g => g.Count())238 .First().Key;239 replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();240 break;241 case FactorReplacementMethodEnum.Shuffle:242 // new var has same empirical distribution but the relation to y is broken243 rand = new FastRandom(31415);244 // prepare a complete column for the dataset245 replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();246 // shuffle only the selected rows247 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();248 int i = 0;249 // update column values250 foreach (var r in rows) {251 replacementValues[r] = shuffledValues[i++];252 }253 365 break; 254 366 default: 255 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement)); 256 } 257 258 return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues); 259 } 260 261 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, 262 ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) { 263 var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList(); 264 dataset.ReplaceVariable(variable, replacementValues.ToList()); 265 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 266 var estimates = model.GetEstimatedClassValues(dataset, rows).ToList(); 267 dataset.ReplaceVariable(variable, originalValues); 268 269 return estimates; 270 } 271 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, 272 ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) { 273 var originalValues = dataset.GetReadOnlyStringValues(variable).ToList(); 274 dataset.ReplaceVariable(variable, replacementValues.ToList()); 275 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 276 var estimates = model.GetEstimatedClassValues(dataset, rows).ToList(); 277 dataset.ReplaceVariable(variable, originalValues); 278 279 return estimates; 280 } 367 throw new NotSupportedException("DataPartition not supported"); 368 } 369 370 return rows; 371 } 372 281 373 } 282 374 } -
branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs
r16035 r16036 100 100 #endregion 101 101 102 #region Public Methods/Wrappers103 102 //mkommend: annoying name clash with static method, open to better naming suggestions 104 103 public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) { … … 159 158 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 160 159 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 161 162 160 double impact = 0; 163 164 // calculate impacts for double variables 161 OnlineCalculatorError error; 162 IRandom random; 163 double replacementValue; 164 IEnumerable<double> newEstimates = null; 165 double newValue = 0; 166 165 167 if (modifiableDataset.VariableHasType<double>(variableName)) { 166 impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod); 168 #region NumericalVariable 169 var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList(); 170 List<double> replacementValues; 171 172 switch (replacementMethod) { 173 case ReplacementMethodEnum.Median: 174 replacementValue = rows.Select(r => originalValues[r]).Median(); 175 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 176 break; 177 case ReplacementMethodEnum.Average: 178 replacementValue = rows.Select(r => originalValues[r]).Average(); 179 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 180 break; 181 case ReplacementMethodEnum.Shuffle: 182 // new var has same empirical distribution but the relation to y is broken 183 random = new FastRandom(31415); 184 // prepare a complete column for the dataset 185 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 186 // shuffle only the selected rows 187 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 188 int i = 0; 189 // update column values 190 foreach (var r in rows) { 191 replacementValues[r] = shuffledValues[i++]; 192 } 193 break; 194 case ReplacementMethodEnum.Noise: 195 var avg = rows.Select(r => originalValues[r]).Average(); 196 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 197 random = new FastRandom(31415); 198 // prepare a complete column for the dataset 199 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 200 // update column values 201 foreach (var r in rows) { 202 replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev); 203 } 204 break; 205 206 default: 207 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod)); 208 } 209 210 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 211 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 212 if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 213 214 impact = originalValue - newValue; 215 #endregion 167 216 } else if (modifiableDataset.VariableHasType<string>(variableName)) { 168 impact = CalculateImpactForFactorVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod); 217 #region FactorVariable 218 var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 219 List<string> replacementValues; 220 221 switch (factorReplacementMethod) { 222 case FactorReplacementMethodEnum.Best: 223 // try replacing with all possible values and find the best replacement value 224 var smallestImpact = double.PositiveInfinity; 225 foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { 226 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList()); 227 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 228 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 229 230 var curImpact = originalValue - newValue; 231 if (curImpact < smallestImpact) smallestImpact = curImpact; 232 } 233 impact = smallestImpact; 234 break; 235 case FactorReplacementMethodEnum.Mode: 236 var mostCommonValue = rows.Select(r => originalValues[r]) 237 .GroupBy(v => v) 238 .OrderByDescending(g => g.Count()) 239 .First().Key; 240 replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList(); 241 242 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 243 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 244 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 245 246 impact = originalValue - newValue; 247 break; 248 case FactorReplacementMethodEnum.Shuffle: 249 // new var has same empirical distribution but the relation to y is broken 250 random = new FastRandom(31415); 251 // prepare a complete column for the dataset 252 replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList(); 253 // shuffle only the selected rows 254 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 255 int i = 0; 256 // update column values 257 foreach (var r in rows) { 258 replacementValues[r] = shuffledValues[i++]; 259 } 260 261 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 262 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 263 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 264 265 impact = originalValue - newValue; 266 break; 267 default: 268 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod)); 269 } 270 #endregion 169 271 } else { 170 272 throw new NotSupportedException("Variable not supported"); 171 273 } 274 172 275 return impact; 173 276 } 174 #endregion 175 176 private static double CalculateImpactForNumericalVariables(string variableName, 177 IRegressionModel model, 178 ModifiableDataset modifiableDataset, 179 IEnumerable<int> rows, 180 IEnumerable<double> targetValues, 181 double originalValue, 182 ReplacementMethodEnum replacementMethod) { 183 OnlineCalculatorError error; 184 var newEstimates = GetReplacedValuesForNumericalVariables(model, variableName, modifiableDataset, rows, replacementMethod); 185 var newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 186 if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 187 return originalValue - newValue; 188 } 189 190 private static double CalculateImpactForFactorVariables(string variableName, 191 IRegressionModel model, 192 ModifiableDataset modifiableDataset, 193 IEnumerable<int> rows, 194 IEnumerable<double> targetValues, 195 double originalValue, 196 FactorReplacementMethodEnum factorReplacementMethod) { 197 198 OnlineCalculatorError error; 199 if (factorReplacementMethod == FactorReplacementMethodEnum.Best) { 200 // try replacing with all possible values and find the best replacement value 201 var smallestImpact = double.PositiveInfinity; 202 foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { 203 var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 204 var newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList()); 205 var newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 206 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 207 208 var curImpact = originalValue - newValue; 209 if (curImpact < smallestImpact) smallestImpact = curImpact; 210 } 211 return smallestImpact; 212 } else { 213 // for replacement methods shuffle and mode 214 // calculate impacts for factor variables 215 var newEstimates = GetReplacedValuesForFactorVariables(model, variableName, modifiableDataset, rows, factorReplacementMethod); 216 var newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 217 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 218 219 return originalValue - newValue; 220 } 221 } 222 223 private static IEnumerable<double> GetReplacedValuesForNumericalVariables( 224 IRegressionModel model, 225 string variable, 226 ModifiableDataset dataset, 227 IEnumerable<int> rows, 228 ReplacementMethodEnum replacement = ReplacementMethodEnum.Shuffle) { 229 var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList(); 230 double replacementValue; 231 List<double> replacementValues; 232 IRandom rand; 233 234 switch (replacement) { 235 case ReplacementMethodEnum.Median: 236 replacementValue = rows.Select(r => originalValues[r]).Median(); 237 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList(); 238 break; 239 case ReplacementMethodEnum.Average: 240 replacementValue = rows.Select(r => originalValues[r]).Average(); 241 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList(); 242 break; 243 case ReplacementMethodEnum.Shuffle: 244 // new var has same empirical distribution but the relation to y is broken 245 rand = new FastRandom(31415); 246 // prepare a complete column for the dataset 247 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList(); 248 // shuffle only the selected rows 249 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList(); 250 int i = 0; 251 // update column values 252 foreach (var r in rows) { 253 replacementValues[r] = shuffledValues[i++]; 254 } 255 break; 256 case ReplacementMethodEnum.Noise: 257 var avg = rows.Select(r => originalValues[r]).Average(); 258 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 259 rand = new FastRandom(31415); 260 // prepare a complete column for the dataset 261 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList(); 262 // update column values 263 foreach (var r in rows) { 264 replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev); 265 } 266 break; 267 268 default: 269 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement)); 270 } 271 272 return GetReplacedEstimates(originalValues, model, variable, dataset, rows, replacementValues); 273 } 274 275 private static IEnumerable<double> GetReplacedValuesForFactorVariables( 276 IRegressionModel model, 277 string variable, 278 ModifiableDataset dataset, 279 IEnumerable<int> rows, 280 FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) { 281 var originalValues = dataset.GetReadOnlyStringValues(variable).ToList(); 282 List<string> replacementValues; 283 IRandom rand; 284 285 switch (replacement) { 286 case FactorReplacementMethodEnum.Mode: 287 var mostCommonValue = rows.Select(r => originalValues[r]) 288 .GroupBy(v => v) 289 .OrderByDescending(g => g.Count()) 290 .First().Key; 291 replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList(); 292 break; 293 case FactorReplacementMethodEnum.Shuffle: 294 // new var has same empirical distribution but the relation to y is broken 295 rand = new FastRandom(31415); 296 // prepare a complete column for the dataset 297 replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList(); 298 // shuffle only the selected rows 299 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList(); 300 int i = 0; 301 // update column values 302 foreach (var r in rows) { 303 replacementValues[r] = shuffledValues[i++]; 304 } 305 break; 306 default: 307 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement)); 308 } 309 310 return GetReplacedEstimates(originalValues, model, variable, dataset, rows, replacementValues); 311 } 312 277 278 /// <summary> 279 /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values 280 /// and changes the value of the model-variables back to the original ones. 281 /// </summary> 282 /// <param name="originalValues"></param> 283 /// <param name="model"></param> 284 /// <param name="variableName"></param> 285 /// <param name="modifiableDataset"></param> 286 /// <param name="rows"></param> 287 /// <param name="replacementValues"></param> 288 /// <returns></returns> 313 289 private static IEnumerable<double> GetReplacedEstimates( 314 290 IList originalValues, 315 291 IRegressionModel model, 316 string variable ,317 ModifiableDataset dataset,292 string variableName, 293 ModifiableDataset modifiableDataset, 318 294 IEnumerable<int> rows, 319 295 IList replacementValues) { 320 dataset.ReplaceVariable(variable, replacementValues);296 modifiableDataset.ReplaceVariable(variableName, replacementValues); 321 297 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 322 var estimates = model.GetEstimatedValues( dataset, rows).ToList();323 dataset.ReplaceVariable(variable, originalValues);298 var estimates = model.GetEstimatedValues(modifiableDataset, rows).ToList(); 299 modifiableDataset.ReplaceVariable(variableName, originalValues); 324 300 325 301 return estimates; 326 302 } 327 303 328 public static double CalculateVariableImpact(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) { 329 IEnumerator<double> firstEnumerator = originalValues.GetEnumerator(); 304 /// <summary> 305 /// Calculates and returns the VariableImpact (calculated via Pearsons R²). 306 /// </summary> 307 /// <param name="targetValues">The actual values</param> 308 /// <param name="estimatedValues">The calculated/replaced values</param> 309 /// <param name="errorState"></param> 310 /// <returns></returns> 311 public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) { 312 //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality 313 //as the code below does. But this way we can easily swap the calculator later on, so the user 314 //could choose a Calculator during runtime in future versions. 315 IOnlineCalculator calculator = new OnlinePearsonsRSquaredCalculator(); 316 IEnumerator<double> firstEnumerator = targetValues.GetEnumerator(); 330 317 IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator(); 331 var calculator = new OnlinePearsonsRSquaredCalculator();332 318 333 319 // always move forward both enumerators (do not use short-circuit evaluation!) … … 349 335 } 350 336 337 /// <summary> 338 /// Returns a collection of the row-indices for a given DataPartition (training or test) 339 /// </summary> 340 /// <param name="dataPartition"></param> 341 /// <param name="problemData"></param> 342 /// <returns></returns> 351 343 public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IRegressionProblemData problemData) { 352 344 IEnumerable<int> rows;
Note: See TracChangeset
for help on using the changeset viewer.