source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs @ 8814

Last change on this file since 8814 was 8814, checked in by sforsten, 7 years ago

#1776:

  • improved performance of confidence calculation
  • fixed bug in median confidence calculation
  • fixed bug in average confidence calculation
  • confidence calculation is now easier for training and test
  • removed obsolete view ClassificationEnsembleSolutionConfidenceAccuracyDependence
File size: 6.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
26
27namespace HeuristicLab.Problems.DataAnalysis {
28  /// <summary>
29  /// Base class for weight calculators for classification solutions in an ensemble.
30  /// </summary>
31  [StorableClass]
32  public abstract class DiscriminantClassificationWeightCalculator : ClassificationWeightCalculator {
33    [StorableConstructor]
34    protected DiscriminantClassificationWeightCalculator(bool deserializing) : base(deserializing) { }
35    protected DiscriminantClassificationWeightCalculator(DiscriminantClassificationWeightCalculator original, Cloner cloner)
36      : base(original, cloner) {
37    }
38    public DiscriminantClassificationWeightCalculator()
39      : base() {
40    }
41
42    protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
43      if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution))
44        return Enumerable.Repeat<double>(1.0, classificationSolutions.Count());
45
46      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = classificationSolutions.Cast<IDiscriminantFunctionClassificationSolution>();
47
48      return DiscriminantCalculateWeights(discriminantSolutions);
49    }
50
51    protected abstract IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions);
52
53    public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
54      if (!solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
55        return Enumerable.Repeat<double>(double.NaN, rows.Count());
56
57      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
58
59      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler);
60      IEnumerable<IDictionary<IDiscriminantFunctionClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler);
61
62      return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values })
63             select DiscriminantAggregateEstimatedClassValues(zip.ClassValues, zip.Values);
64    }
65
66    protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IDiscriminantFunctionClassificationSolution, double> estimatedValues) {
67      return base.AggregateEstimatedClassValues(estimatedClassValues);
68    }
69
70    protected IEnumerable<IDictionary<IDiscriminantFunctionClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
71      var estimatedValuesEnumerators = (from solution in solutions
72                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedValues(dataset, rows).GetEnumerator() })
73                                        .ToList();
74
75      var rowEnumerator = rows.GetEnumerator();
76      while (rowEnumerator.MoveNext() && estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
77        yield return (from enumerator in estimatedValuesEnumerators
78                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
79                      select enumerator)
80                      .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current);
81      }
82    }
83
84    public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
85      if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
86        return double.NaN;
87
88      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
89
90      return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue, handler);
91    }
92
93    protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
94      return base.GetConfidence(solutions, index, estimatedClassValue, handler);
95    }
96
97    public sealed override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
98      if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
99        return Enumerable.Repeat(double.NaN, indices.Count());
100
101      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
102
103      return GetDiscriminantConfidence(discriminantSolutions, indices, estimatedClassValue, handler);
104    }
105
106    public virtual IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
107      return base.GetConfidence(solutions, indices, estimatedClassValue, handler);
108    }
109  }
110}
Note: See TracBrowser for help on using the repository browser.