Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.TreeSimplifier/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationSolutionValuesCalculator.cs @ 8388

Last change on this file since 8388 was 8388, checked in by bburlacu, 12 years ago

#1763: Rebranched the TreeSimplifier project to fix merging errors. Added functionality to insert symbols into the tree, fixed bug when deleting tree nodes.

File size: 3.5 KB
Line 
1using System.Collections.Generic;
2using System.Linq;
3using HeuristicLab.Common;
4using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
5
6namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
7  public class SymbolicDiscriminantFunctionClassificationSolutionValuesCalculator : SymbolicDataAnalysisSolutionValuesCalculator {
8    public override Dictionary<ISymbolicExpressionTreeNode, double> CalculateReplacementValues(ISymbolicExpressionTree tree,
9                                                                                      ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
10                                                                                      IDataAnalysisProblemData problemData) {
11      var replacementValues = new Dictionary<ISymbolicExpressionTreeNode, double>();
12      foreach (ISymbolicExpressionTreeNode node in tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix()) {
13        replacementValues[node] = CalculateReplacementValue(node, tree, interpreter, problemData);
14      }
15      return replacementValues;
16    }
17    public override Dictionary<ISymbolicExpressionTreeNode, double> CalculateImpactValues(ISymbolicExpressionTree tree,
18                                                                                 ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
19                                                                                 IDataAnalysisProblemData classificationProblemData,
20                                                                                 double lowerEstimationLimit, double upperEstimationLimit) {
21      var problemData = (IClassificationProblemData)classificationProblemData;
22      var dataset = problemData.Dataset;
23      var rows = problemData.TrainingIndices;
24      string targetVariable = problemData.TargetVariable;
25      Dictionary<ISymbolicExpressionTreeNode, double> impactValues = new Dictionary<ISymbolicExpressionTreeNode, double>();
26      List<ISymbolicExpressionTreeNode> nodes = tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPostfix().ToList();
27
28      var targetClassValues = dataset.GetDoubleValues(targetVariable, rows);
29      var originalOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows)
30        .LimitToRange(lowerEstimationLimit, upperEstimationLimit)
31        .ToArray();
32      OnlineCalculatorError errorState;
33      double originalGini = NormalizedGiniCalculator.Calculate(targetClassValues, originalOutput, out errorState);
34      if (errorState != OnlineCalculatorError.None) originalGini = 0.0;
35
36      foreach (ISymbolicExpressionTreeNode node in nodes) {
37        var parent = node.Parent;
38        constantNode.Value = CalculateReplacementValue(node, tree, interpreter, classificationProblemData);
39        ISymbolicExpressionTreeNode replacementNode = constantNode;
40        SwitchNode(parent, node, replacementNode);
41        var newOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows)
42          .LimitToRange(lowerEstimationLimit, upperEstimationLimit)
43          .ToArray();
44        double newGini = NormalizedGiniCalculator.Calculate(targetClassValues, newOutput, out errorState);
45        if (errorState != OnlineCalculatorError.None) newGini = 0.0;
46
47        // impact = 0 if no change
48        // impact < 0 if new solution is better
49        // impact > 0 if new solution is worse
50        impactValues[node] = originalGini - newGini;
51        SwitchNode(parent, replacementNode, node);
52      }
53      return impactValues;
54
55    }
56  }
57}
Note: See TracBrowser for help on using the repository browser.