Ignore:
Timestamp:
04/10/17 15:48:20 (2 years ago)
Author:
gkronber
Message:

#2700 merged changesets from trunk to branch

Location:
branches/TSNE/HeuristicLab.Algorithms.DataAnalysis
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis

  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers/OneR.cs

    r14185 r14836  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     25using System.Threading;
    2426using HeuristicLab.Common;
    2527using HeuristicLab.Core;
     
    5860    }
    5961
    60     protected override void Run() {
     62    protected override void Run(CancellationToken cancellationToken) {
    6163      var solution = CreateOneRSolution(Problem.ProblemData, MinBucketSizeParameter.Value.Value);
    6264      Results.Add(new Result("OneR solution", "The 1R classifier.", solution));
     
    6466
    6567    public static IClassificationSolution CreateOneRSolution(IClassificationProblemData problemData, int minBucketSize = 6) {
     68      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     69      var model1 = FindBestDoubleVariableModel(problemData, minBucketSize);
     70      var model2 = FindBestFactorModel(problemData);
     71
     72      if (model1 == null && model2 == null) throw new InvalidProgramException("Could not create OneR solution");
     73      else if (model1 == null) return new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone());
     74      else if (model2 == null) return new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone());
     75      else {
     76        var model1EstimatedValues = model1.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     77        var model1NumCorrect = classValues.Zip(model1EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     78
     79        var model2EstimatedValues = model2.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     80        var model2NumCorrect = classValues.Zip(model2EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     81
     82        if (model1NumCorrect > model2NumCorrect) {
     83          return new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone());
     84        } else {
     85          return new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone());
     86        }
     87      }
     88    }
     89
     90    private static OneRClassificationModel FindBestDoubleVariableModel(IClassificationProblemData problemData, int minBucketSize = 6) {
    6691      var bestClassified = 0;
    6792      List<Split> bestSplits = null;
     
    7095      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
    7196
    72       foreach (var variable in problemData.AllowedInputVariables) {
     97      var allowedInputVariables = problemData.AllowedInputVariables.Where(problemData.Dataset.VariableHasType<double>);
     98
     99      if (!allowedInputVariables.Any()) return null;
     100
     101      foreach (var variable in allowedInputVariables) {
    73102        var inputValues = problemData.Dataset.GetDoubleValues(variable, problemData.TrainingIndices);
    74103        var samples = inputValues.Zip(classValues, (i, v) => new Sample(i, v)).OrderBy(s => s.inputValue);
    75104
    76         var missingValuesDistribution = samples.Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue).ToDictionary(s => s.Key, s => s.Count()).MaxItems(s => s.Value).FirstOrDefault();
     105        var missingValuesDistribution = samples
     106          .Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue)
     107          .ToDictionary(s => s.Key, s => s.Count())
     108          .MaxItems(s => s.Value)
     109          .FirstOrDefault();
    77110
    78111        //calculate class distributions for all distinct inputValues
     
    119152          while (sample.inputValue >= splits[splitIndex].thresholdValue)
    120153            splitIndex++;
    121           correctClassified += sample.classValue == splits[splitIndex].classValue ? 1 : 0;
     154          correctClassified += sample.classValue.IsAlmost(splits[splitIndex].classValue) ? 1 : 0;
    122155        }
    123156        correctClassified += missingValuesDistribution.Value;
     
    133166      //remove neighboring splits with the same class value
    134167      for (int i = 0; i < bestSplits.Count - 1; i++) {
    135         if (bestSplits[i].classValue == bestSplits[i + 1].classValue) {
     168        if (bestSplits[i].classValue.IsAlmost(bestSplits[i + 1].classValue)) {
    136169          bestSplits.Remove(bestSplits[i]);
    137170          i--;
     
    139172      }
    140173
    141       var model = new OneRClassificationModel(problemData.TargetVariable, bestVariable, bestSplits.Select(s => s.thresholdValue).ToArray(), bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass);
    142       var solution = new OneRClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
    143 
    144       return solution;
     174      var model = new OneRClassificationModel(problemData.TargetVariable, bestVariable,
     175        bestSplits.Select(s => s.thresholdValue).ToArray(),
     176        bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass);
     177
     178      return model;
     179    }
     180    private static OneFactorClassificationModel FindBestFactorModel(IClassificationProblemData problemData) {
     181      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     182      var defaultClass = FindMostFrequentClassValue(classValues);
     183      // only select string variables
     184      var allowedInputVariables = problemData.AllowedInputVariables.Where(problemData.Dataset.VariableHasType<string>);
     185
     186      if (!allowedInputVariables.Any()) return null;
     187
     188      OneFactorClassificationModel bestModel = null;
     189      var bestModelNumCorrect = 0;
     190
     191      foreach (var variable in allowedInputVariables) {
     192        var variableValues = problemData.Dataset.GetStringValues(variable, problemData.TrainingIndices);
     193        var groupedClassValues = variableValues
     194          .Zip(classValues, (v, c) => new KeyValuePair<string, double>(v, c))
     195          .GroupBy(kvp => kvp.Key)
     196          .ToDictionary(g => g.Key, g => FindMostFrequentClassValue(g.Select(kvp => kvp.Value)));
     197
     198        var model = new OneFactorClassificationModel(problemData.TargetVariable, variable,
     199          groupedClassValues.Select(kvp => kvp.Key).ToArray(), groupedClassValues.Select(kvp => kvp.Value).ToArray(), defaultClass);
     200
     201        var modelEstimatedValues = model.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     202        var modelNumCorrect = classValues.Zip(modelEstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     203        if (modelNumCorrect > bestModelNumCorrect) {
     204          bestModelNumCorrect = modelNumCorrect;
     205          bestModel = model;
     206        }
     207      }
     208
     209      return bestModel;
     210    }
     211
     212    private static double FindMostFrequentClassValue(IEnumerable<double> classValues) {
     213      return classValues.GroupBy(c => c).OrderByDescending(g => g.Count()).Select(g => g.Key).First();
    145214    }
    146215
Note: See TracChangeset for help on using the changeset viewer.