Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/17/11 15:14:45 (14 years ago)
Author:
gkronber
Message:

#1418 implemented linear scaling for classification solutions, fixed bugs interactive simplifier view for classification solutions.

Location:
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs

    r5730 r5736  
    8282    }
    8383
    84     private void RecalculateResults() {
     84    protected void RecalculateResults() {
    8585      double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values
    8686      IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs

    r5730 r5736  
    4040    [Storable]
    4141    private IRegressionModel model;
     42
    4243    [Storable]
    4344    private double[] classValues;
    44     // class values are not necessarily sorted in ascending order
    4545    public IEnumerable<double> ClassValues {
    4646      get { return (double[])classValues.Clone(); }
    47       set {
    48         if (value == null) throw new ArgumentException();
    49         double[] newValue = value.ToArray();
    50         if (newValue.Length != classValues.Length) throw new ArgumentException();
    51         classValues = newValue;
    52       }
     47      private set { classValues = value.ToArray(); }
    5348    }
     49
    5450    [Storable]
    5551    private double[] thresholds;
    5652    public IEnumerable<double> Thresholds {
    5753      get { return (IEnumerable<double>)thresholds.Clone(); }
    58       set {
    59         thresholds = value.ToArray();
    60         OnThresholdsChanged(EventArgs.Empty);
    61       }
     54      private set { thresholds = value.ToArray(); }
    6255    }
    6356
     
    7164      thresholds = (double[])original.thresholds.Clone();
    7265    }
    73     public DiscriminantFunctionClassificationModel(IRegressionModel model, IEnumerable<double> classValues, IEnumerable<double> thresholds)
     66
     67    public DiscriminantFunctionClassificationModel(IRegressionModel model)
    7468      : base() {
    7569      this.name = ItemName;
    7670      this.description = ItemDescription;
    7771      this.model = model;
    78       this.classValues = classValues.ToArray();
    79       this.thresholds = thresholds.ToArray();
     72      this.classValues = new double[] { 0.0 };
     73      this.thresholds = new double[] { double.NegativeInfinity };
    8074    }
    8175
    8276    public override IDeepCloneable Clone(Cloner cloner) {
    8377      return new DiscriminantFunctionClassificationModel(this, cloner);
     78    }
     79
     80    public void SetThresholdsAndClassValues(IEnumerable<double> thresholds, IEnumerable<double> classValues) {
     81      var classValuesArr = classValues.ToArray();
     82      var thresholdsArr = thresholds.ToArray();
     83      if (thresholdsArr.Length != classValuesArr.Length) throw new ArgumentException();
     84
     85      this.classValues = classValuesArr;
     86      this.thresholds = thresholdsArr;
     87      OnThresholdsChanged(EventArgs.Empty);
    8488    }
    8589
     
    96100          else break;
    97101        }
    98         yield return classValues.ElementAt(classIndex);
     102        yield return classValues.ElementAt(classIndex - 1);
    99103      }
    100104    }
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs

    r5730 r5736  
    4040    public new IDiscriminantFunctionClassificationModel Model {
    4141      get { return (IDiscriminantFunctionClassificationModel)base.Model; }
    42       protected set { base.Model = value; }
     42      protected set {
     43        if (value != null && value != Model) {
     44          if (Model != null) {
     45            Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
     46          }
     47          value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
     48          base.Model = value;
     49        }
     50      }
    4351    }
    4452
     
    4755    protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
    4856      : base(original, cloner) {
     57      RegisterEventHandler();
    4958    }
    50     public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData, IEnumerable<double> classValues, IEnumerable<double> thresholds)
    51       : this(new DiscriminantFunctionClassificationModel(model, classValues, thresholds), problemData) {
     59    public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)
     60      : this(new DiscriminantFunctionClassificationModel(model), problemData) {
    5261    }
    5362    public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
    5463      : base(model, problemData) {
     64      RegisterEventHandler();
     65      SetAccuracyMaximizingThresholds();
     66    }
     67
     68    [StorableHook(HookType.AfterDeserialization)]
     69    private void AfterDeserialization() {
     70      RegisterEventHandler();
     71    }
     72
     73    private void RegisterEventHandler() {
     74      Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
     75    }
     76    private void Model_ThresholdsChanged(object sender, EventArgs e) {
     77      OnModelThresholdsChanged(e);
     78    }
     79
     80    public void SetAccuracyMaximizingThresholds() {
     81      double[] classValues;
     82      double[] thresholds;
     83      var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
     84      AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
     85
     86      Model.SetThresholdsAndClassValues(thresholds, classValues);
     87    }
     88
     89    public void SetClassDistibutionCutPointThresholds() {
     90      double[] classValues;
     91      double[] thresholds;
     92      var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
     93      NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
     94
     95      Model.SetThresholdsAndClassValues(thresholds, classValues);
     96    }
     97
     98    protected override void OnModelChanged(EventArgs e) {
     99      base.OnModelChanged(e);     
     100      SetAccuracyMaximizingThresholds();
     101    }
     102
     103    protected override void OnProblemDataChanged(EventArgs e) {
     104      base.OnProblemDataChanged(e);
     105      SetAccuracyMaximizingThresholds();
     106    }
     107    protected virtual void OnModelThresholdsChanged(EventArgs e) {
     108      RecalculateResults();
    55109    }
    56110
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/AccuracyMaximizationThresholdCalculator.cs

    r5730 r5736  
    6969      classValues = problemData.ClassValues.OrderBy(x => x).ToArray();
    7070      int nClasses = classValues.Length;
    71       thresholds = new double[nClasses + 1];
     71      thresholds = new double[nClasses];
    7272      thresholds[0] = double.NegativeInfinity;
    73       thresholds[thresholds.Length - 1] = double.PositiveInfinity;
     73      // thresholds[thresholds.Length - 1] = double.PositiveInfinity;
    7474
    7575      // incrementally calculate accuracy of all possible thresholds
    7676      int[,] confusionMatrix = new int[nClasses, nClasses];
    7777
    78       for (int i = 1; i < thresholds.Length - 1; i++) {
     78      for (int i = 1; i < thresholds.Length; i++) {
    7979        double lowerThreshold = thresholds[i - 1];
    8080        double actualThreshold = Math.Max(lowerThreshold, minEstimatedValue);
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/NormalDistributionCutPointsThresholdCalculator.cs

    r5730 r5736  
    8888      thresholdList.Sort();
    8989      thresholdList.Insert(0, double.NegativeInfinity);
    90       thresholdList.Add(double.PositiveInfinity);
    9190
    9291      // determine class values for each partition separated by a threshold by calculating the density of all class distributions
    9392      // all points in the partition are classified as the class with the maximal density in the parition
    9493      List<double> classValuesList = new List<double>();
    95       for (int i = 0; i < thresholdList.Count - 1; i++) {
     94      for (int i = 0; i < thresholdList.Count; i++) {
    9695        double m;
    9796        if (double.IsNegativeInfinity(thresholdList[i])) {
    9897          m = thresholdList[i + 1] - 1.0; // smaller than the smalles non-infinity threshold
    99         } else if (double.IsPositiveInfinity(thresholdList[i + 1])) {
    100           m = thresholdList[i] + 1.0; // larger than the largest non-infinity threshold
     98        } else if (i == thresholdList.Count - 1) {
     99          // last threshold
     100          m = thresholdList[i] + 1.0; // larger than the last threshold
    101101        } else {
    102102          m = thresholdList[i] + (thresholdList[i + 1] - thresholdList[i]) / 2.0; // middle of partition
     
    135135        }
    136136      }
    137       filteredThresholds.Add(double.PositiveInfinity);
    138137      thresholds = filteredThresholds.ToArray();
    139138      classValues = filteredClassValues.ToArray();
Note: See TracChangeset for help on using the changeset viewer.