using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification { public class SymbolicDiscriminantFunctionClassificationSolutionImpactValuesCalculator : SymbolicDataAnalysisSolutionImpactValuesCalculator { public override Dictionary CalculateReplacementValues(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataAnalysisProblemData problemData) { var replacementValues = new Dictionary(); foreach (ISymbolicExpressionTreeNode node in tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix()) { replacementValues[node] = CalculateReplacementValue(node, tree, interpreter, problemData); } return replacementValues; } public override Dictionary CalculateImpactValues(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataAnalysisProblemData classificationProblemData, double lowerEstimationLimit, double upperEstimationLimit) { var problemData = (IClassificationProblemData)classificationProblemData; var dataset = problemData.Dataset; var rows = problemData.TrainingIndices; string targetVariable = problemData.TargetVariable; Dictionary impactValues = new Dictionary(); List nodes = tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPostfix().ToList(); var targetClassValues = dataset.GetDoubleValues(targetVariable, rows); var originalOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows) .LimitToRange(lowerEstimationLimit, upperEstimationLimit) .ToArray(); OnlineCalculatorError errorState; double originalGini = NormalizedGiniCalculator.Calculate(targetClassValues, originalOutput, out errorState); if (errorState != OnlineCalculatorError.None) originalGini = 0.0; foreach (ISymbolicExpressionTreeNode node in nodes) { var parent = node.Parent; var constantNode = ((ConstantTreeNode)new Constant().CreateTreeNode()); constantNode.Value = CalculateReplacementValue(node, tree, interpreter, classificationProblemData); ISymbolicExpressionTreeNode replacementNode = constantNode; SwitchNode(parent, node, replacementNode); var newOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows) .LimitToRange(lowerEstimationLimit, upperEstimationLimit) .ToArray(); double newGini = NormalizedGiniCalculator.Calculate(targetClassValues, newOutput, out errorState); if (errorState != OnlineCalculatorError.None) newGini = 0.0; // impact = 0 if no change // impact < 0 if new solution is better // impact > 0 if new solution is worse impactValues[node] = originalGini - newGini; SwitchNode(parent, replacementNode, node); } return impactValues; } } }