Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/08/16 11:39:03 (8 years ago)
Author:
gkronber
Message:

#2650: added support for factor variables to OneR algorithm

Location:
branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers
Files:
2 added
2 edited

Legend:

Unmodified
Added
Removed
  • branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers/OneR.cs

    r14185 r14242  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     
    6465
    6566    public static IClassificationSolution CreateOneRSolution(IClassificationProblemData problemData, int minBucketSize = 6) {
     67      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     68      var model1 = FindBestDoubleVariableModel(problemData, minBucketSize);
     69      var model2 = FindBestFactorModel(problemData);
     70
     71      if (model1 == null && model2 == null) throw new InvalidProgramException("Could not create OneR solution");
     72      else if (model1 == null) return new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone());
     73      else if (model2 == null) return new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone());
     74      else {
     75        var model1EstimatedValues = model1.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     76        var model1NumCorrect = classValues.Zip(model1EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     77
     78        var model2EstimatedValues = model2.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     79        var model2NumCorrect = classValues.Zip(model2EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     80
     81        if (model1NumCorrect > model2NumCorrect) {
     82          return new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone());
     83        } else {
     84          return new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone());
     85        }
     86      }
     87    }
     88
     89    private static OneRClassificationModel FindBestDoubleVariableModel(IClassificationProblemData problemData, int minBucketSize = 6) {
    6690      var bestClassified = 0;
    6791      List<Split> bestSplits = null;
     
    7094      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
    7195
    72       foreach (var variable in problemData.AllowedInputVariables) {
     96      var allowedInputVariables = problemData.AllowedInputVariables.Where(problemData.Dataset.VariableHasType<double>);
     97
     98      if (!allowedInputVariables.Any()) return null;
     99
     100      foreach (var variable in allowedInputVariables) {
    73101        var inputValues = problemData.Dataset.GetDoubleValues(variable, problemData.TrainingIndices);
    74102        var samples = inputValues.Zip(classValues, (i, v) => new Sample(i, v)).OrderBy(s => s.inputValue);
    75103
    76         var missingValuesDistribution = samples.Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue).ToDictionary(s => s.Key, s => s.Count()).MaxItems(s => s.Value).FirstOrDefault();
     104        var missingValuesDistribution = samples
     105          .Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue)
     106          .ToDictionary(s => s.Key, s => s.Count())
     107          .MaxItems(s => s.Value)
     108          .FirstOrDefault();
    77109
    78110        //calculate class distributions for all distinct inputValues
     
    119151          while (sample.inputValue >= splits[splitIndex].thresholdValue)
    120152            splitIndex++;
    121           correctClassified += sample.classValue == splits[splitIndex].classValue ? 1 : 0;
     153          correctClassified += sample.classValue.IsAlmost(splits[splitIndex].classValue) ? 1 : 0;
    122154        }
    123155        correctClassified += missingValuesDistribution.Value;
     
    133165      //remove neighboring splits with the same class value
    134166      for (int i = 0; i < bestSplits.Count - 1; i++) {
    135         if (bestSplits[i].classValue == bestSplits[i + 1].classValue) {
     167        if (bestSplits[i].classValue.IsAlmost(bestSplits[i + 1].classValue)) {
    136168          bestSplits.Remove(bestSplits[i]);
    137169          i--;
     
    139171      }
    140172
    141       var model = new OneRClassificationModel(problemData.TargetVariable, bestVariable, bestSplits.Select(s => s.thresholdValue).ToArray(), bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass);
    142       var solution = new OneRClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
    143 
    144       return solution;
     173      var model = new OneRClassificationModel(problemData.TargetVariable, bestVariable,
     174        bestSplits.Select(s => s.thresholdValue).ToArray(),
     175        bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass);
     176
     177      return model;
     178    }
     179    private static OneFactorClassificationModel FindBestFactorModel(IClassificationProblemData problemData) {
     180      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     181      var defaultClass = FindMostFrequentClassValue(classValues);
     182      // only select string variables
     183      var allowedInputVariables = problemData.AllowedInputVariables.Where(problemData.Dataset.VariableHasType<string>);
     184
     185      if (!allowedInputVariables.Any()) return null;
     186
     187      OneFactorClassificationModel bestModel = null;
     188      var bestModelNumCorrect = 0;
     189
     190      foreach (var variable in allowedInputVariables) {
     191        var variableValues = problemData.Dataset.GetStringValues(variable, problemData.TrainingIndices);
     192        var groupedClassValues = variableValues
     193          .Zip(classValues, (v, c) => new KeyValuePair<string, double>(v, c))
     194          .GroupBy(kvp => kvp.Key)
     195          .ToDictionary(g => g.Key, g => FindMostFrequentClassValue(g.Select(kvp => kvp.Value)));
     196
     197        var model = new OneFactorClassificationModel(problemData.TargetVariable, variable,
     198          groupedClassValues.Select(kvp => kvp.Key).ToArray(), groupedClassValues.Select(kvp => kvp.Value).ToArray(), defaultClass);
     199
     200        var modelEstimatedValues = model.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     201        var modelNumCorrect = classValues.Zip(modelEstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     202        if (modelNumCorrect > bestModelNumCorrect) {
     203          bestModelNumCorrect = modelNumCorrect;
     204          bestModel = model;
     205        }
     206      }
     207
     208      return bestModel;
     209    }
     210
     211    private static double FindMostFrequentClassValue(IEnumerable<double> classValues) {
     212      return classValues.GroupBy(c => c).OrderByDescending(g => g.Count()).Select(g => g.Key).First();
    145213    }
    146214
  • branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers/OneRClassificationModel.cs

    r14185 r14242  
    6767      this.splits = (double[])original.splits.Clone();
    6868      this.classes = (double[])original.classes.Clone();
     69      this.missingValuesClass = original.missingValuesClass;
    6970    }
    7071    public override IDeepCloneable Clone(Cloner cloner) { return new OneRClassificationModel(this, cloner); }
Note: See TracChangeset for help on using the changeset viewer.