Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs @ 8814

Last change on this file since 8814 was 8814, checked in by sforsten, 12 years ago

#1776:

  • improved performance of confidence calculation
  • fixed bug in median confidence calculation
  • fixed bug in average confidence calculation
  • confidence calculation is now easier for training and test
  • removed obsolete view ClassificationEnsembleSolutionConfidenceAccuracyDependence
File size: 6.4 KB
RevLine 
[7531]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
26
27namespace HeuristicLab.Problems.DataAnalysis {
28  /// <summary>
29  /// Base class for weight calculators for classification solutions in an ensemble.
30  /// </summary>
31  [StorableClass]
32  public abstract class DiscriminantClassificationWeightCalculator : ClassificationWeightCalculator {
33    [StorableConstructor]
34    protected DiscriminantClassificationWeightCalculator(bool deserializing) : base(deserializing) { }
35    protected DiscriminantClassificationWeightCalculator(DiscriminantClassificationWeightCalculator original, Cloner cloner)
36      : base(original, cloner) {
37    }
38    public DiscriminantClassificationWeightCalculator()
39      : base() {
40    }
41
[7549]42    protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
[7531]43      if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution))
[7549]44        return Enumerable.Repeat<double>(1.0, classificationSolutions.Count());
[7531]45
[7549]46      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = classificationSolutions.Cast<IDiscriminantFunctionClassificationSolution>();
[7531]47
48      return DiscriminantCalculateWeights(discriminantSolutions);
49    }
50
[7549]51    protected abstract IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions);
[7531]52
[7549]53    public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
54      if (!solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
[8177]55        return Enumerable.Repeat<double>(double.NaN, rows.Count());
[7531]56
[7549]57      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
[7531]58
[7549]59      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler);
[7729]60      IEnumerable<IDictionary<IDiscriminantFunctionClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler);
[7531]61
62      return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values })
63             select DiscriminantAggregateEstimatedClassValues(zip.ClassValues, zip.Values);
64    }
65
[7729]66    protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IDiscriminantFunctionClassificationSolution, double> estimatedValues) {
[8101]67      return base.AggregateEstimatedClassValues(estimatedClassValues);
[7531]68    }
69
[7729]70    protected IEnumerable<IDictionary<IDiscriminantFunctionClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
[7549]71      var estimatedValuesEnumerators = (from solution in solutions
[7729]72                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedValues(dataset, rows).GetEnumerator() })
[7549]73                                        .ToList();
[7531]74
[7549]75      var rowEnumerator = rows.GetEnumerator();
76      while (rowEnumerator.MoveNext() && estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
77        yield return (from enumerator in estimatedValuesEnumerators
78                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
79                      select enumerator)
[7729]80                      .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current);
[7531]81      }
82    }
[7562]83
[8814]84    public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
[7562]85      if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
86        return double.NaN;
87
88      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
89
[8814]90      return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue, handler);
[7562]91    }
92
[8814]93    protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
94      return base.GetConfidence(solutions, index, estimatedClassValue, handler);
[7562]95    }
[8297]96
[8814]97    public sealed override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
[8297]98      if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
99        return Enumerable.Repeat(double.NaN, indices.Count());
100
101      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
102
[8814]103      return GetDiscriminantConfidence(discriminantSolutions, indices, estimatedClassValue, handler);
[8297]104    }
105
[8814]106    public virtual IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
107      return base.GetConfidence(solutions, indices, estimatedClassValue, handler);
[8297]108    }
[7531]109  }
110}
Note: See TracBrowser for help on using the repository browser.