Changeset 8554


Ignore:
Timestamp:
09/03/12 13:27:40 (7 years ago)
Author:
mkommend
Message:

#1915:

  • Corrected class names and class values caching in ClassificationProblemData
  • Removed caching of classification penalties
  • Corrected AccuracyMaximizationThresholdCalculator (retrieving of penalties was switched)
  • Added asserts for the achieved accuracy to the classification sample unit test
Location:
trunk/sources
Files:
6 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis

  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Views

  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionEstimatedClassValuesView.cs

    r8139 r8554  
    9696      }
    9797
    98       int classValuesCount = Content.ProblemData.ClassValues.Count;
     98      int classValuesCount = Content.ProblemData.Classes;
    9999      int solutionsCount = Content.ClassificationSolutions.Count();
    100100      string[,] values = new string[indices.Length, 5 + classValuesCount + solutionsCount];
     
    114114            estimatedValuesVector[i].GroupBy(x => x).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
    115115          var estimationCount = groups.Where(g => g.Key != null).Select(g => g.Count).Sum();
    116           values[i, 4] =
    117             (((double)groups.Where(g => g.Key == estimatedClassValues[i]).Single().Count) / estimationCount).ToString();
    118           for (int classIndex = 0; classIndex < Content.ProblemData.ClassValues.Count; classIndex++) {
    119             var group = groups.Where(g => g.Key == Content.ProblemData.ClassValues[classIndex]).SingleOrDefault();
     116          values[i, 4] = (((double)groups.Where(g => g.Key == estimatedClassValues[i]).Single().Count) / estimationCount).ToString();
     117          for (int classIndex = 0; classIndex < Content.ProblemData.Classes; classIndex++) {
     118            var group = groups.Where(g => g.Key == Content.ProblemData.ClassValues.ElementAt(classIndex)).SingleOrDefault();
    120119            if (group == null) values[i, 5 + classIndex] = 0.ToString();
    121120            else values[i, 5 + classIndex] = group.Count.ToString();
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblemData.cs

    r8528 r8554  
    223223    }
    224224
    225     private List<double> classValues;
    226     public List<double> ClassValues {
     225    private List<double> classValuesCache;
     226    private List<double> ClassValuesCache {
    227227      get {
    228         if (classValues == null) {
    229           classValues = Dataset.GetDoubleValues(TargetVariableParameter.Value.Value).Distinct().ToList();
    230           classValues.Sort();
     228        if (classValuesCache == null) {
     229          classValuesCache = Dataset.GetDoubleValues(TargetVariableParameter.Value.Value).Distinct().OrderBy(x => x).ToList();
    231230        }
    232         return classValues;
     231        return classValuesCache;
    233232      }
    234233    }
    235     IEnumerable<double> IClassificationProblemData.ClassValues {
    236       get { return ClassValues; }
    237     }
    238 
     234    public IEnumerable<double> ClassValues {
     235      get { return ClassValuesCache; }
     236    }
    239237    public int Classes {
    240       get { return ClassValues.Count; }
    241     }
    242 
    243     private List<string> classNames;
    244     public List<string> ClassNames {
     238      get { return ClassValuesCache.Count; }
     239    }
     240
     241    private List<string> classNamesCache;
     242    private List<string> ClassNamesCache {
    245243      get {
    246         if (classNames == null) {
    247           classNames = new List<string>();
     244        if (classNamesCache == null) {
     245          classNamesCache = new List<string>();
    248246          for (int i = 0; i < ClassNamesParameter.Value.Rows; i++)
    249             classNames.Add(ClassNamesParameter.Value[i, 0]);
     247            classNamesCache.Add(ClassNamesParameter.Value[i, 0]);
    250248        }
    251         return classNames;
     249        return classNamesCache;
    252250      }
    253251    }
    254     IEnumerable<string> IClassificationProblemData.ClassNames {
    255       get { return ClassNames; }
    256     }
    257 
    258     private Dictionary<Tuple<double, double>, double> classificationPenaltiesCache = new Dictionary<Tuple<double, double>, double>();
     252    public IEnumerable<string> ClassNames {
     253      get { return ClassNamesCache; }
     254    }
    259255    #endregion
    260256
     
    319315      DeregisterParameterEvents();
    320316
    321       classNames = null;
    322317      ((IStringConvertibleMatrix)ClassNamesParameter.Value).Columns = 1;
    323       ((IStringConvertibleMatrix)ClassNamesParameter.Value).Rows = ClassValues.Count;
     318      ((IStringConvertibleMatrix)ClassNamesParameter.Value).Rows = ClassValuesCache.Count;
    324319      for (int i = 0; i < Classes; i++)
    325         ClassNamesParameter.Value[i, 0] = "Class " + ClassValues[i];
     320        ClassNamesParameter.Value[i, 0] = "Class " + ClassValuesCache[i];
    326321      ClassNamesParameter.Value.ColumnNames = new List<string>() { "ClassNames" };
    327322      ClassNamesParameter.Value.RowNames = ClassValues.Select(s => "ClassValue: " + s);
    328323
    329       classificationPenaltiesCache.Clear();
    330       ((ValueParameter<DoubleMatrix>)ClassificationPenaltiesParameter).ReactOnValueToStringChangedAndValueItemImageChanged = false;
    331324      ((IStringConvertibleMatrix)ClassificationPenaltiesParameter.Value).Rows = Classes;
    332325      ((IStringConvertibleMatrix)ClassificationPenaltiesParameter.Value).Columns = Classes;
     
    339332        }
    340333      }
    341       ((ValueParameter<DoubleMatrix>)ClassificationPenaltiesParameter).ReactOnValueToStringChangedAndValueItemImageChanged = true;
    342334      RegisterParameterEvents();
    343335    }
    344336
    345337    public string GetClassName(double classValue) {
    346       if (!ClassValues.Contains(classValue)) throw new ArgumentException();
    347       int index = ClassValues.IndexOf(classValue);
    348       return ClassNames[index];
     338      if (!ClassValuesCache.Contains(classValue)) throw new ArgumentException();
     339      int index = ClassValuesCache.IndexOf(classValue);
     340      return ClassNamesCache[index];
    349341    }
    350342    public double GetClassValue(string className) {
    351       if (!ClassNames.Contains(className)) throw new ArgumentException();
    352       int index = ClassNames.IndexOf(className);
    353       return ClassValues[index];
     343      if (!ClassNamesCache.Contains(className)) throw new ArgumentException();
     344      int index = ClassNamesCache.IndexOf(className);
     345      return ClassValuesCache[index];
    354346    }
    355347    public void SetClassName(double classValue, string className) {
    356       if (!classValues.Contains(classValue)) throw new ArgumentException();
    357       int index = ClassValues.IndexOf(classValue);
    358       ClassNames[index] = className;
     348      if (!ClassValuesCache.Contains(classValue)) throw new ArgumentException();
     349      int index = ClassValuesCache.IndexOf(classValue);
    359350      ClassNamesParameter.Value[index, 0] = className;
     351      // updating of class names cache is not necessary here as the parameter value fires a changed event which updates the cache
    360352    }
    361353
     
    364356    }
    365357    public double GetClassificationPenalty(double correctClassValue, double estimatedClassValue) {
    366       var key = Tuple.Create(correctClassValue, estimatedClassValue);
    367       if (!classificationPenaltiesCache.ContainsKey(key)) {
    368         int correctClassIndex = ClassValues.IndexOf(correctClassValue);
    369         int estimatedClassIndex = ClassValues.IndexOf(estimatedClassValue);
    370         classificationPenaltiesCache[key] = ClassificationPenaltiesParameter.Value[correctClassIndex, estimatedClassIndex];
    371       }
    372       return classificationPenaltiesCache[key];
     358      int correctClassIndex = ClassValuesCache.IndexOf(correctClassValue);
     359      int estimatedClassIndex = ClassValuesCache.IndexOf(estimatedClassValue);
     360      return ClassificationPenaltiesParameter.Value[correctClassIndex, estimatedClassIndex];
    373361    }
    374362    public void SetClassificationPenalty(string correctClassName, string estimatedClassName, double penalty) {
     
    376364    }
    377365    public void SetClassificationPenalty(double correctClassValue, double estimatedClassValue, double penalty) {
    378       var key = Tuple.Create(correctClassValue, estimatedClassValue);
    379       int correctClassIndex = ClassValues.IndexOf(correctClassValue);
    380       int estimatedClassIndex = ClassValues.IndexOf(estimatedClassValue);
     366      int correctClassIndex = ClassValuesCache.IndexOf(correctClassValue);
     367      int estimatedClassIndex = ClassValuesCache.IndexOf(estimatedClassValue);
    381368
    382369      ClassificationPenaltiesParameter.Value[correctClassIndex, estimatedClassIndex] = penalty;
     
    388375      ClassNamesParameter.Value.Reset += new EventHandler(Parameter_ValueChanged);
    389376      ClassNamesParameter.Value.ItemChanged += new EventHandler<EventArgs<int, int>>(MatrixParameter_ItemChanged);
    390       ClassificationPenaltiesParameter.Value.Reset += new EventHandler(Parameter_ValueChanged);
    391       ClassificationPenaltiesParameter.Value.ItemChanged += new EventHandler<EventArgs<int, int>>(MatrixParameter_ItemChanged);
    392377    }
    393378    private void DeregisterParameterEvents() {
     
    395380      ClassNamesParameter.Value.Reset -= new EventHandler(Parameter_ValueChanged);
    396381      ClassNamesParameter.Value.ItemChanged -= new EventHandler<EventArgs<int, int>>(MatrixParameter_ItemChanged);
    397       ClassificationPenaltiesParameter.Value.Reset -= new EventHandler(Parameter_ValueChanged);
    398       ClassificationPenaltiesParameter.Value.ItemChanged -= new EventHandler<EventArgs<int, int>>(MatrixParameter_ItemChanged);
    399382    }
    400383
    401384    private void TargetVariableParameter_ValueChanged(object sender, EventArgs e) {
    402       classValues = null;
     385      classValuesCache = null;
     386      classNamesCache = null;
    403387      ResetTargetVariableDependentMembers();
    404388      OnChanged();
    405389    }
    406390    private void Parameter_ValueChanged(object sender, EventArgs e) {
     391      classNamesCache = null;
    407392      OnChanged();
    408393    }
    409394    private void MatrixParameter_ItemChanged(object sender, EventArgs<int, int> e) {
     395      classNamesCache = null;
    410396      OnChanged();
    411397    }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/AccuracyMaximizationThresholdCalculator.cs

    r8126 r8554  
    8585            //all positives
    8686            if (pair.TargetClassValue.IsAlmost(classValues[i - 1])) {
    87               if (pair.EstimatedValue > lowerThreshold && pair.EstimatedValue < actualThreshold)
     87              if (pair.EstimatedValue > lowerThreshold && pair.EstimatedValue <= actualThreshold)
    8888                //true positive
    89                 classificationScore += problemData.GetClassificationPenalty(classValues[i - 1], classValues[i - 1]);
     89                classificationScore += problemData.GetClassificationPenalty(pair.TargetClassValue, pair.TargetClassValue);
    9090              else
    9191                //false negative
    92                 classificationScore += problemData.GetClassificationPenalty(classValues[i], classValues[i - 1]);
     92                classificationScore += problemData.GetClassificationPenalty(pair.TargetClassValue, classValues[i]);
    9393            }
    9494              //all negatives
    9595            else {
    96               if (pair.EstimatedValue > lowerThreshold && pair.EstimatedValue < actualThreshold)
     96              if (pair.EstimatedValue > lowerThreshold && pair.EstimatedValue <= actualThreshold)
    9797                //false positive
    98                 classificationScore += problemData.GetClassificationPenalty(classValues[i - 1], classValues[i]);
     98                classificationScore += problemData.GetClassificationPenalty(pair.TargetClassValue, classValues[i - 1]);
    9999              else
    100100                //true negative, consider only upper class
    101                 classificationScore += problemData.GetClassificationPenalty(classValues[i], classValues[i]);
     101                classificationScore += problemData.GetClassificationPenalty(pair.TargetClassValue, pair.TargetClassValue);
    102102            }
    103103          }
  • trunk/sources/HeuristicLab.Tests/HeuristicLab-3.3/SamplesTest.cs

    r8482 r8554  
    343343      Assert.AreEqual(100.62175156249987, GetDoubleResult(ga, "CurrentWorstQuality"), 1E-8);
    344344      Assert.AreEqual(100900, GetIntResult(ga, "EvaluatedSolutions"));
     345      var bestTrainingSolution = (IClassificationSolution)ga.Results["Best training solution"].Value;
     346      Assert.AreEqual(0.80625, bestTrainingSolution.TrainingAccuracy, 1E-8);
     347      Assert.AreEqual(0.782608695652174, bestTrainingSolution.TestAccuracy, 1E-8);
    345348    }
    346349
Note: See TracChangeset for help on using the changeset viewer.