source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs @ 8814

Last change on this file since 8814 was 8814, checked in by sforsten, 7 years ago

#1776:

  • improved performance of confidence calculation
  • fixed bug in median confidence calculation
  • fixed bug in average confidence calculation
  • confidence calculation is now easier for training and test
  • removed obsolete view ClassificationEnsembleSolutionConfidenceAccuracyDependence
File size: 7.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
27
28namespace HeuristicLab.Problems.DataAnalysis {
29  /// <summary>
30  /// Base class for weight calculators for classification solutions in an ensemble.
31  /// </summary>
32  [StorableClass]
33  public abstract class ClassificationWeightCalculator : NamedItem, IClassificationEnsembleSolutionWeightCalculator {
34    [StorableConstructor]
35    protected ClassificationWeightCalculator(bool deserializing) : base(deserializing) { }
36    protected ClassificationWeightCalculator(ClassificationWeightCalculator original, Cloner cloner)
37      : base(original, cloner) {
38    }
39    public ClassificationWeightCalculator()
40      : base() {
41      this.name = ItemName;
42      this.description = ItemDescription;
43    }
44
45    private IDictionary<IClassificationSolution, double> weights;
46
47    /// <summary>
48    /// calls CalculateWeights and removes negative weights
49    /// </summary>
50    /// <param name="classificationSolutions"></param>
51    /// <returns>weights which are equal or bigger than zero</returns>
52    public void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
53      List<double> weights = new List<double>();
54      if (classificationSolutions.Count() > 0) {
55        foreach (var weight in CalculateWeights(classificationSolutions)) {
56          weights.Add(weight >= 0 ? weight : 0);
57        }
58      }
59      double sum = weights.Sum();
60      this.weights = classificationSolutions.Zip(weights, (sol, wei) => new { sol, wei }).ToDictionary(x => x.sol, x => x.wei / sum);
61    }
62
63    protected abstract IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions);
64
65    #region delegate CheckPoint
66    public CheckPoint GetTestClassDelegate() {
67      return PointInTest;
68    }
69    public CheckPoint GetTrainingClassDelegate() {
70      return PointInTraining;
71    }
72    public CheckPoint GetAllClassDelegate() {
73      return AllPoints;
74    }
75    #endregion
76
77    public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
78      return from xs in GetEstimatedClassValues(solutions, dataset, rows, handler)
79             select AggregateEstimatedClassValues(xs);
80    }
81
82    protected double AggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues) {
83      IDictionary<double, double> weightSum = new Dictionary<double, double>();
84      foreach (var item in estimatedClassValues) {
85        if (!weightSum.ContainsKey(item.Value))
86          weightSum[item.Value] = 0.0;
87        weightSum[item.Value] += weights[item.Key];
88      }
89      if (weightSum.Count <= 0)
90        return double.NaN;
91      var max = weightSum.Max(x => x.Value);
92      max = weightSum
93        .Where(x => x.Value.Equals(max))
94        .Select(x => x.Key)
95        .First();
96      return max;
97    }
98
99    protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
100      var estimatedValuesEnumerators = (from solution in solutions
101                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
102                                       .ToList();
103
104      var rowEnumerator = rows.GetEnumerator();
105      while (rowEnumerator.MoveNext() & estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
106        yield return (from enumerator in estimatedValuesEnumerators
107                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
108                      select enumerator)
109                     .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current);
110      }
111    }
112
113    public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
114      if (solutions.Count() < 1)
115        return double.NaN;
116      Dataset dataset = solutions.First().ProblemData.Dataset;
117      var correctSolutions = solutions.Select(s => new { Solution = s, Values = s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First() })
118                                      .Where(a => handler(a.Solution.ProblemData, index) && a.Values.Equals(estimatedClassValue))
119                                      .Select(a => a.Solution);
120      return (from sol in correctSolutions
121              select weights[sol]).Sum();
122    }
123
124    public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
125      if (solutions.Count() < 1)
126        return Enumerable.Repeat(double.NaN, indices.Count());
127
128      List<int> indicesList = indices.ToList();
129
130      Dataset dataset = solutions.First().ProblemData.Dataset;
131      Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indicesList).ToArray());
132      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
133      double[] confidences = new double[indicesList.Count];
134
135      for (int i = 0; i < indicesList.Count; i++) {
136        var correctSolutions = solValues.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i]));
137        confidences[i] = (from sol in correctSolutions
138                          where handler(sol.Key.ProblemData, indicesList[i])
139                          select weights[sol.Key]).Sum();
140      }
141
142      return confidences;
143    }
144
145    #region Helper
146    protected bool PointInTraining(IClassificationProblemData problemData, int point) {
147      return problemData.IsTrainingSample(point);
148    }
149    protected bool PointInTest(IClassificationProblemData problemData, int point) {
150      return problemData.IsTestSample(point);
151    }
152    protected bool AllPoints(IClassificationProblemData problemData, int point) {
153      return true;
154    }
155    #endregion
156  }
157}
Note: See TracBrowser for help on using the repository browser.