- Timestamp:
- 08/01/18 14:01:08 (6 years ago)
- Location:
- branches/2904_CalculateImpacts
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2904_CalculateImpacts/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs
r15674 r16036 23 23 24 24 using System; 25 using System.Collections; 25 26 using System.Collections.Generic; 26 27 using System.Linq; … … 36 37 [Item("ClassificationSolution Impacts Calculator", "Calculation of the impacts of input variables for any classification solution")] 37 38 public sealed class ClassificationSolutionVariableImpactsCalculator : ParameterizedNamedItem { 39 #region Parameters/Properties 38 40 public enum ReplacementMethodEnum { 39 41 Median, … … 54 56 55 57 private const string ReplacementParameterName = "Replacement Method"; 58 private const string FactorReplacementParameterName = "Factor Replacement Method"; 56 59 private const string DataPartitionParameterName = "DataPartition"; 57 60 58 61 public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter { 59 62 get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; } 63 } 64 public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter { 65 get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; } 60 66 } 61 67 public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter { … … 67 73 set { ReplacementParameter.Value.Value = value; } 68 74 } 75 public FactorReplacementMethodEnum FactorReplacementMethod { 76 get { return FactorReplacementParameter.Value.Value; } 77 set { FactorReplacementParameter.Value.Value = value; } 78 } 69 79 public DataPartitionEnum DataPartition { 70 80 get { return DataPartitionParameter.Value.Value; } 71 81 set { DataPartitionParameter.Value.Value = value; } 72 82 } 73 74 83 #endregion 84 85 #region Ctor/Cloner 75 86 [StorableConstructor] 76 87 private ClassificationSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { } 77 88 private ClassificationSolutionVariableImpactsCalculator(ClassificationSolutionVariableImpactsCalculator original, Cloner cloner) 78 89 : base(original, cloner) { } 90 public ClassificationSolutionVariableImpactsCalculator() 91 : base() { 92 Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Shuffle))); 93 Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training))); 94 } 95 79 96 public override IDeepCloneable Clone(Cloner cloner) { 80 97 return new ClassificationSolutionVariableImpactsCalculator(this, cloner); 81 98 } 82 83 public ClassificationSolutionVariableImpactsCalculator() 84 : base() { 85 Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median))); 86 Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training))); 87 } 99 #endregion 88 100 89 101 //mkommend: annoying name clash with static method, open to better naming suggestions 90 102 public IEnumerable<Tuple<string, double>> Calculate(IClassificationSolution solution) { 91 return CalculateImpacts(solution, DataPartition, ReplacementMethod);103 return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition); 92 104 } 93 105 94 106 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 95 107 IClassificationSolution solution, 96 DataPartitionEnum data = DataPartitionEnum.Training, 97 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median, 108 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 109 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 110 DataPartitionEnum dataPartition = DataPartitionEnum.Training) { 111 return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedClassValues, replacementMethod, factorReplacementMethod, dataPartition); 112 } 113 114 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 115 IClassificationModel model, 116 IClassificationProblemData problemData, 117 IEnumerable<double> estimatedValues, 118 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 119 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, 120 DataPartitionEnum dataPartition = DataPartitionEnum.Training) { 121 IEnumerable<int> rows = GetPartitionRows(dataPartition, problemData); 122 return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod); 123 } 124 125 126 public static IEnumerable<Tuple<string, double>> CalculateImpacts( 127 IClassificationModel model, 128 IClassificationProblemData problemData, 129 IEnumerable<double> estimatedClassValues, 130 IEnumerable<int> rows, 131 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 132 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 133 //Calculate original quality-values (via calculator, default is Accuracy) 134 OnlineCalculatorError error; 135 IEnumerable<double> targetValuesPartition = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 136 IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedClassValues.ElementAt(v)); 137 var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error); 138 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation."); 139 140 var impacts = new Dictionary<string, double>(); 141 var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction)); 142 var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList(); 143 var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable(); 144 145 foreach (var inputVariable in allowedInputVariables) { 146 impacts[inputVariable] = CalculateImpact(inputVariable, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod); 147 } 148 149 return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value)); 150 } 151 152 153 public static double CalculateImpact(string variableName, 154 IClassificationModel model, 155 ModifiableDataset modifiableDataset, 156 IEnumerable<int> rows, 157 IEnumerable<double> targetValues, 158 double originalValue, 159 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 98 160 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 99 100 var problemData = solution.ProblemData; 101 var dataset = problemData.Dataset; 102 161 double impact = 0; 162 OnlineCalculatorError error; 163 IRandom random; 164 double replacementValue; 165 IEnumerable<double> newEstimates = null; 166 double newValue = 0; 167 168 if (modifiableDataset.VariableHasType<double>(variableName)) { 169 #region NumericalVariable 170 var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList(); 171 List<double> replacementValues; 172 IRandom rand; 173 174 switch (replacementMethod) { 175 case ReplacementMethodEnum.Median: 176 replacementValue = rows.Select(r => originalValues[r]).Median(); 177 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 178 break; 179 case ReplacementMethodEnum.Average: 180 replacementValue = rows.Select(r => originalValues[r]).Average(); 181 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 182 break; 183 case ReplacementMethodEnum.Shuffle: 184 // new var has same empirical distribution but the relation to y is broken 185 rand = new FastRandom(31415); 186 // prepare a complete column for the dataset 187 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 188 // shuffle only the selected rows 189 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList(); 190 int i = 0; 191 // update column values 192 foreach (var r in rows) { 193 replacementValues[r] = shuffledValues[i++]; 194 } 195 break; 196 case ReplacementMethodEnum.Noise: 197 var avg = rows.Select(r => originalValues[r]).Average(); 198 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 199 rand = new FastRandom(31415); 200 // prepare a complete column for the dataset 201 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 202 // update column values 203 foreach (var r in rows) { 204 replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev); 205 } 206 break; 207 208 default: 209 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod)); 210 } 211 212 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 213 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 214 if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 215 216 impact = originalValue - newValue; 217 #endregion 218 } else if (modifiableDataset.VariableHasType<string>(variableName)) { 219 #region FactorVariable 220 var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 221 List<string> replacementValues; 222 223 switch (factorReplacementMethod) { 224 case FactorReplacementMethodEnum.Best: 225 // try replacing with all possible values and find the best replacement value 226 var smallestImpact = double.PositiveInfinity; 227 foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { 228 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList()); 229 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 230 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 231 232 var curImpact = originalValue - newValue; 233 if (curImpact < smallestImpact) smallestImpact = curImpact; 234 } 235 impact = smallestImpact; 236 break; 237 case FactorReplacementMethodEnum.Mode: 238 var mostCommonValue = rows.Select(r => originalValues[r]) 239 .GroupBy(v => v) 240 .OrderByDescending(g => g.Count()) 241 .First().Key; 242 replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList(); 243 244 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 245 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 246 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 247 248 impact = originalValue - newValue; 249 break; 250 case FactorReplacementMethodEnum.Shuffle: 251 // new var has same empirical distribution but the relation to y is broken 252 random = new FastRandom(31415); 253 // prepare a complete column for the dataset 254 replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList(); 255 // shuffle only the selected rows 256 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 257 int i = 0; 258 // update column values 259 foreach (var r in rows) { 260 replacementValues[r] = shuffledValues[i++]; 261 } 262 263 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 264 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 265 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 266 267 impact = originalValue - newValue; 268 break; 269 default: 270 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod)); 271 } 272 #endregion 273 } else { 274 throw new NotSupportedException("Variable not supported"); 275 } 276 277 return impact; 278 } 279 280 /// <summary> 281 /// Calculates and returns the VariableImpact (calculated via Accuracy). 282 /// </summary> 283 /// <param name="targetValues">The actual values</param> 284 /// <param name="estimatedValues">The calculated/replaced values</param> 285 /// <param name="errorState"></param> 286 /// <returns></returns> 287 public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) { 288 //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality 289 //as the code below does. But this way we can easily swap the calculator later on, so the user 290 //could choose a Calculator during runtime in future versions. 291 IOnlineCalculator calculator = new OnlineAccuracyCalculator(); 292 IEnumerator<double> firstEnumerator = targetValues.GetEnumerator(); 293 IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator(); 294 295 // always move forward both enumerators (do not use short-circuit evaluation!) 296 while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) { 297 double original = firstEnumerator.Current; 298 double estimated = secondEnumerator.Current; 299 calculator.Add(original, estimated); 300 if (calculator.ErrorState != OnlineCalculatorError.None) break; 301 } 302 303 // check if both enumerators are at the end to make sure both enumerations have the same length 304 if (calculator.ErrorState == OnlineCalculatorError.None && 305 (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) { 306 throw new ArgumentException("Number of elements in first and second enumeration doesn't match."); 307 } else { 308 errorState = calculator.ErrorState; 309 return calculator.Value; 310 } 311 } 312 313 /// <summary> 314 /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values 315 /// and changes the value of the model-variables back to the original ones. 316 /// </summary> 317 /// <param name="originalValues"></param> 318 /// <param name="model"></param> 319 /// <param name="variableName"></param> 320 /// <param name="modifiableDataset"></param> 321 /// <param name="rows"></param> 322 /// <param name="replacementValues"></param> 323 /// <returns></returns> 324 private static IEnumerable<double> GetReplacedEstimates( 325 IList originalValues, 326 IClassificationModel model, 327 string variableName, 328 ModifiableDataset modifiableDataset, 329 IEnumerable<int> rows, 330 IList replacementValues) { 331 modifiableDataset.ReplaceVariable(variableName, replacementValues); 332 333 var discModel = model as IDiscriminantFunctionClassificationModel; 334 if (discModel != null) { 335 var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable); 336 discModel.RecalculateModelParameters(problemData, rows); 337 } 338 339 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 340 var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList(); 341 modifiableDataset.ReplaceVariable(variableName, originalValues); 342 343 return estimates; 344 } 345 346 347 /// <summary> 348 /// Returns a collection of the row-indices for a given DataPartition (training or test) 349 /// </summary> 350 /// <param name="dataPartition"></param> 351 /// <param name="problemData"></param> 352 /// <returns></returns> 353 public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IClassificationProblemData problemData) { 103 354 IEnumerable<int> rows; 104 IEnumerable<double> targetValues; 105 double originalAccuracy; 106 107 OnlineCalculatorError error; 108 109 switch (data) { 355 356 switch (dataPartition) { 110 357 case DataPartitionEnum.All: 111 358 rows = problemData.AllIndices; 112 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.AllIndices).ToList();113 originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedClassValues, out error);114 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");359 break; 360 case DataPartitionEnum.Test: 361 rows = problemData.TestIndices; 115 362 break; 116 363 case DataPartitionEnum.Training: 117 364 rows = problemData.TrainingIndices; 118 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToList();119 originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTrainingClassValues, out error);120 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");121 break;122 case DataPartitionEnum.Test:123 rows = problemData.TestIndices;124 targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToList();125 originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTestClassValues, out error);126 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");127 break;128 default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data));129 }130 131 var impacts = new Dictionary<string, double>();132 var modifiableDataset = ((Dataset)dataset).ToModifiable();133 134 var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));135 var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();136 137 // calculate impacts for double variables138 foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) {139 var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, replacementMethod);140 var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);141 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");142 143 impacts[inputVariable] = originalAccuracy - newAccuracy;144 }145 146 // calculate impacts for string variables147 foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) {148 if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {149 // try replacing with all possible values and find the best replacement value150 var smallestImpact = double.PositiveInfinity;151 foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) {152 var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,153 Enumerable.Repeat(repl, dataset.Rows));154 var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);155 if (error != OnlineCalculatorError.None)156 throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");157 158 var impact = originalAccuracy - newAccuracy;159 if (impact < smallestImpact) smallestImpact = impact;160 }161 impacts[inputVariable] = smallestImpact;162 } else {163 // for replacement methods shuffle and mode164 // calculate impacts for factor variables165 166 var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,167 factorReplacementMethod);168 var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);169 if (error != OnlineCalculatorError.None)170 throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");171 172 impacts[inputVariable] = originalAccuracy - newAccuracy;173 }174 } // foreach175 return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));176 }177 178 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {179 var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();180 double replacementValue;181 List<double> replacementValues;182 IRandom rand;183 184 switch (replacement) {185 case ReplacementMethodEnum.Median:186 replacementValue = rows.Select(r => originalValues[r]).Median();187 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();188 break;189 case ReplacementMethodEnum.Average:190 replacementValue = rows.Select(r => originalValues[r]).Average();191 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();192 break;193 case ReplacementMethodEnum.Shuffle:194 // new var has same empirical distribution but the relation to y is broken195 rand = new FastRandom(31415);196 // prepare a complete column for the dataset197 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();198 // shuffle only the selected rows199 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();200 int i = 0;201 // update column values202 foreach (var r in rows) {203 replacementValues[r] = shuffledValues[i++];204 }205 break;206 case ReplacementMethodEnum.Noise:207 var avg = rows.Select(r => originalValues[r]).Average();208 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();209 rand = new FastRandom(31415);210 // prepare a complete column for the dataset211 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();212 // update column values213 foreach (var r in rows) {214 replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);215 }216 break;217 218 default:219 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));220 }221 222 return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);223 }224 225 private static IEnumerable<double> EvaluateModelWithReplacedVariable(226 IClassificationModel model, string variable, ModifiableDataset dataset,227 IEnumerable<int> rows,228 FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {229 var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();230 List<string> replacementValues;231 IRandom rand;232 233 switch (replacement) {234 case FactorReplacementMethodEnum.Mode:235 var mostCommonValue = rows.Select(r => originalValues[r])236 .GroupBy(v => v)237 .OrderByDescending(g => g.Count())238 .First().Key;239 replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();240 break;241 case FactorReplacementMethodEnum.Shuffle:242 // new var has same empirical distribution but the relation to y is broken243 rand = new FastRandom(31415);244 // prepare a complete column for the dataset245 replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();246 // shuffle only the selected rows247 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();248 int i = 0;249 // update column values250 foreach (var r in rows) {251 replacementValues[r] = shuffledValues[i++];252 }253 365 break; 254 366 default: 255 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement)); 256 } 257 258 return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues); 259 } 260 261 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, 262 ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) { 263 var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList(); 264 dataset.ReplaceVariable(variable, replacementValues.ToList()); 265 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 266 var estimates = model.GetEstimatedClassValues(dataset, rows).ToList(); 267 dataset.ReplaceVariable(variable, originalValues); 268 269 return estimates; 270 } 271 private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, 272 ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) { 273 var originalValues = dataset.GetReadOnlyStringValues(variable).ToList(); 274 dataset.ReplaceVariable(variable, replacementValues.ToList()); 275 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 276 var estimates = model.GetEstimatedClassValues(dataset, rows).ToList(); 277 dataset.ReplaceVariable(variable, originalValues); 278 279 return estimates; 280 } 367 throw new NotSupportedException("DataPartition not supported"); 368 } 369 370 return rows; 371 } 372 281 373 } 282 374 } -
branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs
r16035 r16036 100 100 #endregion 101 101 102 #region Public Methods/Wrappers103 102 //mkommend: annoying name clash with static method, open to better naming suggestions 104 103 public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) { … … 159 158 ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, 160 159 FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { 161 162 160 double impact = 0; 163 164 // calculate impacts for double variables 161 OnlineCalculatorError error; 162 IRandom random; 163 double replacementValue; 164 IEnumerable<double> newEstimates = null; 165 double newValue = 0; 166 165 167 if (modifiableDataset.VariableHasType<double>(variableName)) { 166 impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod); 168 #region NumericalVariable 169 var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList(); 170 List<double> replacementValues; 171 172 switch (replacementMethod) { 173 case ReplacementMethodEnum.Median: 174 replacementValue = rows.Select(r => originalValues[r]).Median(); 175 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 176 break; 177 case ReplacementMethodEnum.Average: 178 replacementValue = rows.Select(r => originalValues[r]).Average(); 179 replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); 180 break; 181 case ReplacementMethodEnum.Shuffle: 182 // new var has same empirical distribution but the relation to y is broken 183 random = new FastRandom(31415); 184 // prepare a complete column for the dataset 185 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 186 // shuffle only the selected rows 187 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 188 int i = 0; 189 // update column values 190 foreach (var r in rows) { 191 replacementValues[r] = shuffledValues[i++]; 192 } 193 break; 194 case ReplacementMethodEnum.Noise: 195 var avg = rows.Select(r => originalValues[r]).Average(); 196 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 197 random = new FastRandom(31415); 198 // prepare a complete column for the dataset 199 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 200 // update column values 201 foreach (var r in rows) { 202 replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev); 203 } 204 break; 205 206 default: 207 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod)); 208 } 209 210 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 211 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 212 if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 213 214 impact = originalValue - newValue; 215 #endregion 167 216 } else if (modifiableDataset.VariableHasType<string>(variableName)) { 168 impact = CalculateImpactForFactorVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod); 217 #region FactorVariable 218 var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 219 List<string> replacementValues; 220 221 switch (factorReplacementMethod) { 222 case FactorReplacementMethodEnum.Best: 223 // try replacing with all possible values and find the best replacement value 224 var smallestImpact = double.PositiveInfinity; 225 foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { 226 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList()); 227 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 228 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 229 230 var curImpact = originalValue - newValue; 231 if (curImpact < smallestImpact) smallestImpact = curImpact; 232 } 233 impact = smallestImpact; 234 break; 235 case FactorReplacementMethodEnum.Mode: 236 var mostCommonValue = rows.Select(r => originalValues[r]) 237 .GroupBy(v => v) 238 .OrderByDescending(g => g.Count()) 239 .First().Key; 240 replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList(); 241 242 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 243 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 244 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 245 246 impact = originalValue - newValue; 247 break; 248 case FactorReplacementMethodEnum.Shuffle: 249 // new var has same empirical distribution but the relation to y is broken 250 random = new FastRandom(31415); 251 // prepare a complete column for the dataset 252 replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList(); 253 // shuffle only the selected rows 254 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 255 int i = 0; 256 // update column values 257 foreach (var r in rows) { 258 replacementValues[r] = shuffledValues[i++]; 259 } 260 261 newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); 262 newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 263 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 264 265 impact = originalValue - newValue; 266 break; 267 default: 268 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod)); 269 } 270 #endregion 169 271 } else { 170 272 throw new NotSupportedException("Variable not supported"); 171 273 } 274 172 275 return impact; 173 276 } 174 #endregion 175 176 private static double CalculateImpactForNumericalVariables(string variableName, 177 IRegressionModel model, 178 ModifiableDataset modifiableDataset, 179 IEnumerable<int> rows, 180 IEnumerable<double> targetValues, 181 double originalValue, 182 ReplacementMethodEnum replacementMethod) { 183 OnlineCalculatorError error; 184 var newEstimates = GetReplacedValuesForNumericalVariables(model, variableName, modifiableDataset, rows, replacementMethod); 185 var newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 186 if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } 187 return originalValue - newValue; 188 } 189 190 private static double CalculateImpactForFactorVariables(string variableName, 191 IRegressionModel model, 192 ModifiableDataset modifiableDataset, 193 IEnumerable<int> rows, 194 IEnumerable<double> targetValues, 195 double originalValue, 196 FactorReplacementMethodEnum factorReplacementMethod) { 197 198 OnlineCalculatorError error; 199 if (factorReplacementMethod == FactorReplacementMethodEnum.Best) { 200 // try replacing with all possible values and find the best replacement value 201 var smallestImpact = double.PositiveInfinity; 202 foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { 203 var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); 204 var newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList()); 205 var newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 206 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 207 208 var curImpact = originalValue - newValue; 209 if (curImpact < smallestImpact) smallestImpact = curImpact; 210 } 211 return smallestImpact; 212 } else { 213 // for replacement methods shuffle and mode 214 // calculate impacts for factor variables 215 var newEstimates = GetReplacedValuesForFactorVariables(model, variableName, modifiableDataset, rows, factorReplacementMethod); 216 var newValue = CalculateVariableImpact(targetValues, newEstimates, out error); 217 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); 218 219 return originalValue - newValue; 220 } 221 } 222 223 private static IEnumerable<double> GetReplacedValuesForNumericalVariables( 224 IRegressionModel model, 225 string variable, 226 ModifiableDataset dataset, 227 IEnumerable<int> rows, 228 ReplacementMethodEnum replacement = ReplacementMethodEnum.Shuffle) { 229 var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList(); 230 double replacementValue; 231 List<double> replacementValues; 232 IRandom rand; 233 234 switch (replacement) { 235 case ReplacementMethodEnum.Median: 236 replacementValue = rows.Select(r => originalValues[r]).Median(); 237 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList(); 238 break; 239 case ReplacementMethodEnum.Average: 240 replacementValue = rows.Select(r => originalValues[r]).Average(); 241 replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList(); 242 break; 243 case ReplacementMethodEnum.Shuffle: 244 // new var has same empirical distribution but the relation to y is broken 245 rand = new FastRandom(31415); 246 // prepare a complete column for the dataset 247 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList(); 248 // shuffle only the selected rows 249 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList(); 250 int i = 0; 251 // update column values 252 foreach (var r in rows) { 253 replacementValues[r] = shuffledValues[i++]; 254 } 255 break; 256 case ReplacementMethodEnum.Noise: 257 var avg = rows.Select(r => originalValues[r]).Average(); 258 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 259 rand = new FastRandom(31415); 260 // prepare a complete column for the dataset 261 replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList(); 262 // update column values 263 foreach (var r in rows) { 264 replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev); 265 } 266 break; 267 268 default: 269 throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement)); 270 } 271 272 return GetReplacedEstimates(originalValues, model, variable, dataset, rows, replacementValues); 273 } 274 275 private static IEnumerable<double> GetReplacedValuesForFactorVariables( 276 IRegressionModel model, 277 string variable, 278 ModifiableDataset dataset, 279 IEnumerable<int> rows, 280 FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) { 281 var originalValues = dataset.GetReadOnlyStringValues(variable).ToList(); 282 List<string> replacementValues; 283 IRandom rand; 284 285 switch (replacement) { 286 case FactorReplacementMethodEnum.Mode: 287 var mostCommonValue = rows.Select(r => originalValues[r]) 288 .GroupBy(v => v) 289 .OrderByDescending(g => g.Count()) 290 .First().Key; 291 replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList(); 292 break; 293 case FactorReplacementMethodEnum.Shuffle: 294 // new var has same empirical distribution but the relation to y is broken 295 rand = new FastRandom(31415); 296 // prepare a complete column for the dataset 297 replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList(); 298 // shuffle only the selected rows 299 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList(); 300 int i = 0; 301 // update column values 302 foreach (var r in rows) { 303 replacementValues[r] = shuffledValues[i++]; 304 } 305 break; 306 default: 307 throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement)); 308 } 309 310 return GetReplacedEstimates(originalValues, model, variable, dataset, rows, replacementValues); 311 } 312 277 278 /// <summary> 279 /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values 280 /// and changes the value of the model-variables back to the original ones. 281 /// </summary> 282 /// <param name="originalValues"></param> 283 /// <param name="model"></param> 284 /// <param name="variableName"></param> 285 /// <param name="modifiableDataset"></param> 286 /// <param name="rows"></param> 287 /// <param name="replacementValues"></param> 288 /// <returns></returns> 313 289 private static IEnumerable<double> GetReplacedEstimates( 314 290 IList originalValues, 315 291 IRegressionModel model, 316 string variable ,317 ModifiableDataset dataset,292 string variableName, 293 ModifiableDataset modifiableDataset, 318 294 IEnumerable<int> rows, 319 295 IList replacementValues) { 320 dataset.ReplaceVariable(variable, replacementValues);296 modifiableDataset.ReplaceVariable(variableName, replacementValues); 321 297 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 322 var estimates = model.GetEstimatedValues( dataset, rows).ToList();323 dataset.ReplaceVariable(variable, originalValues);298 var estimates = model.GetEstimatedValues(modifiableDataset, rows).ToList(); 299 modifiableDataset.ReplaceVariable(variableName, originalValues); 324 300 325 301 return estimates; 326 302 } 327 303 328 public static double CalculateVariableImpact(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) { 329 IEnumerator<double> firstEnumerator = originalValues.GetEnumerator(); 304 /// <summary> 305 /// Calculates and returns the VariableImpact (calculated via Pearsons R²). 306 /// </summary> 307 /// <param name="targetValues">The actual values</param> 308 /// <param name="estimatedValues">The calculated/replaced values</param> 309 /// <param name="errorState"></param> 310 /// <returns></returns> 311 public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) { 312 //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality 313 //as the code below does. But this way we can easily swap the calculator later on, so the user 314 //could choose a Calculator during runtime in future versions. 315 IOnlineCalculator calculator = new OnlinePearsonsRSquaredCalculator(); 316 IEnumerator<double> firstEnumerator = targetValues.GetEnumerator(); 330 317 IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator(); 331 var calculator = new OnlinePearsonsRSquaredCalculator();332 318 333 319 // always move forward both enumerators (do not use short-circuit evaluation!) … … 349 335 } 350 336 337 /// <summary> 338 /// Returns a collection of the row-indices for a given DataPartition (training or test) 339 /// </summary> 340 /// <param name="dataPartition"></param> 341 /// <param name="problemData"></param> 342 /// <returns></returns> 351 343 public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IRegressionProblemData problemData) { 352 344 IEnumerable<int> rows; -
branches/2904_CalculateImpacts/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationSolutionVariableImpactsView.Designer.cs
r15753 r16036 19 19 */ 20 20 #endregion 21 22 21 23 namespace HeuristicLab.Problems.DataAnalysis.Views { 22 24 partial class ClassificationSolutionVariableImpactsView { … … 44 46 /// </summary> 45 47 private void InitializeComponent() { 46 this.variableImactsArrayView = new HeuristicLab.Data.Views.StringConvertibleArrayView();47 this.dataPartitionComboBox = new System.Windows.Forms.ComboBox();48 this.dataPartitionLabel = new System.Windows.Forms.Label();49 this.numericVarReplacementLabel = new System.Windows.Forms.Label();50 this.replacementComboBox = new System.Windows.Forms.ComboBox();51 this.factorVarReplacementLabel = new System.Windows.Forms.Label();52 this.factorVarReplComboBox = new System.Windows.Forms.ComboBox();53 48 this.ascendingCheckBox = new System.Windows.Forms.CheckBox(); 54 49 this.sortByLabel = new System.Windows.Forms.Label(); 55 50 this.sortByComboBox = new System.Windows.Forms.ComboBox(); 56 this.backgroundWorker = new System.ComponentModel.BackgroundWorker(); 51 this.factorVarReplComboBox = new System.Windows.Forms.ComboBox(); 52 this.factorVarReplacementLabel = new System.Windows.Forms.Label(); 53 this.replacementComboBox = new System.Windows.Forms.ComboBox(); 54 this.numericVarReplacementLabel = new System.Windows.Forms.Label(); 55 this.dataPartitionLabel = new System.Windows.Forms.Label(); 56 this.dataPartitionComboBox = new System.Windows.Forms.ComboBox(); 57 this.variableImactsArrayView = new HeuristicLab.Data.Views.StringConvertibleArrayView(); 57 58 this.SuspendLayout(); 58 59 // 59 // variableImactsArrayView 60 // 61 this.variableImactsArrayView.Anchor = ((System.Windows.Forms.AnchorStyles)((((System.Windows.Forms.AnchorStyles.Top | System.Windows.Forms.AnchorStyles.Bottom) 62 | System.Windows.Forms.AnchorStyles.Left) 63 | System.Windows.Forms.AnchorStyles.Right))); 64 this.variableImactsArrayView.Caption = "StringConvertibleArray View"; 65 this.variableImactsArrayView.Content = null; 66 this.variableImactsArrayView.Location = new System.Drawing.Point(3, 84); 67 this.variableImactsArrayView.Name = "variableImactsArrayView"; 68 this.variableImactsArrayView.ReadOnly = true; 69 this.variableImactsArrayView.Size = new System.Drawing.Size(662, 278); 70 this.variableImactsArrayView.TabIndex = 2; 71 // 72 // dataPartitionComboBox 73 // 74 this.dataPartitionComboBox.DropDownStyle = System.Windows.Forms.ComboBoxStyle.DropDownList; 75 this.dataPartitionComboBox.FormattingEnabled = true; 76 this.dataPartitionComboBox.Items.AddRange(new object[] { 77 HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum.Training, 78 HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum.Test, 79 HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum.All}); 80 this.dataPartitionComboBox.Location = new System.Drawing.Point(197, 3); 81 this.dataPartitionComboBox.Name = "dataPartitionComboBox"; 82 this.dataPartitionComboBox.Size = new System.Drawing.Size(121, 21); 83 this.dataPartitionComboBox.TabIndex = 1; 84 this.dataPartitionComboBox.SelectedIndexChanged += new System.EventHandler(this.dataPartitionComboBox_SelectedIndexChanged); 85 // 86 // dataPartitionLabel 87 // 88 this.dataPartitionLabel.AutoSize = true; 89 this.dataPartitionLabel.Location = new System.Drawing.Point(3, 6); 90 this.dataPartitionLabel.Name = "dataPartitionLabel"; 91 this.dataPartitionLabel.Size = new System.Drawing.Size(73, 13); 92 this.dataPartitionLabel.TabIndex = 0; 93 this.dataPartitionLabel.Text = "Data partition:"; 94 // 95 // numericVarReplacementLabel 96 // 97 this.numericVarReplacementLabel.AutoSize = true; 98 this.numericVarReplacementLabel.Location = new System.Drawing.Point(3, 33); 99 this.numericVarReplacementLabel.Name = "numericVarReplacementLabel"; 100 this.numericVarReplacementLabel.Size = new System.Drawing.Size(173, 13); 101 this.numericVarReplacementLabel.TabIndex = 2; 102 this.numericVarReplacementLabel.Text = "Replacement for numeric variables:"; 60 // ascendingCheckBox 61 // 62 this.ascendingCheckBox.AutoSize = true; 63 this.ascendingCheckBox.Location = new System.Drawing.Point(534, 6); 64 this.ascendingCheckBox.Name = "ascendingCheckBox"; 65 this.ascendingCheckBox.Size = new System.Drawing.Size(76, 17); 66 this.ascendingCheckBox.TabIndex = 7; 67 this.ascendingCheckBox.Text = "Ascending"; 68 this.ascendingCheckBox.UseVisualStyleBackColor = true; 69 this.ascendingCheckBox.CheckedChanged += new System.EventHandler(this.ascendingCheckBox_CheckedChanged); 70 // 71 // sortByLabel 72 // 73 this.sortByLabel.AutoSize = true; 74 this.sortByLabel.Location = new System.Drawing.Point(324, 6); 75 this.sortByLabel.Name = "sortByLabel"; 76 this.sortByLabel.Size = new System.Drawing.Size(77, 13); 77 this.sortByLabel.TabIndex = 4; 78 this.sortByLabel.Text = "Sorting criteria:"; 79 // 80 // sortByComboBox 81 // 82 this.sortByComboBox.DropDownStyle = System.Windows.Forms.ComboBoxStyle.DropDownList; 83 this.sortByComboBox.FormattingEnabled = true; 84 this.sortByComboBox.Items.AddRange(new object[] { 85 HeuristicLab.Problems.DataAnalysis.Views.ClassificationSolutionVariableImpactsView.SortingCriteria.ImpactValue, 86 HeuristicLab.Problems.DataAnalysis.Views.ClassificationSolutionVariableImpactsView.SortingCriteria.Occurrence, 87 HeuristicLab.Problems.DataAnalysis.Views.ClassificationSolutionVariableImpactsView.SortingCriteria.VariableName}); 88 this.sortByComboBox.Location = new System.Drawing.Point(407, 3); 89 this.sortByComboBox.Name = "sortByComboBox"; 90 this.sortByComboBox.Size = new System.Drawing.Size(121, 21); 91 this.sortByComboBox.TabIndex = 5; 92 this.sortByComboBox.SelectedIndexChanged += new System.EventHandler(this.sortByComboBox_SelectedIndexChanged); 93 // 94 // factorVarReplComboBox 95 // 96 this.factorVarReplComboBox.DropDownStyle = System.Windows.Forms.ComboBoxStyle.DropDownList; 97 this.factorVarReplComboBox.FormattingEnabled = true; 98 this.factorVarReplComboBox.Items.AddRange(new object[] { 99 HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Best, 100 HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Mode, 101 HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Shuffle}); 102 this.factorVarReplComboBox.Location = new System.Drawing.Point(197, 57); 103 this.factorVarReplComboBox.Name = "factorVarReplComboBox"; 104 this.factorVarReplComboBox.Size = new System.Drawing.Size(121, 21); 105 this.factorVarReplComboBox.TabIndex = 1; 106 this.factorVarReplComboBox.SelectedIndexChanged += new System.EventHandler(this.replacementComboBox_SelectedIndexChanged); 107 // 108 // factorVarReplacementLabel 109 // 110 this.factorVarReplacementLabel.AutoSize = true; 111 this.factorVarReplacementLabel.Location = new System.Drawing.Point(3, 60); 112 this.factorVarReplacementLabel.Name = "factorVarReplacementLabel"; 113 this.factorVarReplacementLabel.Size = new System.Drawing.Size(188, 13); 114 this.factorVarReplacementLabel.TabIndex = 0; 115 this.factorVarReplacementLabel.Text = "Replacement for categorical variables:"; 103 116 // 104 117 // replacementComboBox … … 117 130 this.replacementComboBox.SelectedIndexChanged += new System.EventHandler(this.replacementComboBox_SelectedIndexChanged); 118 131 // 119 // factorVarReplacementLabel 120 // 121 this.factorVarReplacementLabel.AutoSize = true; 122 this.factorVarReplacementLabel.Location = new System.Drawing.Point(3, 60); 123 this.factorVarReplacementLabel.Name = "factorVarReplacementLabel"; 124 this.factorVarReplacementLabel.Size = new System.Drawing.Size(188, 13); 125 this.factorVarReplacementLabel.TabIndex = 0; 126 this.factorVarReplacementLabel.Text = "Replacement for categorical variables:"; 127 // 128 // factorVarReplComboBox 129 // 130 this.factorVarReplComboBox.DropDownStyle = System.Windows.Forms.ComboBoxStyle.DropDownList; 131 this.factorVarReplComboBox.FormattingEnabled = true; 132 this.factorVarReplComboBox.Items.AddRange(new object[] { 133 HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Best, 134 HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Mode, 135 HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Shuffle}); 136 this.factorVarReplComboBox.Location = new System.Drawing.Point(197, 57); 137 this.factorVarReplComboBox.Name = "factorVarReplComboBox"; 138 this.factorVarReplComboBox.Size = new System.Drawing.Size(121, 21); 139 this.factorVarReplComboBox.TabIndex = 1; 140 this.factorVarReplComboBox.SelectedIndexChanged += new System.EventHandler(this.replacementComboBox_SelectedIndexChanged); 141 // 142 // ascendingCheckBox 143 // 144 this.ascendingCheckBox.AutoSize = true; 145 this.ascendingCheckBox.Location = new System.Drawing.Point(534, 6); 146 this.ascendingCheckBox.Name = "ascendingCheckBox"; 147 this.ascendingCheckBox.Size = new System.Drawing.Size(76, 17); 148 this.ascendingCheckBox.TabIndex = 10; 149 this.ascendingCheckBox.Text = "Ascending"; 150 this.ascendingCheckBox.UseVisualStyleBackColor = true; 151 this.ascendingCheckBox.CheckedChanged += new System.EventHandler(this.ascendingCheckBox_CheckedChanged); 152 // 153 // sortByLabel 154 // 155 this.sortByLabel.AutoSize = true; 156 this.sortByLabel.Location = new System.Drawing.Point(324, 6); 157 this.sortByLabel.Name = "sortByLabel"; 158 this.sortByLabel.Size = new System.Drawing.Size(77, 13); 159 this.sortByLabel.TabIndex = 8; 160 this.sortByLabel.Text = "Sorting criteria:"; 161 // 162 // sortByComboBox 163 // 164 this.sortByComboBox.DropDownStyle = System.Windows.Forms.ComboBoxStyle.DropDownList; 165 this.sortByComboBox.FormattingEnabled = true; 166 this.sortByComboBox.Location = new System.Drawing.Point(407, 3); 167 this.sortByComboBox.Name = "sortByComboBox"; 168 this.sortByComboBox.Size = new System.Drawing.Size(121, 21); 169 this.sortByComboBox.TabIndex = 9; 170 this.sortByComboBox.SelectedIndexChanged += new System.EventHandler(this.sortByComboBox_SelectedIndexChanged); 132 // numericVarReplacementLabel 133 // 134 this.numericVarReplacementLabel.AutoSize = true; 135 this.numericVarReplacementLabel.Location = new System.Drawing.Point(3, 33); 136 this.numericVarReplacementLabel.Name = "numericVarReplacementLabel"; 137 this.numericVarReplacementLabel.Size = new System.Drawing.Size(173, 13); 138 this.numericVarReplacementLabel.TabIndex = 2; 139 this.numericVarReplacementLabel.Text = "Replacement for numeric variables:"; 140 // 141 // dataPartitionLabel 142 // 143 this.dataPartitionLabel.AutoSize = true; 144 this.dataPartitionLabel.Location = new System.Drawing.Point(3, 6); 145 this.dataPartitionLabel.Name = "dataPartitionLabel"; 146 this.dataPartitionLabel.Size = new System.Drawing.Size(73, 13); 147 this.dataPartitionLabel.TabIndex = 0; 148 this.dataPartitionLabel.Text = "Data partition:"; 149 // 150 // dataPartitionComboBox 151 // 152 this.dataPartitionComboBox.DropDownStyle = System.Windows.Forms.ComboBoxStyle.DropDownList; 153 this.dataPartitionComboBox.FormattingEnabled = true; 154 this.dataPartitionComboBox.Items.AddRange(new object[] { 155 HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum.Training, 156 HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum.Test, 157 HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum.All}); 158 this.dataPartitionComboBox.Location = new System.Drawing.Point(197, 3); 159 this.dataPartitionComboBox.Name = "dataPartitionComboBox"; 160 this.dataPartitionComboBox.Size = new System.Drawing.Size(121, 21); 161 this.dataPartitionComboBox.TabIndex = 1; 162 this.dataPartitionComboBox.SelectedIndexChanged += new System.EventHandler(this.dataPartitionComboBox_SelectedIndexChanged); 163 // 164 // variableImactsArrayView 165 // 166 this.variableImactsArrayView.Anchor = ((System.Windows.Forms.AnchorStyles)((((System.Windows.Forms.AnchorStyles.Top | System.Windows.Forms.AnchorStyles.Bottom) 167 | System.Windows.Forms.AnchorStyles.Left) 168 | System.Windows.Forms.AnchorStyles.Right))); 169 this.variableImactsArrayView.Caption = "StringConvertibleArray View"; 170 this.variableImactsArrayView.Content = null; 171 this.variableImactsArrayView.Location = new System.Drawing.Point(3, 84); 172 this.variableImactsArrayView.Name = "variableImactsArrayView"; 173 this.variableImactsArrayView.ReadOnly = true; 174 this.variableImactsArrayView.Size = new System.Drawing.Size(706, 278); 175 this.variableImactsArrayView.TabIndex = 2; 171 176 // 172 177 // ClassificationSolutionVariableImpactsView … … 185 190 this.Controls.Add(this.variableImactsArrayView); 186 191 this.Name = "ClassificationSolutionVariableImpactsView"; 187 this.Size = new System.Drawing.Size( 668, 365);192 this.Size = new System.Drawing.Size(712, 365); 188 193 this.VisibleChanged += new System.EventHandler(this.ClassificationSolutionVariableImpactsView_VisibleChanged); 189 194 this.ResumeLayout(false); … … 201 206 private System.Windows.Forms.Label factorVarReplacementLabel; 202 207 private System.Windows.Forms.ComboBox factorVarReplComboBox; 203 private System.Windows.Forms.CheckBox ascendingCheckBox;204 208 private System.Windows.Forms.Label sortByLabel; 205 209 private System.Windows.Forms.ComboBox sortByComboBox; 206 private System. ComponentModel.BackgroundWorker backgroundWorker;210 private System.Windows.Forms.CheckBox ascendingCheckBox; 207 211 } 208 212 } -
branches/2904_CalculateImpacts/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationSolutionVariableImpactsView.cs
r15753 r16036 33 33 [Content(typeof(IClassificationSolution))] 34 34 public partial class ClassificationSolutionVariableImpactsView : DataAnalysisSolutionEvaluationView { 35 #region Nested Types36 35 private enum SortingCriteria { 37 36 ImpactValue, … … 39 38 VariableName 40 39 } 41 #endregion 42 43 #region Fields 44 private Dictionary<string, double> rawVariableImpacts = new Dictionary<string, double>(); 45 private Thread thread; 46 #endregion 47 48 #region Getter/Setter 40 private CancellationTokenSource cancellationToken = new CancellationTokenSource(); 41 private List<Tuple<string, double>> rawVariableImpacts = new List<Tuple<string, double>>(); 42 49 43 public new IClassificationSolution Content { 50 44 get { return (IClassificationSolution)base.Content; } … … 53 47 } 54 48 } 55 #endregion 56 57 #region Ctor 49 58 50 public ClassificationSolutionVariableImpactsView() 59 51 : base() { 60 52 InitializeComponent(); 61 53 62 //Little workaround. If you fill the ComboBox-Items in the other partial class, the UI-Designer will moan.63 this.sortByComboBox.Items.AddRange(Enum.GetValues(typeof(SortingCriteria)).Cast<object>().ToArray());64 this.sortByComboBox.SelectedItem = SortingCriteria.ImpactValue;65 66 54 //Set the default values 67 55 this.dataPartitionComboBox.SelectedIndex = 0; 68 this.replacementComboBox.SelectedIndex = 0;56 this.replacementComboBox.SelectedIndex = 3; 69 57 this.factorVarReplComboBox.SelectedIndex = 0; 70 } 71 #endregion 72 73 #region Events 58 this.sortByComboBox.SelectedItem = SortingCriteria.ImpactValue; 59 } 60 74 61 protected override void RegisterContentEvents() { 75 62 base.RegisterContentEvents(); … … 77 64 Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged); 78 65 } 79 80 66 protected override void DeregisterContentEvents() { 81 67 base.DeregisterContentEvents(); … … 87 73 OnContentChanged(); 88 74 } 89 90 75 protected virtual void Content_ModelChanged(object sender, EventArgs e) { 91 76 OnContentChanged(); 92 77 } 93 94 78 protected override void OnContentChanged() { 95 79 base.OnContentChanged(); … … 100 84 } 101 85 } 102 103 86 private void ClassificationSolutionVariableImpactsView_VisibleChanged(object sender, EventArgs e) { 104 if (thread == null) { return; } 105 106 if (thread.IsAlive) { thread.Abort(); } 107 thread = null; 108 } 109 87 cancellationToken.Cancel(); 88 } 110 89 111 90 private void dataPartitionComboBox_SelectedIndexChanged(object sender, EventArgs e) { 112 91 UpdateVariableImpact(); 113 92 } 114 115 93 private void replacementComboBox_SelectedIndexChanged(object sender, EventArgs e) { 116 94 UpdateVariableImpact(); 117 95 } 118 119 96 private void sortByComboBox_SelectedIndexChanged(object sender, EventArgs e) { 120 97 //Update the default ordering (asc,desc), but remove the eventHandler beforehand (otherwise the data would be ordered twice) 121 98 ascendingCheckBox.CheckedChanged -= ascendingCheckBox_CheckedChanged; 122 switch ((SortingCriteria)sortByComboBox.SelectedItem) { 123 case SortingCriteria.ImpactValue: 124 ascendingCheckBox.Checked = false; 125 break; 126 case SortingCriteria.Occurrence: 127 ascendingCheckBox.Checked = true; 128 break; 129 case SortingCriteria.VariableName: 130 ascendingCheckBox.Checked = true; 131 break; 132 default: 133 throw new NotImplementedException("Ordering for selected SortingCriteria not implemented"); 134 } 99 ascendingCheckBox.Checked = (SortingCriteria)sortByComboBox.SelectedItem != SortingCriteria.ImpactValue; 135 100 ascendingCheckBox.CheckedChanged += ascendingCheckBox_CheckedChanged; 136 101 137 UpdateDataOrdering(); 138 } 139 102 UpdateOrdering(); 103 } 140 104 private void ascendingCheckBox_CheckedChanged(object sender, EventArgs e) { 141 UpdateDataOrdering(); 142 } 143 144 #endregion 145 146 #region Helper Methods 147 private void UpdateVariableImpact() { 105 UpdateOrdering(); 106 } 107 108 private async void UpdateVariableImpact() { 148 109 //Check if the selection is valid 149 110 if (Content == null) { return; } … … 152 113 if (factorVarReplComboBox.SelectedIndex < 0) { return; } 153 114 115 IProgress progress; 116 154 117 //Prepare arguments 155 118 var mainForm = (MainForm.WindowsForms.MainForm)MainFormManager.MainForm; … … 159 122 160 123 variableImactsArrayView.Caption = Content.Name + " Variable Impacts"; 161 162 mainForm.AddOperationProgressToView(this, "Calculating variable impacts for " + Content.Name); 163 164 Task.Factory.StartNew(() => { 165 thread = Thread.CurrentThread; 166 //Remember the original ordering of the variables 167 var impacts = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(Content, dataPartition, replMethod, factorReplMethod); 124 progress = mainForm.AddOperationProgressToView(this, "Calculating variable impacts for " + Content.Name); 125 progress.ProgressValue = 0; 126 127 cancellationToken = new CancellationTokenSource(); 128 129 try { 168 130 var problemData = Content.ProblemData; 169 131 var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(Content.Model.VariablesUsedForPrediction)); 170 var originalVariableOrdering = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).Where(problemData.Dataset.VariableHasType<double>).ToList(); 132 //Remember the original ordering of the variables 133 var originalVariableOrdering = problemData.Dataset.VariableNames 134 .Where(v => inputvariables.Contains(v)) 135 .Where(v => problemData.Dataset.VariableHasType<double>(v) || problemData.Dataset.VariableHasType<string>(v)) 136 .ToList(); 137 138 List<Tuple<string, double>> impacts = null; 139 140 await Task.Run(() => { impacts = CalculateVariableImpacts(originalVariableOrdering, (IClassificationModel)Content.Model.Clone(), problemData, Content.EstimatedClassValues, dataPartition, replMethod, factorReplMethod, cancellationToken.Token, progress); }); 141 if (impacts == null) { return; } 171 142 172 143 rawVariableImpacts.Clear(); 173 originalVariableOrdering.ForEach(v => rawVariableImpacts.Add(v, impacts.First(vv => vv.Item1 == v).Item2)); 174 }).ContinueWith((o) => { 175 UpdateDataOrdering(); 176 mainForm.RemoveOperationProgressFromView(this); 177 thread = null; 178 }, TaskScheduler.FromCurrentSynchronizationContext()); 144 originalVariableOrdering.ForEach(v => rawVariableImpacts.Add(new Tuple<string, double>(v, impacts.First(vv => vv.Item1 == v).Item2))); 145 UpdateOrdering(); 146 } 147 finally { 148 ((MainForm.WindowsForms.MainForm)MainFormManager.MainForm).RemoveOperationProgressFromView(this); 149 } 150 } 151 152 private List<Tuple<string, double>> CalculateVariableImpacts(List<string> originalVariableOrdering, 153 IClassificationModel model, 154 IClassificationProblemData problemData, 155 IEnumerable<double> estimatedValues, 156 ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum dataPartition, 157 ClassificationSolutionVariableImpactsCalculator.ReplacementMethodEnum replMethod, 158 ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum factorReplMethod, 159 CancellationToken token, 160 IProgress progress) { 161 List<Tuple<string, double>> impacts = new List<Tuple<string, double>>(); 162 int count = originalVariableOrdering.Count; 163 int i = 0; 164 var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable(); 165 IEnumerable<int> rows = ClassificationSolutionVariableImpactsCalculator.GetPartitionRows(dataPartition, problemData); 166 167 //Calculate original quality-values (via calculator, default is R²) 168 OnlineCalculatorError error; 169 IEnumerable<double> targetValuesPartition = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); 170 IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v)); 171 var originalCalculatorValue = ClassificationSolutionVariableImpactsCalculator.CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error); 172 if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation."); 173 174 foreach (var variableName in originalVariableOrdering) { 175 if (cancellationToken.Token.IsCancellationRequested) { return null; } 176 progress.ProgressValue = (double)++i / count; 177 progress.Status = string.Format("Calculating impact for variable {0} ({1} of {2})", variableName, i, count); 178 179 double impact = ClassificationSolutionVariableImpactsCalculator.CalculateImpact(variableName, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replMethod, factorReplMethod); 180 impacts.Add(new Tuple<string, double>(variableName, impact)); 181 } 182 183 return impacts; 179 184 } 180 185 … … 183 188 /// The default is "Descending" by "VariableImpact" (as in previous versions) 184 189 /// </summary> 185 private void Update DataOrdering() {190 private void UpdateOrdering() { 186 191 //Check if valid sortingCriteria is selected and data exists 187 192 if (sortByComboBox.SelectedIndex == -1) { return; } … … 192 197 bool ascending = ascendingCheckBox.Checked; 193 198 194 IEnumerable< KeyValuePair<string, double>> orderedEntries = null;199 IEnumerable<Tuple<string, double>> orderedEntries = null; 195 200 196 201 //Sort accordingly 197 202 switch (selectedItem) { 198 203 case SortingCriteria.ImpactValue: 199 orderedEntries = rawVariableImpacts.OrderBy(v => v. Value);204 orderedEntries = rawVariableImpacts.OrderBy(v => v.Item2); 200 205 break; 201 206 case SortingCriteria.Occurrence: … … 203 208 break; 204 209 case SortingCriteria.VariableName: 205 orderedEntries = rawVariableImpacts.OrderBy(v => v. Key, new NaturalStringComparer());210 orderedEntries = rawVariableImpacts.OrderBy(v => v.Item1, new NaturalStringComparer()); 206 211 break; 207 212 default: … … 212 217 213 218 //Write the data back 214 var impactArray = new DoubleArray(orderedEntries.Select(i => i. Value).ToArray()) {215 ElementNames = orderedEntries.Select(i => i. Key)219 var impactArray = new DoubleArray(orderedEntries.Select(i => i.Item2).ToArray()) { 220 ElementNames = orderedEntries.Select(i => i.Item1) 216 221 }; 217 222 … … 221 226 } 222 227 } 223 #endregion224 228 } 225 229 }
Note: See TracChangeset
for help on using the changeset viewer.