Ignore:
Timestamp:
04/04/17 17:52:44 (6 months ago)
Author:
gkronber
Message:

#2650: merged the factors branch into trunk

Location:
trunk/sources
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources

  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis

  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers/OneR.cs

    r14523 r14826  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     
    6566
    6667    public static IClassificationSolution CreateOneRSolution(IClassificationProblemData problemData, int minBucketSize = 6) {
     68      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     69      var model1 = FindBestDoubleVariableModel(problemData, minBucketSize);
     70      var model2 = FindBestFactorModel(problemData);
     71
     72      if (model1 == null && model2 == null) throw new InvalidProgramException("Could not create OneR solution");
     73      else if (model1 == null) return new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone());
     74      else if (model2 == null) return new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone());
     75      else {
     76        var model1EstimatedValues = model1.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     77        var model1NumCorrect = classValues.Zip(model1EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     78
     79        var model2EstimatedValues = model2.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     80        var model2NumCorrect = classValues.Zip(model2EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     81
     82        if (model1NumCorrect > model2NumCorrect) {
     83          return new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone());
     84        } else {
     85          return new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone());
     86        }
     87      }
     88    }
     89
     90    private static OneRClassificationModel FindBestDoubleVariableModel(IClassificationProblemData problemData, int minBucketSize = 6) {
    6791      var bestClassified = 0;
    6892      List<Split> bestSplits = null;
     
    7195      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
    7296
    73       foreach (var variable in problemData.AllowedInputVariables) {
     97      var allowedInputVariables = problemData.AllowedInputVariables.Where(problemData.Dataset.VariableHasType<double>);
     98
     99      if (!allowedInputVariables.Any()) return null;
     100
     101      foreach (var variable in allowedInputVariables) {
    74102        var inputValues = problemData.Dataset.GetDoubleValues(variable, problemData.TrainingIndices);
    75103        var samples = inputValues.Zip(classValues, (i, v) => new Sample(i, v)).OrderBy(s => s.inputValue);
    76104
    77         var missingValuesDistribution = samples.Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue).ToDictionary(s => s.Key, s => s.Count()).MaxItems(s => s.Value).FirstOrDefault();
     105        var missingValuesDistribution = samples
     106          .Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue)
     107          .ToDictionary(s => s.Key, s => s.Count())
     108          .MaxItems(s => s.Value)
     109          .FirstOrDefault();
    78110
    79111        //calculate class distributions for all distinct inputValues
     
    120152          while (sample.inputValue >= splits[splitIndex].thresholdValue)
    121153            splitIndex++;
    122           correctClassified += sample.classValue == splits[splitIndex].classValue ? 1 : 0;
     154          correctClassified += sample.classValue.IsAlmost(splits[splitIndex].classValue) ? 1 : 0;
    123155        }
    124156        correctClassified += missingValuesDistribution.Value;
     
    134166      //remove neighboring splits with the same class value
    135167      for (int i = 0; i < bestSplits.Count - 1; i++) {
    136         if (bestSplits[i].classValue == bestSplits[i + 1].classValue) {
     168        if (bestSplits[i].classValue.IsAlmost(bestSplits[i + 1].classValue)) {
    137169          bestSplits.Remove(bestSplits[i]);
    138170          i--;
     
    140172      }
    141173
    142       var model = new OneRClassificationModel(problemData.TargetVariable, bestVariable, bestSplits.Select(s => s.thresholdValue).ToArray(), bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass);
    143       var solution = new OneRClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
    144 
    145       return solution;
     174      var model = new OneRClassificationModel(problemData.TargetVariable, bestVariable,
     175        bestSplits.Select(s => s.thresholdValue).ToArray(),
     176        bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass);
     177
     178      return model;
     179    }
     180    private static OneFactorClassificationModel FindBestFactorModel(IClassificationProblemData problemData) {
     181      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     182      var defaultClass = FindMostFrequentClassValue(classValues);
     183      // only select string variables
     184      var allowedInputVariables = problemData.AllowedInputVariables.Where(problemData.Dataset.VariableHasType<string>);
     185
     186      if (!allowedInputVariables.Any()) return null;
     187
     188      OneFactorClassificationModel bestModel = null;
     189      var bestModelNumCorrect = 0;
     190
     191      foreach (var variable in allowedInputVariables) {
     192        var variableValues = problemData.Dataset.GetStringValues(variable, problemData.TrainingIndices);
     193        var groupedClassValues = variableValues
     194          .Zip(classValues, (v, c) => new KeyValuePair<string, double>(v, c))
     195          .GroupBy(kvp => kvp.Key)
     196          .ToDictionary(g => g.Key, g => FindMostFrequentClassValue(g.Select(kvp => kvp.Value)));
     197
     198        var model = new OneFactorClassificationModel(problemData.TargetVariable, variable,
     199          groupedClassValues.Select(kvp => kvp.Key).ToArray(), groupedClassValues.Select(kvp => kvp.Value).ToArray(), defaultClass);
     200
     201        var modelEstimatedValues = model.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     202        var modelNumCorrect = classValues.Zip(modelEstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     203        if (modelNumCorrect > bestModelNumCorrect) {
     204          bestModelNumCorrect = modelNumCorrect;
     205          bestModel = model;
     206        }
     207      }
     208
     209      return bestModel;
     210    }
     211
     212    private static double FindMostFrequentClassValue(IEnumerable<double> classValues) {
     213      return classValues.GroupBy(c => c).OrderByDescending(g => g.Count()).Select(g => g.Key).First();
    146214    }
    147215
Note: See TracChangeset for help on using the changeset viewer.