Changeset 15831 for branches/2904_CalculateImpacts/3.4/Implementation
- Timestamp:
- 03/08/18 10:44:51 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs
r15816 r15831 23 23 24 24 using System; 25 using System.Collections; 25 26 using System.Collections.Generic; 26 27 using System.Linq; … … 96 97 } 97 98 98 p rivate static void PrepareData(DataPartitionEnum partition,99 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 99 100 IRegressionSolution solution, 100 out IEnumerable<int> rows, 101 DataPartitionEnum data = DataPartitionEnum.Training, 102 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median, 103 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 104 Func<double, string, bool> progressCallback = null) { 105 return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedValues, data, replacementMethod, factorReplacementMethod, progressCallback); 106 } 107 108 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 109 IRegressionModel model, 110 IRegressionProblemData problemData, 111 IEnumerable<double> estimatedValues, 112 DataPartitionEnum data = DataPartitionEnum.Training, 113 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median, 114 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 115 Func<double, string, bool> progressCallback = null, 116 IOnlineCalculator calculator = null) { 117 //PearsonsRSquared is the default calculator 118 if (calculator == null) { calculator = new OnlinePearsonsRSquaredCalculator(); } 119 IEnumerable<int> rows; 120 121 switch (data) { 122 case DataPartitionEnum.All: 123 rows = problemData.AllIndices; 124 break; 125 case DataPartitionEnum.Test: 126 rows = problemData.TestIndices; 127 break; 128 case DataPartitionEnum.Training: 129 rows = problemData.TrainingIndices; 130 break; 131 default: 132 throw new NotSupportedException("DataPartition not supported"); 133 } 134 135 return CalculateImpacts(model, problemData, estimatedValues, rows, calculator, replacementMethod, factorReplacementMethod, progressCallback); 136 } 137 138 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 139 IRegressionModel model, 140 IRegressionProblemData problemData, 141 IEnumerable<double> estimatedValues, 142 IEnumerable<int> rows, 143 IOnlineCalculator calculator, 144 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median, 145 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 146 Func<double, string, bool> progressCallback = null) { 147 148 IEnumerable<double> targetValues; 149 double originalValue = -1; 150 151 PrepareData(rows, problemData, estimatedValues, out targetValues, out originalValue, calculator); 152 153 var impacts = new Dictionary<string, double>(); 154 var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction)); 155 var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList(); 156 157 int curIdx = 0; 158 int count = allowedInputVariables 159 .Where(v => problemData.Dataset.VariableHasType<double>(v) || problemData.Dataset.VariableHasType<string>(v)) 160 .Count(); 161 162 foreach (var inputVariable in allowedInputVariables) { 163 //Report the current progress in percent. If the callback returns true, it means the execution shall be stopped 164 if (progressCallback != null) { 165 curIdx++; 166 if (progressCallback((double)curIdx / count, string.Format("Calculating impact for variable {0} ({1} of {2})", inputVariable, curIdx, count))) { return null; } 167 } 168 impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData.Dataset, rows, targetValues, originalValue, calculator, replacementMethod, factorReplacementMethod); 169 } 170 171 return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value)); 172 } 173 174 public static double CalculateImpact(string variableName, 175 IRegressionSolution solution, 176 IEnumerable<int> rows, 177 IEnumerable<double> targetValues, 178 double originalValue, 179 IOnlineCalculator calculator, 180 DataPartitionEnum data = DataPartitionEnum.Training, 181 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median, 182 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 183 return CalculateImpact(variableName, solution.Model, solution.ProblemData.Dataset, rows, targetValues, originalValue, calculator, replacementMethod, factorReplacementMethod); 184 } 185 186 public static double CalculateImpact(string variableName, 187 IRegressionModel model, 188 IDataset dataset, 189 IEnumerable<int> rows, 190 IEnumerable<double> targetValues, 191 double originalValue, 192 IOnlineCalculator calculator, 193 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median, 194 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 195 196 double impact = 0; 197 var modifiableDataset = ((Dataset)dataset).ToModifiable(); 198 199 // calculate impacts for double variables 200 if (dataset.VariableHasType<double>(variableName)) { 201 impact = CalculateImpactForDouble(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod, calculator); 202 } else if (dataset.VariableHasType<string>(variableName)) { 203 impact = CalculateImpactForString(variableName, model, dataset, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod, calculator); 204 } else { 205 throw new NotSupportedException("Variable not supported"); 206 } 207 return impact; 208 } 209 210 private static void PrepareData(IEnumerable<int> rows, 211 IRegressionProblemData problemData, 212 IEnumerable<double> estimatedValues, 101 213 out IEnumerable<double> targetValues, 102 out double originalR2) { 214 out double originalValue, 215 IOnlineCalculator calculator) { 103 216 OnlineCalculatorError error; 104 217 105 switch (partition) { 106 case DataPartitionEnum.All: 107 rows = solution.ProblemData.AllIndices; 108 targetValues = solution.ProblemData.TargetVariableValues.ToList(); 109 originalR2 = OnlinePearsonsRCalculator.Calculate(solution.ProblemData.TargetVariableValues, solution.EstimatedValues, out error); 110 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation."); 111 originalR2 = originalR2 * originalR2; 112 break; 113 case DataPartitionEnum.Training: 114 rows = solution.ProblemData.TrainingIndices; 115 targetValues = solution.ProblemData.TargetVariableTrainingValues.ToList(); 116 originalR2 = solution.TrainingRSquared; 117 break; 118 case DataPartitionEnum.Test: 119 rows = solution.ProblemData.TestIndices; 120 targetValues = solution.ProblemData.TargetVariableTestValues.ToList(); 121 originalR2 = solution.TestRSquared; 122 break; 123 default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", partition)); 124 } 218 var targetVariableValueList = problemData.TargetVariableValues.ToList(); 219 targetValues = rows.Select(v => targetVariableValueList.ElementAt(v)); 220 var estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v)); 221 originalValue = calculator.CalculateValue(targetValues, estimatedValuesPartition, out error); 222 223 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation."); 125 224 } 126 225 127 226 private static double CalculateImpactForDouble(string variableName, 128 IRegression Solution solution,227 IRegressionModel model, 129 228 ModifiableDataset modifiableDataset, 130 229 IEnumerable<int> rows, 131 230 IEnumerable<double> targetValues, 132 double originalR2, 133 ReplacementMethodEnum replacementMethod) { 231 double originalValue, 232 ReplacementMethodEnum replacementMethod, 233 IOnlineCalculator calculator) { 134 234 OnlineCalculatorError error; 135 var newEstimates = EvaluateModelWithReplacedVariable( solution.Model, variableName, modifiableDataset, rows, replacementMethod);136 var new R2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);137 if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during R²calculation with replaced inputs."); }138 return original R2 - (newR2 * newR2);235 var newEstimates = EvaluateModelWithReplacedVariable(model, variableName, modifiableDataset, rows, replacementMethod); 236 var newValue = calculator.CalculateValue(targetValues, newEstimates, out error); 237 if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 238 return originalValue - newValue; 139 239 } 140 240 141 241 private static double CalculateImpactForString(string variableName, 142 IRegressionSolution solution, 242 IRegressionModel model, 243 IDataset problemData, 143 244 ModifiableDataset modifiableDataset, 144 245 IEnumerable<int> rows, 145 246 IEnumerable<double> targetValues, 146 double originalR2, 147 FactorReplacementMethodEnum factorReplacementMethod) { 247 double originalValue, 248 FactorReplacementMethodEnum factorReplacementMethod, 249 IOnlineCalculator calculator) { 148 250 149 251 OnlineCalculatorError error; … … 151 253 // try replacing with all possible values and find the best replacement value 152 254 var smallestImpact = double.PositiveInfinity; 153 foreach (var repl in solution.ProblemData.Dataset.GetStringValues(variableName, rows).Distinct()) { 154 var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, solution.ProblemData.Dataset.Rows)); 155 var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error); 156 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs."); 157 158 var curImpact = originalR2 - (newR2 * newR2); 255 foreach (var repl in problemData.GetStringValues(variableName, rows).Distinct()) { 256 var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 257 var newEstimates = EvaluateModelWithReplacedVariable(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, problemData.Rows).ToList()); 258 var newValue = calculator.CalculateValue(targetValues, newEstimates, out error); 259 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 260 261 var curImpact = originalValue - newValue; 159 262 if (curImpact < smallestImpact) smallestImpact = curImpact; 160 263 } … … 163 266 // for replacement methods shuffle and mode 164 267 // calculate impacts for factor variables 165 var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, variableName, modifiableDataset, rows, factorReplacementMethod); 166 var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error); 167 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs."); 168 169 return originalR2 - (newR2 * newR2); 170 } 171 } 172 public static double CalculateImpact(string variableName, 173 IRegressionSolution solution, 174 IEnumerable<int> rows, 175 IEnumerable<double> targetValues, 176 double originalR2, 177 DataPartitionEnum data = DataPartitionEnum.Training, 178 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median, 179 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 180 181 double impact = 0; 182 var modifiableDataset = ((Dataset)solution.ProblemData.Dataset).ToModifiable(); 183 184 // calculate impacts for double variables 185 if (solution.ProblemData.Dataset.VariableHasType<double>(variableName)) { 186 impact = CalculateImpactForDouble(variableName, solution, modifiableDataset, rows, targetValues, originalR2, replacementMethod); 187 } else if (solution.ProblemData.Dataset.VariableHasType<string>(variableName)) { 188 impact = CalculateImpactForString(variableName, solution, modifiableDataset, rows, targetValues, originalR2, factorReplacementMethod); 189 } else { 190 throw new NotSupportedException("Variable not supported"); 191 } 192 return impact; 193 } 194 195 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 196 IRegressionSolution solution, 197 DataPartitionEnum data = DataPartitionEnum.Training, 198 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median, 199 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 200 Func<double, string, bool> progressCallback = null) { 201 202 IEnumerable<int> rows; 203 IEnumerable<double> targetValues; 204 double originalR2 = -1; 205 206 PrepareData(data, solution, out rows, out targetValues, out originalR2); 207 208 var impacts = new Dictionary<string, double>(); 209 var inputvariables = new HashSet<string>(solution.ProblemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction)); 210 var allowedInputVariables = solution.ProblemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList(); 211 212 int curIdx = 0; 213 int count = allowedInputVariables.Where(solution.ProblemData.Dataset.VariableHasType<double>).Count(); 214 // calculate impacts for double variables 215 foreach (var inputVariable in allowedInputVariables) { 216 //Report the current progress in percent. If the callback returns true, it means the execution shall be stopped 217 if (progressCallback != null) { 218 curIdx++; 219 if (progressCallback((double)curIdx / count, string.Format("Calculating impact for variable {0} ({1} of {2})", inputVariable, curIdx, count))) { return null; } 220 } 221 impacts[inputVariable] = CalculateImpact(inputVariable, solution, rows, targetValues, originalR2, data, replacementMethod, factorReplacementMethod); 222 } 223 224 return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value)); 268 var newEstimates = EvaluateModelWithReplacedVariable(model, variableName, modifiableDataset, rows, factorReplacementMethod); 269 var newValue = calculator.CalculateValue(targetValues, newEstimates, out error); 270 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 271 272 return originalValue - newValue; 273 } 225 274 } 226 275 … … 269 318 } 270 319 271 return EvaluateModelWithReplacedVariable( model, variable, dataset, rows, replacementValues);320 return EvaluateModelWithReplacedVariable(originalValues, model, variable, dataset, rows, replacementValues); 272 321 } 273 322 … … 305 354 } 306 355 307 return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues); 308 } 309 310 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable, 311 ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) { 312 var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList(); 313 dataset.ReplaceVariable(variable, replacementValues.ToList()); 356 return EvaluateModelWithReplacedVariable(originalValues, model, variable, dataset, rows, replacementValues); 357 } 358 359 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IList originalValues, IRegressionModel model, string variable, 360 ModifiableDataset dataset, IEnumerable<int> rows, IList replacementValues) { 361 dataset.ReplaceVariable(variable, replacementValues); 314 362 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 315 363 var estimates = model.GetEstimatedValues(dataset, rows).ToList(); … … 318 366 return estimates; 319 367 } 320 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable,321 ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) {322 var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();323 dataset.ReplaceVariable(variable, replacementValues.ToList());324 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements325 var estimates = model.GetEstimatedValues(dataset, rows).ToList();326 dataset.ReplaceVariable(variable, originalValues);327 328 return estimates;329 }330 368 } 331 369 }
Note: See TracChangeset
for help on using the changeset viewer.