Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/22/12 15:03:32 (11 years ago)
Author:
bburlacu
Message:

#1763: Bugfixes and refactoring as suggested in the comments above.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.TreeSimplifier/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationSolutionImpactValuesCalculator.cs

    r8409 r8935  
    1 using System.Collections.Generic;
     1#region License Information
     2/* HeuristicLab
     3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     4 *
     5 * This file is part of HeuristicLab.
     6 *
     7 * HeuristicLab is free software: you can redistribute it and/or modify
     8 * it under the terms of the GNU General Public License as published by
     9 * the Free Software Foundation, either version 3 of the License, or
     10 * (at your option) any later version.
     11 *
     12 * HeuristicLab is distributed in the hope that it will be useful,
     13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
     14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     15 * GNU General Public License for more details.
     16 *
     17 * You should have received a copy of the GNU General Public License
     18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
     19 */
     20#endregion
     21
     22using System;
     23using System.Collections.Generic;
    224using System.Linq;
    325using HeuristicLab.Common;
     
    628namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
    729  public class SymbolicDiscriminantFunctionClassificationSolutionImpactValuesCalculator : SymbolicDataAnalysisSolutionImpactValuesCalculator {
    8     public override Dictionary<ISymbolicExpressionTreeNode, double> CalculateReplacementValues(ISymbolicExpressionTree tree,
    9                                                                                       ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
    10                                                                                       IDataAnalysisProblemData problemData) {
    11       var replacementValues = new Dictionary<ISymbolicExpressionTreeNode, double>();
    12       foreach (ISymbolicExpressionTreeNode node in tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix()) {
    13         replacementValues[node] = CalculateReplacementValue(node, tree, interpreter, problemData);
    14       }
    15       return replacementValues;
     30    public override IEnumerable<Tuple<ISymbolicExpressionTreeNode, double>> CalculateReplacementValues(ISymbolicExpressionTree tree,
     31                                                                                                       ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     32                                                                                                       IDataAnalysisProblemData problemData) {
     33      return from node in tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix()
     34             select new Tuple<ISymbolicExpressionTreeNode, double>(node, CalculateReplacementValue(node, tree, interpreter, problemData));
    1635    }
    17     public override Dictionary<ISymbolicExpressionTreeNode, double> CalculateImpactValues(ISymbolicExpressionTree tree,
    18                                                                                  ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
    19                                                                                  IDataAnalysisProblemData classificationProblemData,
    20                                                                                  double lowerEstimationLimit, double upperEstimationLimit) {
     36
     37    public override IEnumerable<Tuple<ISymbolicExpressionTreeNode, double>> CalculateImpactValues(ISymbolicExpressionTree tree,
     38                                                                                                  ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     39                                                                                                  IDataAnalysisProblemData classificationProblemData,
     40                                                                                                  double lowerEstimationLimit, double upperEstimationLimit) {
    2141      var problemData = (IClassificationProblemData)classificationProblemData;
    2242      var dataset = problemData.Dataset;
    2343      var rows = problemData.TrainingIndices;
    2444      string targetVariable = problemData.TargetVariable;
    25       Dictionary<ISymbolicExpressionTreeNode, double> impactValues = new Dictionary<ISymbolicExpressionTreeNode, double>();
    26       List<ISymbolicExpressionTreeNode> nodes = tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPostfix().ToList();
    27 
    2845      var targetClassValues = dataset.GetDoubleValues(targetVariable, rows);
    29       var originalOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows)
    30         .LimitToRange(lowerEstimationLimit, upperEstimationLimit)
    31         .ToArray();
     46      var originalOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows).LimitToRange(lowerEstimationLimit, upperEstimationLimit).ToArray();
    3247      OnlineCalculatorError errorState;
    3348      double originalGini = NormalizedGiniCalculator.Calculate(targetClassValues, originalOutput, out errorState);
    3449      if (errorState != OnlineCalculatorError.None) originalGini = 0.0;
    3550
    36       foreach (ISymbolicExpressionTreeNode node in nodes) {
    37         var parent = node.Parent;
    38         var constantNode = ((ConstantTreeNode)new Constant().CreateTreeNode());
    39         constantNode.Value = CalculateReplacementValue(node, tree, interpreter, classificationProblemData);
    40         ISymbolicExpressionTreeNode replacementNode = constantNode;
    41         SwitchNode(parent, node, replacementNode);
    42         var newOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows)
    43           .LimitToRange(lowerEstimationLimit, upperEstimationLimit)
    44           .ToArray();
    45         double newGini = NormalizedGiniCalculator.Calculate(targetClassValues, newOutput, out errorState);
    46         if (errorState != OnlineCalculatorError.None) newGini = 0.0;
     51      return from node in tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPostfix()
     52             select new Tuple<ISymbolicExpressionTreeNode, double>(node, CalculateImpact(tree, originalGini, node, interpreter, problemData, lowerEstimationLimit, upperEstimationLimit));
     53    }
    4754
    48         // impact = 0 if no change
    49         // impact < 0 if new solution is better
    50         // impact > 0 if new solution is worse
    51         impactValues[node] = originalGini - newGini;
    52         SwitchNode(parent, replacementNode, node);
    53       }
    54       return impactValues;
     55    private static double CalculateImpact(ISymbolicExpressionTree tree, double originalQuality, ISymbolicExpressionTreeNode node,
     56                                          ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IClassificationProblemData problemData,
     57                                          double lowerEstimationLimit, double upperEstimationLimit) {
     58      var dataset = problemData.Dataset;
     59      var rows = problemData.TrainingIndices.ToList();
     60      string targetVariable = problemData.TargetVariable;
     61      var targetValues = dataset.GetDoubleValues(targetVariable, rows).ToList();
    5562
     63      var parent = node.Parent;
     64      var constantNode = (ConstantTreeNode)new Constant().CreateTreeNode();
     65      constantNode.Value = CalculateReplacementValue(node, tree, interpreter, problemData);
     66      SwitchNode(parent, node, constantNode);
     67      var newOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows)
     68                                 .LimitToRange(lowerEstimationLimit, upperEstimationLimit)
     69                                 .ToArray();
     70      OnlineCalculatorError errorState;
     71      double quality = NormalizedGiniCalculator.Calculate(targetValues, newOutput, out errorState);
     72      if (errorState != OnlineCalculatorError.None) quality = 0.0;
     73      SwitchNode(parent, constantNode, node);
     74      // impact = 0 if no change
     75      // impact < 0 if new solution is better
     76      // impact > 0 if new solution is worse
     77      return originalQuality - quality;
    5678    }
    5779  }
Note: See TracChangeset for help on using the changeset viewer.