source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs @ 7562

Last change on this file since 7562 was 7562, checked in by sforsten, 8 years ago

#1776:

  • bug fix in NeighbourhoodWeightCalculator
  • added GetConfidence method to IClassificationEnsembleSolutionWeightCalculator
  • adjusted the confidence column in ClassificationEnsembleSolutionEstimatedClassValuesView
File size: 5.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
26using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
27
28namespace HeuristicLab.Problems.DataAnalysis {
29  /// <summary>
30  /// Base class for weight calculators for classification solutions in an ensemble.
31  /// </summary>
32  [StorableClass]
33  public abstract class DiscriminantClassificationWeightCalculator : ClassificationWeightCalculator {
34    [StorableConstructor]
35    protected DiscriminantClassificationWeightCalculator(bool deserializing) : base(deserializing) { }
36    protected DiscriminantClassificationWeightCalculator(DiscriminantClassificationWeightCalculator original, Cloner cloner)
37      : base(original, cloner) {
38    }
39    public DiscriminantClassificationWeightCalculator()
40      : base() {
41    }
42
43    protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
44      if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution))
45        return Enumerable.Repeat<double>(1.0, classificationSolutions.Count());
46
47      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = classificationSolutions.Cast<IDiscriminantFunctionClassificationSolution>();
48
49      return DiscriminantCalculateWeights(discriminantSolutions);
50    }
51
52    protected abstract IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions);
53
54    public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
55      if (!solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
56        return Enumerable.Repeat<double>(0.0, rows.Count());
57
58      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
59
60      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler);
61      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler);
62
63      return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values })
64             select DiscriminantAggregateEstimatedClassValues(zip.ClassValues, zip.Values);
65    }
66
67    protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) {
68      return AggregateEstimatedClassValues(estimatedClassValues);
69    }
70
71    protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
72      var estimatedValuesEnumerators = (from solution in solutions
73                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
74                                        .ToList();
75
76      var rowEnumerator = rows.GetEnumerator();
77      while (rowEnumerator.MoveNext() && estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
78        yield return (from enumerator in estimatedValuesEnumerators
79                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
80                      select enumerator)
81                      .ToDictionary(x => (IClassificationSolution)x.Solution, x => x.EstimatedValuesEnumerator.Current);
82      }
83    }
84
85    public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
86      if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
87        return double.NaN;
88
89      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
90
91      return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue);
92    }
93
94    protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
95      return base.GetConfidence(solutions, index, estimatedClassValue);
96    }
97  }
98}
Note: See TracBrowser for help on using the repository browser.