Free cookie consent management tool by TermsFeed Policy Generator

Changeset 10568


Ignore:
Timestamp:
03/10/14 14:43:37 (10 years ago)
Author:
mkommend
Message:

#1998: Code cleanup in ZeroR classification algorithm.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationModelComparison/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/ZeroR.cs

    r9074 r10568  
    2020#endregion
    2121
    22 using System.Collections.Generic;
    2322using System.Linq;
    2423using HeuristicLab.Common;
    2524using HeuristicLab.Core;
    26 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2725using HeuristicLab.Optimization;
    2826using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2927using HeuristicLab.Problems.DataAnalysis;
    30 using HeuristicLab.Problems.DataAnalysis.Symbolic;
    31 using HeuristicLab.Problems.DataAnalysis.Symbolic.Classification;
    3228
    3329namespace HeuristicLab.Algorithms.DataAnalysis {
     
    6258      Dataset dataset = problemData.Dataset;
    6359      string target = problemData.TargetVariable;
    64       var classValuesEnumerator = problemData.ClassValues.GetEnumerator();
    65       var classValuesInDatasetEnumerator = dataset.GetDoubleValues(target, problemData.TrainingIndices).GetEnumerator();
     60      var targetValues = dataset.GetDoubleValues(target, problemData.TrainingIndices);
    6661
    67       Dictionary<double, int> classValuesCount = new Dictionary<double, int>(problemData.ClassValues.Count());
     62      var dominantClass = targetValues.GroupBy(x => x).ToDictionary(g => g.Key, g => g.Count())
     63        .MaxItems(kvp => kvp.Value).Select(x => x.Key).First();
    6864
    69       //initialize
    70       while (classValuesEnumerator.MoveNext()) {
    71         classValuesCount[classValuesEnumerator.Current] = 0;
    72       }
    73 
    74       //count occurence of classes
    75       while (classValuesInDatasetEnumerator.MoveNext()) {
    76         classValuesCount[classValuesInDatasetEnumerator.Current] += 1;
    77       }
    78 
    79       classValuesEnumerator.Reset();
    80       double mostOccurences = -1;
    81       double bestClass = double.NaN;
    82       while (classValuesEnumerator.MoveNext()) {
    83         if (classValuesCount[classValuesEnumerator.Current] > mostOccurences) {
    84           mostOccurences = classValuesCount[classValuesEnumerator.Current];
    85           bestClass = classValuesEnumerator.Current;
    86         }
    87       }
    88 
    89       ConstantClassificationModel model = new ConstantClassificationModel(bestClass);
    90       ConstantClassificationSolution solution = new ConstantClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
    91 
     65      var model = new ConstantClassificationModel(dominantClass);
     66      var solution = new ConstantClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
    9267      return solution;
    93     }
    94 
    95     private static SymbolicDiscriminantFunctionClassificationModel CreateDiscriminantFunctionModel(ISymbolicExpressionTree tree,
    96     ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
    97     IClassificationProblemData problemData,
    98     IEnumerable<int> rows,
    99     IEnumerable<double> classValues) {
    100       var model = new SymbolicDiscriminantFunctionClassificationModel(tree, interpreter, new AccuracyMaximizationThresholdCalculator());
    101       IList<double> thresholds = new List<double>();
    102       double last = 0;
    103       foreach (double item in classValues) {
    104         if (thresholds.Count == 0) {
    105           thresholds.Add(double.NegativeInfinity);
    106         } else {
    107           thresholds.Add((last + item) / 2);
    108         }
    109         last = item;
    110       }
    111       model.SetThresholdsAndClassValues(thresholds, classValues);
    112       return model;
    11368    }
    11469  }
Note: See TracChangeset for help on using the changeset viewer.