Changeset 4469


Ignore:
Timestamp:
09/22/10 12:14:38 (9 years ago)
Author:
mkommend
Message:

Added logic to remove the test samples from the training samples (ticket #939).

Location:
branches/HeuristicLab.Classification
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification.Views/3.3/ConfusionMatrixView.cs

    r4417 r4469  
    105105
    106106        double[,] confusionMatrix = new double[Content.ProblemData.NumberOfClasses, Content.ProblemData.NumberOfClasses];
    107         int start;
    108         int end;
     107        IEnumerable<int> rows;
    109108
    110109        if (cmbSamples.SelectedItem.ToString() == TrainingSamples) {
    111           start = Content.ProblemData.TrainingSamplesStart.Value;
    112           end = Content.ProblemData.TrainingSamplesEnd.Value;
     110          rows = Content.ProblemData.TrainingIndizes;
    113111        } else if (cmbSamples.SelectedItem.ToString() == TestSamples) {
    114           start = Content.ProblemData.TestSamplesStart.Value;
    115           end = Content.ProblemData.TestSamplesEnd.Value;
     112          rows = Content.ProblemData.TestIndizes;
    116113        } else throw new InvalidOperationException();
    117114
     
    123120        }
    124121
    125         double[] targetValues = Content.ProblemData.Dataset.GetVariableValues(Content.ProblemData.TargetVariable.Value, start, end);
    126         double[] predictedValues = Content.EstimatedClassValues.Skip(start).Take(end - start).ToArray();
     122        double[] targetValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable.Value, rows).ToArray();
     123        double[] predictedValues = Content.GetEstimatedClassValues(rows).ToArray();
    127124
    128125        for (int i = 0; i < targetValues.Length; i++) {
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification.Views/3.3/RocCurvesView.cs

    r4417 r4469  
    9797
    9898        int slices = 100;
    99         int samplesStart = Content.ProblemData.TrainingSamplesStart.Value;
    100         int samplesEnd = Content.ProblemData.TrainingSamplesEnd.Value;
     99        IEnumerable<int> rows;
    101100
    102101        if (cmbSamples.SelectedItem.ToString() == TrainingSamples) {
    103           samplesStart = Content.ProblemData.TrainingSamplesStart.Value;
    104           samplesEnd = Content.ProblemData.TrainingSamplesEnd.Value;
     102          rows = Content.ProblemData.TrainingIndizes;
    105103        } else if (cmbSamples.SelectedItem.ToString() == TestSamples) {
    106           samplesStart = Content.ProblemData.TestSamplesStart.Value;
    107           samplesEnd = Content.ProblemData.TestSamplesEnd.Value;
     104          rows = Content.ProblemData.TestIndizes;
    108105        } else throw new InvalidOperationException();
    109106
    110         double[] estimatedValues = Content.EstimatedValues.Skip(samplesStart).Take(samplesEnd - samplesStart).ToArray();
    111         double[] targetClassValues = Content.ProblemData.Dataset.GetVariableValues(Content.ProblemData.TargetVariable.Value)
    112           .Skip(samplesStart).Take(samplesEnd - samplesStart).ToArray();
     107        double[] estimatedValues = Content.GetEstimatedValues(rows).ToArray();
     108        double[] targetClassValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable.Value, rows).ToArray();
    113109        double minThreshold = estimatedValues.Min();
    114110        double maxThreshold = estimatedValues.Max();
     
    122118          List<ROCPoint> rocPoints = new List<ROCPoint>();
    123119          int positives = targetClassValues.Where(c => c.IsAlmost(classValue)).Count();
    124           int negatives = samplesEnd - samplesStart - positives;
     120          int negatives = targetClassValues.Length - positives;
    125121
    126122          for (double lowerThreshold = minThreshold; lowerThreshold < maxThreshold; lowerThreshold += thresholdIncrement) {
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification.Views/3.3/SymbolicClassificationSolutionView.cs

    r4417 r4469  
    134134
    135135    private void FillSeriesWithDataPoints(Series series) {
    136       int row = Content.ProblemData.TrainingSamplesStart.Value;
    137       foreach (double estimatedValue in Content.EstimatedTrainingValues) {
     136      List<double> estimatedValues = Content.EstimatedValues.ToList();
     137      foreach (int row in Content.ProblemData.TrainingIndizes) {
     138        double estimatedValue = estimatedValues[row];
    138139        double targetValue = Content.ProblemData.Dataset[Content.ProblemData.TargetVariable.Value, row];
    139         if (targetValue == (double)series.Tag) {
     140        if (targetValue.IsAlmost((double)series.Tag)) {
    140141          double jitterValue = random.NextDouble() * 2.0 - 1.0;
    141142          DataPoint point = new DataPoint();
     
    145146          series.Points.Add(point);
    146147        }
    147         row++;
    148       }
    149 
    150       row = Content.ProblemData.TestSamplesStart.Value;
    151       foreach (double estimatedValue in Content.EstimatedTestValues) {
     148      }
     149
     150      foreach (int row in Content.ProblemData.TestIndizes) {
     151        double estimatedValue = estimatedValues[row];
    152152        double targetValue = Content.ProblemData.Dataset[Content.ProblemData.TargetVariable.Value, row];
    153153        if (targetValue == (double)series.Tag) {
     
    159159          series.Points.Add(point);
    160160        }
    161         row++;
    162       }
     161      }
     162
    163163      UpdateCursorInterval();
    164164    }
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Symbolic/Analyzer/ValidationBestSymbolicClassificationSolutionAnalyzer.cs

    r4417 r4469  
    217217      int count = (int)((validationEnd - validationStart) * RelativeNumberOfEvaluatedSamples.Value);
    218218      if (count == 0) count = 1;
    219       IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count);
     219      IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count)
     220         .Where(row => row < ClassificationProblemData.TestSamplesStart.Value || ClassificationProblemData.TestSamplesEnd.Value <= row);
    220221
    221222      double upperEstimationLimit = UpperEstimationLimit != null ? UpperEstimationLimit.Value : double.PositiveInfinity;
     
    244245      if (newBest) {
    245246        double alpha, beta;
    246         int trainingStart = ClassificationProblemData.TrainingSamplesStart.Value;
    247         int trainingEnd = ClassificationProblemData.TrainingSamplesEnd.Value;
    248         IEnumerable<int> trainingRows = Enumerable.Range(trainingStart, trainingEnd - trainingStart);
    249247        SymbolicRegressionScaledMeanSquaredErrorEvaluator.Calculate(SymbolicExpressionTreeInterpreter, bestTree,
    250248          lowerEstimationLimit, upperEstimationLimit,
    251249          ClassificationProblemData.Dataset, targetVariable,
    252           trainingRows, out beta, out alpha);
     250          ClassificationProblemData.TrainingIndizes, out beta, out alpha);
    253251
    254252        // scale tree for solution
     
    275273
    276274      IEnumerable<double> trainingValues = ClassificationProblemData.Dataset.GetEnumeratedVariableValues(
    277         ClassificationProblemData.TargetVariable.Value,
    278         ClassificationProblemData.TrainingSamplesStart.Value,
    279         ClassificationProblemData.TrainingSamplesEnd.Value);
     275        ClassificationProblemData.TargetVariable.Value, ClassificationProblemData.TrainingIndizes);
    280276      IEnumerable<double> testValues = ClassificationProblemData.Dataset.GetEnumeratedVariableValues(
    281         ClassificationProblemData.TargetVariable.Value,
    282         ClassificationProblemData.TestSamplesStart.Value,
    283         ClassificationProblemData.TestSamplesEnd.Value);
     277        ClassificationProblemData.TargetVariable.Value, ClassificationProblemData.TestIndizes);
    284278
    285279      OnlineAccuracyEvaluator accuracyEvaluator = new OnlineAccuracyEvaluator();
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Symbolic/SymbolicClassificationSolution.cs

    r4417 r4469  
    6262
    6363      List<KeyValuePair<double, double>> estimatedTargetValues =
    64          (from row in Enumerable.Range(ProblemData.TrainingSamplesStart.Value, ProblemData.TrainingSamplesEnd.Value - ProblemData.TrainingSamplesStart.Value)
     64         (from row in ProblemData.TrainingIndizes
    6565          select new KeyValuePair<double, double>(
    6666            estimatedValues[row],
     
    131131
    132132    public IEnumerable<double> EstimatedClassValues {
    133       get {
    134         double[] classValues = ProblemData.SortedClassValues.ToArray();
    135         foreach (double value in EstimatedValues) {
    136           int classIndex = 0;
    137           while (value > actualThresholds[classIndex + 1])
    138             classIndex++;
    139           yield return classValues[classIndex];
    140         }
    141       }
     133      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
    142134    }
    143135
    144136    public IEnumerable<double> EstimatedTrainingClassValues {
    145       get {
    146         int start = ProblemData.TrainingSamplesStart.Value;
    147         int n = ProblemData.TrainingSamplesEnd.Value - start;
    148         return EstimatedClassValues.Skip(start).Take(n).ToList();
    149       }
     137      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
    150138    }
    151139
    152140    public IEnumerable<double> EstimatedTestClassValues {
    153       get {
    154         int start = ProblemData.TestSamplesStart.Value;
    155         int n = ProblemData.TestSamplesEnd.Value - start;
    156         return EstimatedClassValues.Skip(start).Take(n).ToList();
     141      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
     142    }
     143
     144    public IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     145      double[] classValues = ProblemData.SortedClassValues.ToArray();
     146      foreach (int row in rows) {
     147        double value = estimatedValues[row];
     148        int classIndex = 0;
     149        while (value > actualThresholds[classIndex + 1])
     150          classIndex++;
     151        yield return classValues[classIndex];
    157152      }
    158153    }
Note: See TracChangeset for help on using the changeset viewer.