Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/27/12 11:02:09 (12 years ago)
Author:
mkommend
Message:

#1763: Merged remaining changes from the TreeSimplifier branch in the trunk and refactored impact values calculators.

Location:
trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification
Files:
3 edited
1 copied

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification

  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification-3.4.csproj

    r8606 r8946  
    120120    <Compile Include="ModelCreators\NormalDistributedThresholdsModelCreator.cs" />
    121121    <Compile Include="MultiObjective\SymbolicClassificationMultiObjectiveValidationBestSolutionAnalyzer.cs" />
     122    <Compile Include="SymbolicClassificationSolutionImpactValuesCalculator.cs" />
    122123    <Compile Include="SymbolicNearestNeighbourClassificationModel.cs" />
    123124    <Compile Include="Plugin.cs" />
     
    258259  -->
    259260  <PropertyGroup>
    260    <PreBuildEvent Condition=" '$(OS)' == 'Windows_NT' ">set Path=%25Path%25;$(ProjectDir);$(SolutionDir)
     261    <PreBuildEvent Condition=" '$(OS)' == 'Windows_NT' ">set Path=%25Path%25;$(ProjectDir);$(SolutionDir)
    261262set ProjectDir=$(ProjectDir)
    262263set SolutionDir=$(SolutionDir)
     
    265266call PreBuildEvent.cmd
    266267</PreBuildEvent>
    267 <PreBuildEvent Condition=" '$(OS)' != 'Windows_NT' ">
     268    <PreBuildEvent Condition=" '$(OS)' != 'Windows_NT' ">
    268269export ProjectDir=$(ProjectDir)
    269270export SolutionDir=$(SolutionDir)
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/Properties

    • Property svn:ignore
      --- 
      +++ 
      
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationSolutionImpactValuesCalculator.cs

    r8942 r8946  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    24 using System.Linq;
    2523using HeuristicLab.Common;
    2624using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2725
    2826namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
    29   public class SymbolicDiscriminantFunctionClassificationSolutionImpactValuesCalculator : SymbolicDataAnalysisSolutionImpactValuesCalculator {
    30     public override IEnumerable<Tuple<ISymbolicExpressionTreeNode, double>> CalculateReplacementValues(ISymbolicExpressionTree tree,
    31                                                                                                        ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
    32                                                                                                        IDataAnalysisProblemData problemData) {
    33       return from node in tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix()
    34              select new Tuple<ISymbolicExpressionTreeNode, double>(node, CalculateReplacementValue(node, tree, interpreter, problemData));
     27  public class SymbolicClassificationSolutionImpactValuesCalculator : SymbolicDataAnalysisSolutionImpactValuesCalculator {
     28    public override double CalculateReplacementValue(ISymbolicDataAnalysisModel model, ISymbolicExpressionTreeNode node, IDataAnalysisProblemData problemData, IEnumerable<int> rows) {
     29      var classificationModel = (ISymbolicClassificationModel)model;
     30      var classificationProblemData = (IClassificationProblemData)problemData;
     31
     32      return CalculateReplacementValue(node, classificationModel.SymbolicExpressionTree, classificationModel.Interpreter, classificationProblemData.Dataset, rows);
    3533    }
    3634
    37     public override IEnumerable<Tuple<ISymbolicExpressionTreeNode, double>> CalculateImpactValues(ISymbolicExpressionTree tree,
    38                                                                                                   ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
    39                                                                                                   IDataAnalysisProblemData classificationProblemData,
    40                                                                                                   double lowerEstimationLimit, double upperEstimationLimit) {
    41       var problemData = (IClassificationProblemData)classificationProblemData;
    42       var dataset = problemData.Dataset;
    43       var rows = problemData.TrainingIndices;
    44       string targetVariable = problemData.TargetVariable;
    45       var targetClassValues = dataset.GetDoubleValues(targetVariable, rows);
    46       var originalOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows).LimitToRange(lowerEstimationLimit, upperEstimationLimit).ToArray();
     35    public override double CalculateImpactValue(ISymbolicDataAnalysisModel model, ISymbolicExpressionTreeNode node, IDataAnalysisProblemData problemData, IEnumerable<int> rows, double originalQuality = double.NaN) {
     36      var classificationModel = (ISymbolicClassificationModel)model;
     37      var classificationProblemData = (IClassificationProblemData)problemData;
     38
     39      var dataset = classificationProblemData.Dataset;
     40      var targetClassValues = dataset.GetDoubleValues(classificationProblemData.TargetVariable, rows);
     41
    4742      OnlineCalculatorError errorState;
    48       double originalGini = NormalizedGiniCalculator.Calculate(targetClassValues, originalOutput, out errorState);
    49       if (errorState != OnlineCalculatorError.None) originalGini = 0.0;
     43      if (double.IsNaN(originalQuality)) {
     44        var originalClassValues = classificationModel.GetEstimatedClassValues(dataset, rows);
     45        originalQuality = OnlineAccuracyCalculator.Calculate(targetClassValues, originalClassValues, out errorState);
     46        if (errorState != OnlineCalculatorError.None) originalQuality = 0.0;
     47      }
    5048
    51       return from node in tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPostfix()
    52              select new Tuple<ISymbolicExpressionTreeNode, double>(node, CalculateImpact(tree, originalGini, node, interpreter, problemData, lowerEstimationLimit, upperEstimationLimit));
     49      var replacementValue = CalculateReplacementValue(classificationModel, node, classificationProblemData, rows);
     50      var constantNode = new ConstantTreeNode(new Constant()) { Value = replacementValue };
     51      var cloner = new Cloner();
     52      cloner.RegisterClonedObject(node, constantNode);
     53      var tempModel = cloner.Clone(classificationModel);
     54      tempModel.RecalculateModelParameters(classificationProblemData, rows);
     55
     56      var estimatedClassValues = tempModel.GetEstimatedClassValues(dataset, rows);
     57      double newQuality = OnlineAccuracyCalculator.Calculate(targetClassValues, estimatedClassValues, out errorState);
     58      if (errorState != OnlineCalculatorError.None) newQuality = 0.0;
     59
     60      return originalQuality - newQuality;
    5361    }
    5462
    55     private static double CalculateImpact(ISymbolicExpressionTree tree, double originalQuality, ISymbolicExpressionTreeNode node,
    56                                           ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IClassificationProblemData problemData,
    57                                           double lowerEstimationLimit, double upperEstimationLimit) {
    58       var dataset = problemData.Dataset;
    59       var rows = problemData.TrainingIndices.ToList();
    60       string targetVariable = problemData.TargetVariable;
    61       var targetValues = dataset.GetDoubleValues(targetVariable, rows).ToList();
    62 
    63       var parent = node.Parent;
    64       var constantNode = (ConstantTreeNode)new Constant().CreateTreeNode();
    65       constantNode.Value = CalculateReplacementValue(node, tree, interpreter, problemData);
    66       SwitchNode(parent, node, constantNode);
    67       var newOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows)
    68                                  .LimitToRange(lowerEstimationLimit, upperEstimationLimit)
    69                                  .ToArray();
    70       OnlineCalculatorError errorState;
    71       double quality = NormalizedGiniCalculator.Calculate(targetValues, newOutput, out errorState);
    72       if (errorState != OnlineCalculatorError.None) quality = 0.0;
    73       SwitchNode(parent, constantNode, node);
    74       // impact = 0 if no change
    75       // impact < 0 if new solution is better
    76       // impact > 0 if new solution is worse
    77       return originalQuality - quality;
    78     }
    7963  }
    8064}
Note: See TracChangeset for help on using the changeset viewer.