Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/22/10 12:14:38 (14 years ago)
Author:
mkommend
Message:

Added logic to remove the test samples from the training samples (ticket #939).

Location:
branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification.Views/3.3
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification.Views/3.3/ConfusionMatrixView.cs

    r4417 r4469  
    105105
    106106        double[,] confusionMatrix = new double[Content.ProblemData.NumberOfClasses, Content.ProblemData.NumberOfClasses];
    107         int start;
    108         int end;
     107        IEnumerable<int> rows;
    109108
    110109        if (cmbSamples.SelectedItem.ToString() == TrainingSamples) {
    111           start = Content.ProblemData.TrainingSamplesStart.Value;
    112           end = Content.ProblemData.TrainingSamplesEnd.Value;
     110          rows = Content.ProblemData.TrainingIndizes;
    113111        } else if (cmbSamples.SelectedItem.ToString() == TestSamples) {
    114           start = Content.ProblemData.TestSamplesStart.Value;
    115           end = Content.ProblemData.TestSamplesEnd.Value;
     112          rows = Content.ProblemData.TestIndizes;
    116113        } else throw new InvalidOperationException();
    117114
     
    123120        }
    124121
    125         double[] targetValues = Content.ProblemData.Dataset.GetVariableValues(Content.ProblemData.TargetVariable.Value, start, end);
    126         double[] predictedValues = Content.EstimatedClassValues.Skip(start).Take(end - start).ToArray();
     122        double[] targetValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable.Value, rows).ToArray();
     123        double[] predictedValues = Content.GetEstimatedClassValues(rows).ToArray();
    127124
    128125        for (int i = 0; i < targetValues.Length; i++) {
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification.Views/3.3/RocCurvesView.cs

    r4417 r4469  
    9797
    9898        int slices = 100;
    99         int samplesStart = Content.ProblemData.TrainingSamplesStart.Value;
    100         int samplesEnd = Content.ProblemData.TrainingSamplesEnd.Value;
     99        IEnumerable<int> rows;
    101100
    102101        if (cmbSamples.SelectedItem.ToString() == TrainingSamples) {
    103           samplesStart = Content.ProblemData.TrainingSamplesStart.Value;
    104           samplesEnd = Content.ProblemData.TrainingSamplesEnd.Value;
     102          rows = Content.ProblemData.TrainingIndizes;
    105103        } else if (cmbSamples.SelectedItem.ToString() == TestSamples) {
    106           samplesStart = Content.ProblemData.TestSamplesStart.Value;
    107           samplesEnd = Content.ProblemData.TestSamplesEnd.Value;
     104          rows = Content.ProblemData.TestIndizes;
    108105        } else throw new InvalidOperationException();
    109106
    110         double[] estimatedValues = Content.EstimatedValues.Skip(samplesStart).Take(samplesEnd - samplesStart).ToArray();
    111         double[] targetClassValues = Content.ProblemData.Dataset.GetVariableValues(Content.ProblemData.TargetVariable.Value)
    112           .Skip(samplesStart).Take(samplesEnd - samplesStart).ToArray();
     107        double[] estimatedValues = Content.GetEstimatedValues(rows).ToArray();
     108        double[] targetClassValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable.Value, rows).ToArray();
    113109        double minThreshold = estimatedValues.Min();
    114110        double maxThreshold = estimatedValues.Max();
     
    122118          List<ROCPoint> rocPoints = new List<ROCPoint>();
    123119          int positives = targetClassValues.Where(c => c.IsAlmost(classValue)).Count();
    124           int negatives = samplesEnd - samplesStart - positives;
     120          int negatives = targetClassValues.Length - positives;
    125121
    126122          for (double lowerThreshold = minThreshold; lowerThreshold < maxThreshold; lowerThreshold += thresholdIncrement) {
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification.Views/3.3/SymbolicClassificationSolutionView.cs

    r4417 r4469  
    134134
    135135    private void FillSeriesWithDataPoints(Series series) {
    136       int row = Content.ProblemData.TrainingSamplesStart.Value;
    137       foreach (double estimatedValue in Content.EstimatedTrainingValues) {
     136      List<double> estimatedValues = Content.EstimatedValues.ToList();
     137      foreach (int row in Content.ProblemData.TrainingIndizes) {
     138        double estimatedValue = estimatedValues[row];
    138139        double targetValue = Content.ProblemData.Dataset[Content.ProblemData.TargetVariable.Value, row];
    139         if (targetValue == (double)series.Tag) {
     140        if (targetValue.IsAlmost((double)series.Tag)) {
    140141          double jitterValue = random.NextDouble() * 2.0 - 1.0;
    141142          DataPoint point = new DataPoint();
     
    145146          series.Points.Add(point);
    146147        }
    147         row++;
    148       }
    149 
    150       row = Content.ProblemData.TestSamplesStart.Value;
    151       foreach (double estimatedValue in Content.EstimatedTestValues) {
     148      }
     149
     150      foreach (int row in Content.ProblemData.TestIndizes) {
     151        double estimatedValue = estimatedValues[row];
    152152        double targetValue = Content.ProblemData.Dataset[Content.ProblemData.TargetVariable.Value, row];
    153153        if (targetValue == (double)series.Tag) {
     
    159159          series.Points.Add(point);
    160160        }
    161         row++;
    162       }
     161      }
     162
    163163      UpdateCursorInterval();
    164164    }
Note: See TracChangeset for help on using the changeset viewer.