Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/06/17 10:19:37 (7 years ago)
Author:
gkronber
Message:

#2650: merged r14826 from trunk to stable. The only remaining conflict is DataTableControl and ScatterPlotControl which have been renamed within r14982 (-> tree conflict).

Location:
stable
Files:
5 edited
2 copied

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers/OneR.cs

    r15061 r15131  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     
    6566
    6667    public static IClassificationSolution CreateOneRSolution(IClassificationProblemData problemData, int minBucketSize = 6) {
     68      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     69      var model1 = FindBestDoubleVariableModel(problemData, minBucketSize);
     70      var model2 = FindBestFactorModel(problemData);
     71
     72      if (model1 == null && model2 == null) throw new InvalidProgramException("Could not create OneR solution");
     73      else if (model1 == null) return new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone());
     74      else if (model2 == null) return new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone());
     75      else {
     76        var model1EstimatedValues = model1.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     77        var model1NumCorrect = classValues.Zip(model1EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     78
     79        var model2EstimatedValues = model2.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     80        var model2NumCorrect = classValues.Zip(model2EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     81
     82        if (model1NumCorrect > model2NumCorrect) {
     83          return new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone());
     84        } else {
     85          return new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone());
     86        }
     87      }
     88    }
     89
     90    private static OneRClassificationModel FindBestDoubleVariableModel(IClassificationProblemData problemData, int minBucketSize = 6) {
    6791      var bestClassified = 0;
    6892      List<Split> bestSplits = null;
     
    7195      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
    7296
    73       foreach (var variable in problemData.AllowedInputVariables) {
     97      var allowedInputVariables = problemData.AllowedInputVariables.Where(problemData.Dataset.VariableHasType<double>);
     98
     99      if (!allowedInputVariables.Any()) return null;
     100
     101      foreach (var variable in allowedInputVariables) {
    74102        var inputValues = problemData.Dataset.GetDoubleValues(variable, problemData.TrainingIndices);
    75103        var samples = inputValues.Zip(classValues, (i, v) => new Sample(i, v)).OrderBy(s => s.inputValue);
    76104
    77         var missingValuesDistribution = samples.Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue).ToDictionary(s => s.Key, s => s.Count()).MaxItems(s => s.Value).FirstOrDefault();
     105        var missingValuesDistribution = samples
     106          .Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue)
     107          .ToDictionary(s => s.Key, s => s.Count())
     108          .MaxItems(s => s.Value)
     109          .FirstOrDefault();
    78110
    79111        //calculate class distributions for all distinct inputValues
     
    120152          while (sample.inputValue >= splits[splitIndex].thresholdValue)
    121153            splitIndex++;
    122           correctClassified += sample.classValue == splits[splitIndex].classValue ? 1 : 0;
     154          correctClassified += sample.classValue.IsAlmost(splits[splitIndex].classValue) ? 1 : 0;
    123155        }
    124156        correctClassified += missingValuesDistribution.Value;
     
    134166      //remove neighboring splits with the same class value
    135167      for (int i = 0; i < bestSplits.Count - 1; i++) {
    136         if (bestSplits[i].classValue == bestSplits[i + 1].classValue) {
     168        if (bestSplits[i].classValue.IsAlmost(bestSplits[i + 1].classValue)) {
    137169          bestSplits.Remove(bestSplits[i]);
    138170          i--;
     
    140172      }
    141173
    142       var model = new OneRClassificationModel(problemData.TargetVariable, bestVariable, bestSplits.Select(s => s.thresholdValue).ToArray(), bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass);
    143       var solution = new OneRClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
    144 
    145       return solution;
     174      var model = new OneRClassificationModel(problemData.TargetVariable, bestVariable,
     175        bestSplits.Select(s => s.thresholdValue).ToArray(),
     176        bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass);
     177
     178      return model;
     179    }
     180    private static OneFactorClassificationModel FindBestFactorModel(IClassificationProblemData problemData) {
     181      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     182      var defaultClass = FindMostFrequentClassValue(classValues);
     183      // only select string variables
     184      var allowedInputVariables = problemData.AllowedInputVariables.Where(problemData.Dataset.VariableHasType<string>);
     185
     186      if (!allowedInputVariables.Any()) return null;
     187
     188      OneFactorClassificationModel bestModel = null;
     189      var bestModelNumCorrect = 0;
     190
     191      foreach (var variable in allowedInputVariables) {
     192        var variableValues = problemData.Dataset.GetStringValues(variable, problemData.TrainingIndices);
     193        var groupedClassValues = variableValues
     194          .Zip(classValues, (v, c) => new KeyValuePair<string, double>(v, c))
     195          .GroupBy(kvp => kvp.Key)
     196          .ToDictionary(g => g.Key, g => FindMostFrequentClassValue(g.Select(kvp => kvp.Value)));
     197
     198        var model = new OneFactorClassificationModel(problemData.TargetVariable, variable,
     199          groupedClassValues.Select(kvp => kvp.Key).ToArray(), groupedClassValues.Select(kvp => kvp.Value).ToArray(), defaultClass);
     200
     201        var modelEstimatedValues = model.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     202        var modelNumCorrect = classValues.Zip(modelEstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     203        if (modelNumCorrect > bestModelNumCorrect) {
     204          bestModelNumCorrect = modelNumCorrect;
     205          bestModel = model;
     206        }
     207      }
     208
     209      return bestModel;
     210    }
     211
     212    private static double FindMostFrequentClassValue(IEnumerable<double> classValues) {
     213      return classValues.GroupBy(c => c).OrderByDescending(g => g.Count()).Select(g => g.Key).First();
    146214    }
    147215
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers/OneRClassificationModel.cs

    r14186 r15131  
    3131  [StorableClass]
    3232  [Item("OneR Classification Model", "A model that uses intervals for one variable to determine the class.")]
    33   public class OneRClassificationModel : ClassificationModel {
     33  public sealed class OneRClassificationModel : ClassificationModel {
    3434    public override IEnumerable<string> VariablesUsedForPrediction {
    3535      get { return new[] { Variable }; }
     
    3737
    3838    [Storable]
    39     protected string variable;
     39    private string variable;
    4040    public string Variable {
    4141      get { return variable; }
     
    4343
    4444    [Storable]
    45     protected double[] splits;
     45    private double[] splits;
    4646    public double[] Splits {
    4747      get { return splits; }
     
    4949
    5050    [Storable]
    51     protected double[] classes;
     51    private double[] classes;
    5252    public double[] Classes {
    5353      get { return classes; }
     
    5555
    5656    [Storable]
    57     protected double missingValuesClass;
     57    private double missingValuesClass;
    5858    public double MissingValuesClass {
    5959      get { return missingValuesClass; }
     
    6161
    6262    [StorableConstructor]
    63     protected OneRClassificationModel(bool deserializing) : base(deserializing) { }
    64     protected OneRClassificationModel(OneRClassificationModel original, Cloner cloner)
     63    private OneRClassificationModel(bool deserializing) : base(deserializing) { }
     64    private OneRClassificationModel(OneRClassificationModel original, Cloner cloner)
    6565      : base(original, cloner) {
    6666      this.variable = (string)original.variable;
    6767      this.splits = (double[])original.splits.Clone();
    6868      this.classes = (double[])original.classes.Clone();
     69      this.missingValuesClass = original.missingValuesClass;
    6970    }
    7071    public override IDeepCloneable Clone(Cloner cloner) { return new OneRClassificationModel(this, cloner); }
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers/OneRClassificationSolution.cs

    r14186 r15131  
    2828  [StorableClass]
    2929  [Item(Name = "OneR Classification Solution", Description = "Represents a OneR classification solution which uses only a single feature with potentially multiple thresholds for class prediction.")]
    30   public class OneRClassificationSolution : ClassificationSolution {
     30  public sealed class OneRClassificationSolution : ClassificationSolution {
    3131    public new OneRClassificationModel Model {
    3232      get { return (OneRClassificationModel)base.Model; }
     
    3535
    3636    [StorableConstructor]
    37     protected OneRClassificationSolution(bool deserializing) : base(deserializing) { }
    38     protected OneRClassificationSolution(OneRClassificationSolution original, Cloner cloner) : base(original, cloner) { }
     37    private OneRClassificationSolution(bool deserializing) : base(deserializing) { }
     38    private OneRClassificationSolution(OneRClassificationSolution original, Cloner cloner) : base(original, cloner) { }
    3939    public OneRClassificationSolution(OneRClassificationModel model, IClassificationProblemData problemData)
    4040      : base(model, problemData) {
Note: See TracChangeset for help on using the changeset viewer.