Changeset 8636


Ignore:
Timestamp:
09/12/12 16:20:18 (7 years ago)
Author:
mkommend
Message:

#1924:

  • Changed the accuracy threshold calculator to eliminate the necessity the the class values are ordered.
  • Adapted the symbolic classification simplifier to work with all ISymbolicClassificationModels.
  • Corrected ROCCurvesView to also work if the class values are not sorted.
Location:
trunk/sources
Files:
2 added
4 edited
4 copied

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views/3.4/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views-3.4.csproj

    r8600 r8636  
    110110  </ItemGroup>
    111111  <ItemGroup>
     112    <Compile Include="InteractiveSymbolicClassificationSolutionSimplifierView.cs">
     113      <SubType>UserControl</SubType>
     114    </Compile>
     115    <Compile Include="InteractiveSymbolicClassificationSolutionSimplifierView.Designer.cs">
     116      <DependentUpon>InteractiveSymbolicClassificationSolutionSimplifierView.cs</DependentUpon>
     117    </Compile>
     118    <Compile Include="SymbolicClassificationSolutionView.cs">
     119      <SubType>UserControl</SubType>
     120    </Compile>
     121    <Compile Include="SymbolicClassificationSolutionView.Designer.cs">
     122      <DependentUpon>SymbolicClassificationSolutionView.cs</DependentUpon>
     123    </Compile>
     124    <Compile Include="InteractiveSymbolicClassificationSolutionSimplifierViewBase.cs">
     125      <SubType>UserControl</SubType>
     126    </Compile>
     127    <Compile Include="InteractiveSymbolicClassificationSolutionSimplifierViewBase.Designer.cs">
     128      <DependentUpon>InteractiveSymbolicClassificationSolutionSimplifierViewBase.cs</DependentUpon>
     129    </Compile>
    112130    <Compile Include="Plugin.cs" />
    113131    <Compile Include="SymbolicDiscriminantFunctionClassificationSolutionView.cs">
     
    246264  -->
    247265  <PropertyGroup>
    248    <PreBuildEvent Condition=" '$(OS)' == 'Windows_NT' ">set Path=%25Path%25;$(ProjectDir);$(SolutionDir)
     266    <PreBuildEvent Condition=" '$(OS)' == 'Windows_NT' ">set Path=%25Path%25;$(ProjectDir);$(SolutionDir)
    249267set ProjectDir=$(ProjectDir)
    250268set SolutionDir=$(SolutionDir)
     
    253271call PreBuildEvent.cmd
    254272</PreBuildEvent>
    255 <PreBuildEvent Condition=" '$(OS)' != 'Windows_NT' ">
     273    <PreBuildEvent Condition=" '$(OS)' != 'Windows_NT' ">
    256274export ProjectDir=$(ProjectDir)
    257275export SolutionDir=$(SolutionDir)
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views/3.4/InteractiveSymbolicClassificationSolutionSimplifierView.Designer.cs

    r8633 r8636  
    2121
    2222namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views {
    23   partial class InteractiveSymbolicDiscriminantFunctionClassificationSolutionSimplifierView {
     23  partial class InteractiveSymbolicClassificationSolutionSimplifierView {
    2424    /// <summary>
    2525    /// Required designer variable.
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views/3.4/InteractiveSymbolicClassificationSolutionSimplifierView.cs

    r8633 r8636  
    2020#endregion
    2121
    22 using System;
    23 using System.Collections.Generic;
    24 using System.Linq;
    25 using HeuristicLab.Common;
    2622using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    27 using HeuristicLab.Problems.DataAnalysis.Symbolic.Views;
    2823
    2924namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views {
    30   public partial class InteractiveSymbolicDiscriminantFunctionClassificationSolutionSimplifierView : InteractiveSymbolicDataAnalysisSolutionSimplifierView {
    31     private readonly ConstantTreeNode constantNode;
    32     private readonly SymbolicExpressionTree tempTree;
     25  public partial class InteractiveSymbolicClassificationSolutionSimplifierView : InteractiveSymbolicClassificationSolutionSimplifierViewBase {
    3326
    34     public new SymbolicDiscriminantFunctionClassificationSolution Content {
    35       get { return (SymbolicDiscriminantFunctionClassificationSolution)base.Content; }
     27    public new SymbolicClassificationSolution Content {
     28      get { return (SymbolicClassificationSolution)base.Content; }
    3629      set { base.Content = value; }
    3730    }
    3831
    39     public InteractiveSymbolicDiscriminantFunctionClassificationSolutionSimplifierView()
    40       : base() {
    41       InitializeComponent();
    42       this.Caption = "Interactive Classification Solution Simplifier";
    43 
    44       constantNode = ((ConstantTreeNode)new Constant().CreateTreeNode());
    45       ISymbolicExpressionTreeNode root = new ProgramRootSymbol().CreateTreeNode();
    46       ISymbolicExpressionTreeNode start = new StartSymbol().CreateTreeNode();
    47       root.AddSubtree(start);
    48       tempTree = new SymbolicExpressionTree(root);
    49     }
     32    public InteractiveSymbolicClassificationSolutionSimplifierView() : base() { }
    5033
    5134    protected override void UpdateModel(ISymbolicExpressionTree tree) {
    52       var model = new SymbolicDiscriminantFunctionClassificationModel(tree, Content.Model.Interpreter, Content.Model.ThresholdCalculator, Content.Model.LowerEstimationLimit, Content.Model.UpperEstimationLimit);
    53       model.RecalculateModelParameters(Content.ProblemData, Content.ProblemData.TrainingIndices);
    54       Content.Model = model;
    55     }
    56 
    57     protected override Dictionary<ISymbolicExpressionTreeNode, double> CalculateReplacementValues(ISymbolicExpressionTree tree) {
    58       Dictionary<ISymbolicExpressionTreeNode, double> replacementValues = new Dictionary<ISymbolicExpressionTreeNode, double>();
    59       foreach (ISymbolicExpressionTreeNode node in tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix()) {
    60         replacementValues[node] = CalculateReplacementValue(node, tree);
    61       }
    62       return replacementValues;
    63     }
    64 
    65     protected override Dictionary<ISymbolicExpressionTreeNode, double> CalculateImpactValues(ISymbolicExpressionTree tree) {
    66       var interpreter = Content.Model.Interpreter;
    67       var dataset = Content.ProblemData.Dataset;
    68       var rows = Content.ProblemData.TrainingIndices;
    69       string targetVariable = Content.ProblemData.TargetVariable;
    70       Dictionary<ISymbolicExpressionTreeNode, double> impactValues = new Dictionary<ISymbolicExpressionTreeNode, double>();
    71       List<ISymbolicExpressionTreeNode> nodes = tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPostfix().ToList();
    72 
    73       var targetClassValues = dataset.GetDoubleValues(targetVariable, rows);
    74       var originalOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows)
    75         .LimitToRange(Content.Model.LowerEstimationLimit, Content.Model.UpperEstimationLimit)
    76         .ToArray();
    77       OnlineCalculatorError errorState;
    78       double originalGini = NormalizedGiniCalculator.Calculate(targetClassValues, originalOutput, out errorState);
    79       if (errorState != OnlineCalculatorError.None) originalGini = 0.0;
    80 
    81       foreach (ISymbolicExpressionTreeNode node in nodes) {
    82         var parent = node.Parent;
    83         constantNode.Value = CalculateReplacementValue(node, tree);
    84         ISymbolicExpressionTreeNode replacementNode = constantNode;
    85         SwitchNode(parent, node, replacementNode);
    86         var newOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows)
    87           .LimitToRange(Content.Model.LowerEstimationLimit, Content.Model.UpperEstimationLimit)
    88           .ToArray();
    89         double newGini = NormalizedGiniCalculator.Calculate(targetClassValues, newOutput, out errorState);
    90         if (errorState != OnlineCalculatorError.None) newGini = 0.0;
    91 
    92         // impact = 0 if no change
    93         // impact < 0 if new solution is better
    94         // impact > 0 if new solution is worse
    95         impactValues[node] = originalGini - newGini;
    96         SwitchNode(parent, replacementNode, node);
    97       }
    98       return impactValues;
    99     }
    100 
    101     private double CalculateReplacementValue(ISymbolicExpressionTreeNode node, ISymbolicExpressionTree sourceTree) {
    102       // remove old ADFs
    103       while (tempTree.Root.SubtreeCount > 1) tempTree.Root.RemoveSubtree(1);
    104       // clone ADFs of source tree
    105       for (int i = 1; i < sourceTree.Root.SubtreeCount; i++) {
    106         tempTree.Root.AddSubtree((ISymbolicExpressionTreeNode)sourceTree.Root.GetSubtree(i).Clone());
    107       }
    108       var start = tempTree.Root.GetSubtree(0);
    109       while (start.SubtreeCount > 0) start.RemoveSubtree(0);
    110       start.AddSubtree((ISymbolicExpressionTreeNode)node.Clone());
    111       var interpreter = Content.Model.Interpreter;
    112       var rows = Content.ProblemData.TrainingIndices;
    113       return interpreter.GetSymbolicExpressionTreeValues(tempTree, Content.ProblemData.Dataset, rows).Median();
    114     }
    115 
    116 
    117     private void SwitchNode(ISymbolicExpressionTreeNode root, ISymbolicExpressionTreeNode oldBranch, ISymbolicExpressionTreeNode newBranch) {
    118       for (int i = 0; i < root.SubtreeCount; i++) {
    119         if (root.GetSubtree(i) == oldBranch) {
    120           root.RemoveSubtree(i);
    121           root.InsertSubtree(i, newBranch);
    122           return;
    123         }
    124       }
    125     }
    126 
    127     protected override void btnOptimizeConstants_Click(object sender, EventArgs e) {
    128 
     35      Content.Model = CreateModel(tree);
    12936    }
    13037  }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views/3.4/InteractiveSymbolicDiscriminantFunctionClassificationSolutionSimplifierView.cs

    r8594 r8636  
    2020#endregion
    2121
    22 using System;
    23 using System.Collections.Generic;
    24 using System.Linq;
    25 using HeuristicLab.Common;
    2622using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    27 using HeuristicLab.Problems.DataAnalysis.Symbolic.Views;
    2823
    2924namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views {
    30   public partial class InteractiveSymbolicDiscriminantFunctionClassificationSolutionSimplifierView : InteractiveSymbolicDataAnalysisSolutionSimplifierView {
    31     private readonly ConstantTreeNode constantNode;
    32     private readonly SymbolicExpressionTree tempTree;
     25  public partial class InteractiveSymbolicDiscriminantFunctionClassificationSolutionSimplifierView : InteractiveSymbolicClassificationSolutionSimplifierViewBase {
    3326
    3427    public new SymbolicDiscriminantFunctionClassificationSolution Content {
     
    3730    }
    3831
    39     public InteractiveSymbolicDiscriminantFunctionClassificationSolutionSimplifierView()
    40       : base() {
    41       InitializeComponent();
    42       this.Caption = "Interactive Classification Solution Simplifier";
    43 
    44       constantNode = ((ConstantTreeNode)new Constant().CreateTreeNode());
    45       ISymbolicExpressionTreeNode root = new ProgramRootSymbol().CreateTreeNode();
    46       ISymbolicExpressionTreeNode start = new StartSymbol().CreateTreeNode();
    47       root.AddSubtree(start);
    48       tempTree = new SymbolicExpressionTree(root);
    49     }
     32    public InteractiveSymbolicDiscriminantFunctionClassificationSolutionSimplifierView() : base() { }
    5033
    5134    protected override void UpdateModel(ISymbolicExpressionTree tree) {
    52       var model = new SymbolicDiscriminantFunctionClassificationModel(tree, Content.Model.Interpreter, Content.Model.ThresholdCalculator, Content.Model.LowerEstimationLimit, Content.Model.UpperEstimationLimit);
     35      var model = CreateModel(tree);
    5336      model.RecalculateModelParameters(Content.ProblemData, Content.ProblemData.TrainingIndices);
    54       Content.Model = model;
    55     }
    56 
    57     protected override Dictionary<ISymbolicExpressionTreeNode, double> CalculateReplacementValues(ISymbolicExpressionTree tree) {
    58       Dictionary<ISymbolicExpressionTreeNode, double> replacementValues = new Dictionary<ISymbolicExpressionTreeNode, double>();
    59       foreach (ISymbolicExpressionTreeNode node in tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix()) {
    60         replacementValues[node] = CalculateReplacementValue(node, tree);
    61       }
    62       return replacementValues;
    63     }
    64 
    65     protected override Dictionary<ISymbolicExpressionTreeNode, double> CalculateImpactValues(ISymbolicExpressionTree tree) {
    66       var interpreter = Content.Model.Interpreter;
    67       var dataset = Content.ProblemData.Dataset;
    68       var rows = Content.ProblemData.TrainingIndices;
    69       string targetVariable = Content.ProblemData.TargetVariable;
    70       Dictionary<ISymbolicExpressionTreeNode, double> impactValues = new Dictionary<ISymbolicExpressionTreeNode, double>();
    71       List<ISymbolicExpressionTreeNode> nodes = tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPostfix().ToList();
    72 
    73       var targetClassValues = dataset.GetDoubleValues(targetVariable, rows);
    74       var originalOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows)
    75         .LimitToRange(Content.Model.LowerEstimationLimit, Content.Model.UpperEstimationLimit)
    76         .ToArray();
    77       OnlineCalculatorError errorState;
    78       double originalGini = NormalizedGiniCalculator.Calculate(targetClassValues, originalOutput, out errorState);
    79       if (errorState != OnlineCalculatorError.None) originalGini = 0.0;
    80 
    81       foreach (ISymbolicExpressionTreeNode node in nodes) {
    82         var parent = node.Parent;
    83         constantNode.Value = CalculateReplacementValue(node, tree);
    84         ISymbolicExpressionTreeNode replacementNode = constantNode;
    85         SwitchNode(parent, node, replacementNode);
    86         var newOutput = interpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows)
    87           .LimitToRange(Content.Model.LowerEstimationLimit, Content.Model.UpperEstimationLimit)
    88           .ToArray();
    89         double newGini = NormalizedGiniCalculator.Calculate(targetClassValues, newOutput, out errorState);
    90         if (errorState != OnlineCalculatorError.None) newGini = 0.0;
    91 
    92         // impact = 0 if no change
    93         // impact < 0 if new solution is better
    94         // impact > 0 if new solution is worse
    95         impactValues[node] = originalGini - newGini;
    96         SwitchNode(parent, replacementNode, node);
    97       }
    98       return impactValues;
    99     }
    100 
    101     private double CalculateReplacementValue(ISymbolicExpressionTreeNode node, ISymbolicExpressionTree sourceTree) {
    102       // remove old ADFs
    103       while (tempTree.Root.SubtreeCount > 1) tempTree.Root.RemoveSubtree(1);
    104       // clone ADFs of source tree
    105       for (int i = 1; i < sourceTree.Root.SubtreeCount; i++) {
    106         tempTree.Root.AddSubtree((ISymbolicExpressionTreeNode)sourceTree.Root.GetSubtree(i).Clone());
    107       }
    108       var start = tempTree.Root.GetSubtree(0);
    109       while (start.SubtreeCount > 0) start.RemoveSubtree(0);
    110       start.AddSubtree((ISymbolicExpressionTreeNode)node.Clone());
    111       var interpreter = Content.Model.Interpreter;
    112       var rows = Content.ProblemData.TrainingIndices;
    113       return interpreter.GetSymbolicExpressionTreeValues(tempTree, Content.ProblemData.Dataset, rows).Median();
    114     }
    115 
    116 
    117     private void SwitchNode(ISymbolicExpressionTreeNode root, ISymbolicExpressionTreeNode oldBranch, ISymbolicExpressionTreeNode newBranch) {
    118       for (int i = 0; i < root.SubtreeCount; i++) {
    119         if (root.GetSubtree(i) == oldBranch) {
    120           root.RemoveSubtree(i);
    121           root.InsertSubtree(i, newBranch);
    122           return;
    123         }
    124       }
    125     }
    126 
    127     protected override void btnOptimizeConstants_Click(object sender, EventArgs e) {
    128 
     37      Content.Model = (ISymbolicDiscriminantFunctionClassificationModel)model;
    12938    }
    13039  }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views/3.4/SymbolicClassificationSolutionView.Designer.cs

    r8633 r8636  
    2121
    2222namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views {
    23   partial class SymbolicDiscriminantFunctionClassificationSolutionView {
     23  partial class SymbolicClassificationSolutionView {
    2424    /// <summary>
    2525    /// Required designer variable.
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views/3.4/SymbolicClassificationSolutionView.cs

    r8633 r8636  
    2626
    2727namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views {
    28   [Content(typeof(SymbolicDiscriminantFunctionClassificationSolution), false)]
     28  [Content(typeof(SymbolicClassificationSolution), false)]
    2929  [View("SymbolicDiscriminantFunctionClassificationSolution View")]
    30   public partial class SymbolicDiscriminantFunctionClassificationSolutionView : DiscriminantFunctionClassificationSolutionView {
    31     public SymbolicDiscriminantFunctionClassificationSolutionView() {
     30  public partial class SymbolicClassificationSolutionView : ClassificationSolutionView {
     31    public SymbolicClassificationSolutionView() {
    3232      InitializeComponent();
    3333    }
    3434
    35     protected new SymbolicDiscriminantFunctionClassificationSolution Content {
    36       get { return (SymbolicDiscriminantFunctionClassificationSolution)base.Content; }
     35    protected new SymbolicClassificationSolution Content {
     36      get { return (SymbolicClassificationSolution)base.Content; }
    3737      set { base.Content = value; }
    3838    }
    3939
    4040    private void btn_SimplifyModel_Click(object sender, EventArgs e) {
    41       var view = new InteractiveSymbolicDiscriminantFunctionClassificationSolutionSimplifierView();
    42       view.Content = (SymbolicDiscriminantFunctionClassificationSolution)this.Content.Clone();
     41      var view = new InteractiveSymbolicClassificationSolutionSimplifierView();
     42      view.Content = (SymbolicClassificationSolution)this.Content.Clone();
    4343      view.Show();
    4444    }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/DiscriminantFunctionClassificationRocCurvesView.cs

    r8139 r8636  
    114114        maxThreshold += thresholdIncrement;
    115115
    116         List<double> classValues = Content.ProblemData.ClassValues.OrderBy(x => x).ToList();
     116        List<double> classValues = Content.Model.ClassValues.ToList();
    117117
    118118        foreach (double classValue in classValues) {
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/AccuracyMaximizationThresholdCalculator.cs

    r8573 r8636  
    5353
    5454    public static void CalculateThresholds(IClassificationProblemData problemData, IEnumerable<double> estimatedValues, IEnumerable<double> targetClassValues, out double[] classValues, out double[] thresholds) {
    55       int slices = 100;
    56       double minThresholdInc = 10e-5; // necessary to prevent infinite loop when maxEstimated - minEstimated is effectively zero (constant model)
     55      const int slices = 100;
     56      const double minThresholdInc = 10e-5; // necessary to prevent infinite loop when maxEstimated - minEstimated is effectively zero (constant model)
    5757      List<double> estimatedValuesList = estimatedValues.ToList();
    5858      double maxEstimatedValue = estimatedValuesList.Max();
     
    6161      var estimatedAndTargetValuePairs =
    6262        estimatedValuesList.Zip(targetClassValues, (x, y) => new { EstimatedValue = x, TargetClassValue = y })
    63         .OrderBy(x => x.EstimatedValue)
    64         .ToList();
     63        .OrderBy(x => x.EstimatedValue).ToList();
    6564
    66       classValues = problemData.ClassValues.OrderBy(x => x).ToArray();
     65      classValues = estimatedAndTargetValuePairs.GroupBy(x => x.TargetClassValue)
     66        .Select(x => new { Median = x.Select(y => y.EstimatedValue).Median(), Class = x.Key })
     67        .OrderBy(x => x.Median).Select(x => x.Class).ToArray();
     68
    6769      int nClasses = classValues.Length;
    6870      thresholds = new double[nClasses];
    6971      thresholds[0] = double.NegativeInfinity;
    70       // thresholds[thresholds.Length - 1] = double.PositiveInfinity;
    7172
    7273      // incrementally calculate accuracy of all possible thresholds
Note: See TracChangeset for help on using the changeset viewer.