Changeset 16037
- Timestamp:
- 08/01/18 14:14:49 (6 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2904_CalculateImpacts/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs
r16036 r16037 90 90 public ClassificationSolutionVariableImpactsCalculator() 91 91 : base() { 92 Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Shuffle))); 92 Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median))); 93 Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best))); 93 94 Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training))); 94 95 } … … 122 123 return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod); 123 124 } 124 125 125 126 126 public static IEnumerable<Tuple<string, double>> CalculateImpacts( … … 149 149 return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value)); 150 150 } 151 152 151 153 152 public static double CalculateImpact(string variableName, … … 170 169 var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList(); 171 170 List<double> replacementValues; 172 IRandom rand;173 171 174 172 switch (replacementMethod) { … … 183 181 case ReplacementMethodEnum.Shuffle: 184 182 // new var has same empirical distribution but the relation to y is broken 185 rand = new FastRandom(31415);183 random = new FastRandom(31415); 186 184 // prepare a complete column for the dataset 187 185 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 188 186 // shuffle only the selected rows 189 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand ).ToList();187 var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); 190 188 int i = 0; 191 189 // update column values … … 197 195 var avg = rows.Select(r => originalValues[r]).Average(); 198 196 var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); 199 rand = new FastRandom(31415);197 random = new FastRandom(31415); 200 198 // prepare a complete column for the dataset 201 199 replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); 202 200 // update column values 203 201 foreach (var r in rows) { 204 replacementValues[r] = NormalDistributedRandom.NextDouble(rand , avg, stdDev);202 replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev); 205 203 } 206 204 break; … … 276 274 277 275 return impact; 276 } 277 278 /// <summary> 279 /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values 280 /// and changes the value of the model-variables back to the original ones. 281 /// </summary> 282 /// <param name="originalValues"></param> 283 /// <param name="model"></param> 284 /// <param name="variableName"></param> 285 /// <param name="modifiableDataset"></param> 286 /// <param name="rows"></param> 287 /// <param name="replacementValues"></param> 288 /// <returns></returns> 289 private static IEnumerable<double> GetReplacedEstimates( 290 IList originalValues, 291 IClassificationModel model, 292 string variableName, 293 ModifiableDataset modifiableDataset, 294 IEnumerable<int> rows, 295 IList replacementValues) { 296 modifiableDataset.ReplaceVariable(variableName, replacementValues); 297 298 var discModel = model as IDiscriminantFunctionClassificationModel; 299 if (discModel != null) { 300 var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable); 301 discModel.RecalculateModelParameters(problemData, rows); 302 } 303 304 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements 305 var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList(); 306 modifiableDataset.ReplaceVariable(variableName, originalValues); 307 308 return estimates; 278 309 } 279 310 … … 312 343 313 344 /// <summary> 314 /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values315 /// and changes the value of the model-variables back to the original ones.316 /// </summary>317 /// <param name="originalValues"></param>318 /// <param name="model"></param>319 /// <param name="variableName"></param>320 /// <param name="modifiableDataset"></param>321 /// <param name="rows"></param>322 /// <param name="replacementValues"></param>323 /// <returns></returns>324 private static IEnumerable<double> GetReplacedEstimates(325 IList originalValues,326 IClassificationModel model,327 string variableName,328 ModifiableDataset modifiableDataset,329 IEnumerable<int> rows,330 IList replacementValues) {331 modifiableDataset.ReplaceVariable(variableName, replacementValues);332 333 var discModel = model as IDiscriminantFunctionClassificationModel;334 if (discModel != null) {335 var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable);336 discModel.RecalculateModelParameters(problemData, rows);337 }338 339 //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements340 var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList();341 modifiableDataset.ReplaceVariable(variableName, originalValues);342 343 return estimates;344 }345 346 347 /// <summary>348 345 /// Returns a collection of the row-indices for a given DataPartition (training or test) 349 346 /// </summary> … … 370 367 return rows; 371 368 } 372 373 369 } 374 370 }
Note: See TracChangeset
for help on using the changeset viewer.