Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/22/10 12:14:38 (14 years ago)
Author:
mkommend
Message:

Added logic to remove the test samples from the training samples (ticket #939).

Location:
branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Symbolic
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Symbolic/Analyzer/ValidationBestSymbolicClassificationSolutionAnalyzer.cs

    r4417 r4469  
    217217      int count = (int)((validationEnd - validationStart) * RelativeNumberOfEvaluatedSamples.Value);
    218218      if (count == 0) count = 1;
    219       IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count);
     219      IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count)
     220         .Where(row => row < ClassificationProblemData.TestSamplesStart.Value || ClassificationProblemData.TestSamplesEnd.Value <= row);
    220221
    221222      double upperEstimationLimit = UpperEstimationLimit != null ? UpperEstimationLimit.Value : double.PositiveInfinity;
     
    244245      if (newBest) {
    245246        double alpha, beta;
    246         int trainingStart = ClassificationProblemData.TrainingSamplesStart.Value;
    247         int trainingEnd = ClassificationProblemData.TrainingSamplesEnd.Value;
    248         IEnumerable<int> trainingRows = Enumerable.Range(trainingStart, trainingEnd - trainingStart);
    249247        SymbolicRegressionScaledMeanSquaredErrorEvaluator.Calculate(SymbolicExpressionTreeInterpreter, bestTree,
    250248          lowerEstimationLimit, upperEstimationLimit,
    251249          ClassificationProblemData.Dataset, targetVariable,
    252           trainingRows, out beta, out alpha);
     250          ClassificationProblemData.TrainingIndizes, out beta, out alpha);
    253251
    254252        // scale tree for solution
     
    275273
    276274      IEnumerable<double> trainingValues = ClassificationProblemData.Dataset.GetEnumeratedVariableValues(
    277         ClassificationProblemData.TargetVariable.Value,
    278         ClassificationProblemData.TrainingSamplesStart.Value,
    279         ClassificationProblemData.TrainingSamplesEnd.Value);
     275        ClassificationProblemData.TargetVariable.Value, ClassificationProblemData.TrainingIndizes);
    280276      IEnumerable<double> testValues = ClassificationProblemData.Dataset.GetEnumeratedVariableValues(
    281         ClassificationProblemData.TargetVariable.Value,
    282         ClassificationProblemData.TestSamplesStart.Value,
    283         ClassificationProblemData.TestSamplesEnd.Value);
     277        ClassificationProblemData.TargetVariable.Value, ClassificationProblemData.TestIndizes);
    284278
    285279      OnlineAccuracyEvaluator accuracyEvaluator = new OnlineAccuracyEvaluator();
  • branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Symbolic/SymbolicClassificationSolution.cs

    r4417 r4469  
    6262
    6363      List<KeyValuePair<double, double>> estimatedTargetValues =
    64          (from row in Enumerable.Range(ProblemData.TrainingSamplesStart.Value, ProblemData.TrainingSamplesEnd.Value - ProblemData.TrainingSamplesStart.Value)
     64         (from row in ProblemData.TrainingIndizes
    6565          select new KeyValuePair<double, double>(
    6666            estimatedValues[row],
     
    131131
    132132    public IEnumerable<double> EstimatedClassValues {
    133       get {
    134         double[] classValues = ProblemData.SortedClassValues.ToArray();
    135         foreach (double value in EstimatedValues) {
    136           int classIndex = 0;
    137           while (value > actualThresholds[classIndex + 1])
    138             classIndex++;
    139           yield return classValues[classIndex];
    140         }
    141       }
     133      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
    142134    }
    143135
    144136    public IEnumerable<double> EstimatedTrainingClassValues {
    145       get {
    146         int start = ProblemData.TrainingSamplesStart.Value;
    147         int n = ProblemData.TrainingSamplesEnd.Value - start;
    148         return EstimatedClassValues.Skip(start).Take(n).ToList();
    149       }
     137      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
    150138    }
    151139
    152140    public IEnumerable<double> EstimatedTestClassValues {
    153       get {
    154         int start = ProblemData.TestSamplesStart.Value;
    155         int n = ProblemData.TestSamplesEnd.Value - start;
    156         return EstimatedClassValues.Skip(start).Take(n).ToList();
     141      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
     142    }
     143
     144    public IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     145      double[] classValues = ProblemData.SortedClassValues.ToArray();
     146      foreach (int row in rows) {
     147        double value = estimatedValues[row];
     148        int classIndex = 0;
     149        while (value > actualThresholds[classIndex + 1])
     150          classIndex++;
     151        yield return classValues[classIndex];
    157152      }
    158153    }
Note: See TracChangeset for help on using the changeset viewer.