Free cookie consent management tool by TermsFeed Policy Generator

Changeset 5664


Ignore:
Timestamp:
03/10/11 16:27:48 (14 years ago)
Author:
gkronber
Message:

#1418 ported ROC, confusion matrix and discriminant function classification views and fixed bug in threshold calculation.

Location:
branches/DataAnalysis Refactoring
Files:
4 edited
6 copied

Legend:

Unmodified
Added
Removed
  • branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/LinearDiscriminantAnalysis.cs

    r5658 r5664  
    7979      // change class values into class index
    8080      int targetVariableColumn = inputMatrix.GetLength(1) - 1;
    81       List<double> classValues = problemData.ClassValues.OrderBy(x=>x).ToList();
     81      List<double> classValues = problemData.ClassValues.OrderBy(x => x).ToList();
    8282      for (int row = 0; row < inputMatrix.GetLength(0); row++) {
    8383        inputMatrix[row, targetVariableColumn] = classValues.IndexOf(inputMatrix[row, targetVariableColumn]);
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationSolutionConfusionMatrixView.Designer.cs

    r5642 r5664  
    11namespace HeuristicLab.Problems.DataAnalysis.Classification.Views {
    2   partial class ConfusionMatrixView {
     2  partial class ClassificationSolutionConfusionMatrixView {
    33    /// <summary>
    44    /// Required designer variable.
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationSolutionConfusionMatrixView.cs

    r5642 r5664  
    2828
    2929namespace HeuristicLab.Problems.DataAnalysis.Classification.Views {
    30   [View("Confusion Matrix View")]
    31   [Content(typeof(SymbolicClassificationSolution))]
    32   public partial class ConfusionMatrixView : AsynchronousContentView {
     30  [View("Classification solution confusion matrix view")]
     31  [Content(typeof(IClassificationSolution))]
     32  public partial class ClassificationSolutionConfusionMatrixView : AsynchronousContentView {
    3333    private const string TrainingSamples = "Training";
    3434    private const string TestSamples = "Test";
    35     public ConfusionMatrixView() {
     35    public ClassificationSolutionConfusionMatrixView() {
    3636      InitializeComponent();
    3737      cmbSamples.Items.Add(TrainingSamples);
     
    4040    }
    4141
    42     public new SymbolicClassificationSolution Content {
    43       get { return (SymbolicClassificationSolution)base.Content; }
     42    public new IClassificationSolution Content {
     43      get { return (IClassificationSolution)base.Content; }
    4444      set { base.Content = value; }
    4545    }
     
    4747    protected override void RegisterContentEvents() {
    4848      base.RegisterContentEvents();
    49       Content.EstimatedValuesChanged += new EventHandler(Content_EstimatedValuesChanged);
     49      Content.ModelChanged += new EventHandler(Content_ModelChanged);
    5050      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
    51       Content.ThresholdsChanged += new EventHandler(Content_ThresholdsChanged);
    5251    }
    5352
     
    5554    protected override void DeregisterContentEvents() {
    5655      base.DeregisterContentEvents();
    57       Content.EstimatedValuesChanged -= new EventHandler(Content_EstimatedValuesChanged);
     56      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
    5857      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
    59       Content.ThresholdsChanged -= new EventHandler(Content_ThresholdsChanged);
    6058    }
    6159
    62     private void Content_EstimatedValuesChanged(object sender, EventArgs e) {
     60    private void Content_ModelChanged(object sender, EventArgs e) {
    6361      FillDataGridView();
    6462    }
    6563    private void Content_ProblemDataChanged(object sender, EventArgs e) {
    6664      UpdateDataGridView();
    67     }
    68     private void Content_ThresholdsChanged(object sender, EventArgs e) {
    69       FillDataGridView();
    7065    }
    7166
     
    8277          dataGridView.ColumnCount = 1;
    8378        } else {
    84           dataGridView.ColumnCount = Content.ProblemData.NumberOfClasses;
    85           dataGridView.RowCount = Content.ProblemData.NumberOfClasses;
     79          dataGridView.ColumnCount = Content.ProblemData.Classes;
     80          dataGridView.RowCount = Content.ProblemData.Classes;
    8681
    8782          int i = 0;
     
    10499        if (Content == null) return;
    105100
    106         double[,] confusionMatrix = new double[Content.ProblemData.NumberOfClasses, Content.ProblemData.NumberOfClasses];
     101        double[,] confusionMatrix = new double[Content.ProblemData.Classes, Content.ProblemData.Classes];
    107102        IEnumerable<int> rows;
    108103
     
    115110        Dictionary<double, int> classValueIndexMapping = new Dictionary<double, int>();
    116111        int index = 0;
    117         foreach (double classValue in Content.ProblemData.SortedClassValues) {
     112        foreach (double classValue in Content.ProblemData.ClassValues.OrderBy(x => x)) {
    118113          classValueIndexMapping.Add(classValue, index);
    119114          index++;
    120115        }
    121116
    122         double[] targetValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable.Value, rows).ToArray();
     117        double[] targetValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable, rows).ToArray();
    123118        double[] predictedValues = Content.GetEstimatedClassValues(rows).ToArray();
    124119
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/DiscriminantFunctionClassificationRocCurvesView.Designer.cs

    r5642 r5664  
    11namespace HeuristicLab.Problems.DataAnalysis.Classification.Views {
    2   partial class RocCurvesView {
     2  partial class DiscriminantFunctionClassificationRocCurvesView {
    33    /// <summary>
    44    /// Required designer variable.
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/DiscriminantFunctionClassificationRocCurvesView.cs

    r5642 r5664  
    3131using HeuristicLab.MainForm.WindowsForms;
    3232namespace HeuristicLab.Problems.DataAnalysis.Classification.Views {
    33   [View("ROC Curves View")]
    34   [Content(typeof(SymbolicClassificationSolution))]
    35   public partial class RocCurvesView : AsynchronousContentView {
     33  [View("Discriminant function classification solution ROC curves view")]
     34  [Content(typeof(IDiscriminantFunctionClassificationSolution))]
     35  public partial class DiscriminantFunctionClassificationRocCurvesView : AsynchronousContentView {
    3636    private const string xAxisTitle = "False Positive Rate";
    3737    private const string yAxisTitle = "True Positive Rate";
     
    4040    private Dictionary<string, List<ROCPoint>> cachedRocPoints;
    4141
    42     public RocCurvesView() {
     42    public DiscriminantFunctionClassificationRocCurvesView() {
    4343      InitializeComponent();
    4444
     
    6161    }
    6262
    63     public new SymbolicClassificationSolution Content {
    64       get { return (SymbolicClassificationSolution)base.Content; }
     63    public new IDiscriminantFunctionClassificationSolution Content {
     64      get { return (IDiscriminantFunctionClassificationSolution)base.Content; }
    6565      set { base.Content = value; }
    6666    }
     
    6868    protected override void RegisterContentEvents() {
    6969      base.RegisterContentEvents();
    70       Content.EstimatedValuesChanged += new EventHandler(Content_EstimatedValuesChanged);
     70      Content.ModelChanged += new EventHandler(Content_ModelChanged);
    7171      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
    7272    }
    7373    protected override void DeregisterContentEvents() {
    7474      base.DeregisterContentEvents();
    75       Content.EstimatedValuesChanged -= new EventHandler(Content_EstimatedValuesChanged);
     75      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
    7676      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
    7777    }
    7878
    79     private void Content_EstimatedValuesChanged(object sender, EventArgs e) {
     79    private void Content_ModelChanged(object sender, EventArgs e) {
    8080      UpdateChart();
    8181    }
     
    107107
    108108        double[] estimatedValues = Content.GetEstimatedValues(rows).ToArray();
    109         double[] targetClassValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable.Value, rows).ToArray();
     109        double[] targetClassValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable, rows).ToArray();
    110110        double minThreshold = estimatedValues.Min();
    111111        double maxThreshold = estimatedValues.Max();
     
    114114        maxThreshold += thresholdIncrement;
    115115
    116         List<double> classValues = Content.ProblemData.SortedClassValues.ToList();
     116        List<double> classValues = Content.ProblemData.ClassValues.OrderBy(x => x).ToList();
    117117
    118118        foreach (double classValue in classValues) {
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/DiscriminantFunctionClassificationSolutionView.Designer.cs

    r5642 r5664  
    11namespace HeuristicLab.Problems.DataAnalysis.Classification.Views {
    2   partial class SymbolicClassificationSolutionView {
     2  partial class DiscriminantFunctionClassificationSolutionView {
    33    /// <summary>
    44    /// Required designer variable.
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/DiscriminantFunctionClassificationSolutionView.cs

    r5642 r5664  
    3131
    3232namespace HeuristicLab.Problems.DataAnalysis.Classification.Views {
    33   [View("Symbolic Classification View")]
    34   [Content(typeof(SymbolicClassificationSolution), true)]
    35   public sealed partial class SymbolicClassificationSolutionView : AsynchronousContentView {
     33  [View("Discriminant function classification solution view")]
     34  [Content(typeof(IDiscriminantFunctionClassificationSolution), true)]
     35  public sealed partial class DiscriminantFunctionClassificationSolutionView : AsynchronousContentView {
    3636    private const double TrainingAxisValue = 0.0;
    3737    private const double TestAxisValue = 10.0;
     
    4040    private const string TestLabelText = "Test Samples";
    4141
    42     public new SymbolicClassificationSolution Content {
    43       get { return (SymbolicClassificationSolution)base.Content; }
     42    public new IDiscriminantFunctionClassificationSolution Content {
     43      get { return (IDiscriminantFunctionClassificationSolution)base.Content; }
    4444      set { base.Content = value; }
    4545    }
     
    4949    private bool updateInProgress;
    5050
    51     public SymbolicClassificationSolutionView()
     51    public DiscriminantFunctionClassificationSolutionView()
    5252      : base() {
    5353      InitializeComponent();
     
    8585    protected override void RegisterContentEvents() {
    8686      base.RegisterContentEvents();
    87       Content.EstimatedValuesChanged += new EventHandler(Content_EstimatedValuesChanged);
     87      Content.ModelChanged += new EventHandler(Content_ModelChanged);
    8888      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
    8989      Content.ThresholdsChanged += new EventHandler(Content_ThresholdsChanged);
     
    9191    protected override void DeregisterContentEvents() {
    9292      base.DeregisterContentEvents();
    93       Content.EstimatedValuesChanged -= new EventHandler(Content_EstimatedValuesChanged);
     93      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
    9494      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
    9595      Content.ThresholdsChanged -= new EventHandler(Content_ThresholdsChanged);
     
    9999      UpdateChart();
    100100    }
    101     private void Content_EstimatedValuesChanged(object sender, EventArgs e) {
     101    private void Content_ModelChanged(object sender, EventArgs e) {
    102102      UpdateChart();
    103103    }
     
    118118        if (Content != null) {
    119119          IEnumerator<string> classNameEnumerator = Content.ProblemData.ClassNames.GetEnumerator();
    120           IEnumerator<double> classValueEnumerator = Content.ProblemData.SortedClassValues.GetEnumerator();
     120          IEnumerator<double> classValueEnumerator = Content.ProblemData.ClassValues.OrderBy(x => x).GetEnumerator();
    121121          while (classNameEnumerator.MoveNext() && classValueEnumerator.MoveNext()) {
    122122            Series series = new Series(classNameEnumerator.Current);
     
    138138      foreach (int row in Content.ProblemData.TrainingIndizes) {
    139139        double estimatedValue = estimatedValues[row];
    140         double targetValue = Content.ProblemData.Dataset[Content.ProblemData.TargetVariable.Value, row];
     140        double targetValue = Content.ProblemData.Dataset[Content.ProblemData.TargetVariable, row];
    141141        if (targetValue.IsAlmost((double)series.Tag)) {
    142142          double jitterValue = random.NextDouble() * 2.0 - 1.0;
     
    151151      foreach (int row in Content.ProblemData.TestIndizes) {
    152152        double estimatedValue = estimatedValues[row];
    153         double targetValue = Content.ProblemData.Dataset[Content.ProblemData.TargetVariable.Value, row];
     153        double targetValue = Content.ProblemData.Dataset[Content.ProblemData.TargetVariable, row];
    154154        if (targetValue == (double)series.Tag) {
    155155          double jitterValue = random.NextDouble() * 2.0 - 1.0;
     
    235235    private void chart_AnnotationPositionChanging(object sender, AnnotationPositionChangingEventArgs e) {
    236236      int classIndex = (int)e.Annotation.Tag;
    237 
    238       double classValue = Content.ProblemData.SortedClassValues.ElementAt(classIndex);
    239       if (e.NewLocationY >= classValue)
    240         e.NewLocationY = classValue;
    241 
    242       classValue = Content.ProblemData.SortedClassValues.ElementAt(classIndex - 1);
    243       if (e.NewLocationY <= classValue)
    244         e.NewLocationY = classValue;
    245 
    246237      double[] thresholds = Content.Thresholds.ToArray();
     238      double max = thresholds[classIndex + 1];
     239      double min = thresholds[classIndex - 1];
     240
     241      if (e.NewLocationY >= max)
     242        e.NewLocationY = max;
     243
     244      if (e.NewLocationY <= min)
     245        e.NewLocationY = min;
     246
    247247      thresholds[classIndex] = e.NewLocationY;
    248248      Content.Thresholds = thresholds;
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Views/3.4/HeuristicLab.Problems.DataAnalysis.Views-3.4.csproj

    r5663 r5664  
    110110  </ItemGroup>
    111111  <ItemGroup>
     112    <Compile Include="Classification\ClassificationSolutionConfusionMatrixView.cs">
     113      <SubType>UserControl</SubType>
     114    </Compile>
     115    <Compile Include="Classification\ClassificationSolutionConfusionMatrixView.Designer.cs">
     116      <DependentUpon>ClassificationSolutionConfusionMatrixView.cs</DependentUpon>
     117    </Compile>
    112118    <Compile Include="Classification\ClassificationSolutionEstimatedClassValuesView.cs">
    113119      <SubType>UserControl</SubType>
     
    115121    <Compile Include="Classification\ClassificationSolutionEstimatedClassValuesView.Designer.cs">
    116122      <DependentUpon>ClassificationSolutionEstimatedClassValuesView.cs</DependentUpon>
     123    </Compile>
     124    <Compile Include="Classification\DiscriminantFunctionClassificationRocCurvesView.cs">
     125      <SubType>UserControl</SubType>
     126    </Compile>
     127    <Compile Include="Classification\DiscriminantFunctionClassificationRocCurvesView.Designer.cs">
     128      <DependentUpon>DiscriminantFunctionClassificationRocCurvesView.cs</DependentUpon>
     129    </Compile>
     130    <Compile Include="Classification\DiscriminantFunctionClassificationSolutionView.cs">
     131      <SubType>UserControl</SubType>
     132    </Compile>
     133    <Compile Include="Classification\DiscriminantFunctionClassificationSolutionView.Designer.cs">
     134      <DependentUpon>DiscriminantFunctionClassificationSolutionView.cs</DependentUpon>
    117135    </Compile>
    118136    <Compile Include="Regression\RegressionSolutionEstimatedValuesView.cs">
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/DiscriminantFunctionClassificationSolution.cs

    r5657 r5664  
    7777        return Model.Thresholds;
    7878      }
    79       protected set { Model.Thresholds = value; }
     79      set { Model.Thresholds = new List<double>(value); }
    8080    }
    8181
     
    117117      thresholds[thresholds.Length - 1] = double.PositiveInfinity;
    118118
     119      double thresholdIncrement = (maxEstimatedValue - minEstimatedValue) / slices;
     120
    119121      for (int i = 1; i < thresholds.Length - 1; i++) {
    120122        double lowerThreshold = thresholds[i - 1];
    121         double actualThreshold = minEstimatedValue;
    122         double thresholdIncrement = (maxEstimatedValue - minEstimatedValue) / slices;
    123 
     123        double actualThreshold = Math.Max(lowerThreshold, minEstimatedValue);
    124124        double lowestBestThreshold = double.NaN;
    125125        double highestBestThreshold = double.NaN;
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IDiscriminantFunctionClassificationSolution.cs

    r5649 r5664  
    3131    IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows);
    3232
    33     IEnumerable<double> Thresholds { get; }
     33    IEnumerable<double> Thresholds { get; set; }
    3434
    3535    event EventHandler ThresholdsChanged;
Note: See TracChangeset for help on using the changeset viewer.