Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/10/12 17:13:46 (12 years ago)
Author:
gkronber
Message:

#1902 implemented LS Gaussian Process classification

Location:
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs

    r7259 r8623  
    5151    }
    5252
     53    private IDiscriminantFunctionThresholdCalculator thresholdCalculator;
     54    [Storable]
     55    public IDiscriminantFunctionThresholdCalculator ThresholdCalculator {
     56      get { return thresholdCalculator; }
     57      private set { thresholdCalculator = value; }
     58    }
     59
    5360
    5461    [StorableConstructor]
     
    6168    }
    6269
    63     public DiscriminantFunctionClassificationModel(IRegressionModel model)
     70    public DiscriminantFunctionClassificationModel(IRegressionModel model, IDiscriminantFunctionThresholdCalculator thresholdCalculator)
    6471      : base() {
    6572      this.name = ItemName;
    6673      this.description = ItemDescription;
    6774      this.model = model;
    68       this.classValues = new double[] { 0.0 };
    69       this.thresholds = new double[] { double.NegativeInfinity };
     75      this.classValues = new double[0];
     76      this.thresholds = new double[0];
     77      this.thresholdCalculator = thresholdCalculator;
     78    }
     79
     80    [StorableHook(HookType.AfterDeserialization)]
     81    private void AfterDeserialization() {
     82      if (ThresholdCalculator == null) ThresholdCalculator = new AccuracyMaximizationThresholdCalculator();
    7083    }
    7184
     
    8093    }
    8194
     95    public virtual void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) {
     96      double[] classValues;
     97      double[] thresholds;
     98      var targetClassValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
     99      var estimatedTrainingValues = GetEstimatedValues(problemData.Dataset, rows);
     100      thresholdCalculator.Calculate(problemData, estimatedTrainingValues, targetClassValues, out classValues, out thresholds);
     101      SetThresholdsAndClassValues(thresholds, classValues);
     102    }
     103
     104
    82105    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
    83106      return model.GetEstimatedValues(dataset, rows);
     
    85108
    86109    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
     110      if (!Thresholds.Any() && !ClassValues.Any()) throw new ArgumentException("No thresholds and class values were set for the current classification model.");
    87111      foreach (var x in GetEstimatedValues(dataset, rows)) {
    88112        int classIndex = 0;
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/NormalDistributionCutPointsThresholdCalculator.cs

    r7259 r8623  
    107107        double maxDensityClassValue = -1;
    108108        foreach (var classValue in originalClasses) {
    109           double density = NormalDensity(m, classMean[classValue], classStdDev[classValue]);
     109          double density = LogNormalDensity(m, classMean[classValue], classStdDev[classValue]);
    110110          if (density > maxDensity) {
    111111            maxDensity = density;
     
    139139    }
    140140
    141     private static double NormalDensity(double x, double mu, double sigma) {
    142       if (sigma.IsAlmost(0.0)) {
    143         if (x.IsAlmost(mu)) return 1.0; else return 0.0;
    144       } else {
    145         return (1.0 / Math.Sqrt(2.0 * Math.PI * sigma * sigma)) * Math.Exp(-((x - mu) * (x - mu)) / (2.0 * sigma * sigma));
    146       }
     141    private static double LogNormalDensity(double x, double mu, double sigma) {
     142      return -0.5 * Math.Log(2.0 * Math.PI * sigma * sigma) - ((x - mu) * (x - mu)) / (2.0 * sigma * sigma);
    147143    }
    148144
Note: See TracChangeset for help on using the changeset viewer.