Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs @ 7562

Last change on this file since 7562 was 7562, checked in by sforsten, 12 years ago

#1776:

  • bug fix in NeighbourhoodWeightCalculator
  • added GetConfidence method to IClassificationEnsembleSolutionWeightCalculator
  • adjusted the confidence column in ClassificationEnsembleSolutionEstimatedClassValuesView
File size: 5.4 KB
RevLine 
[7531]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
[7549]26using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
[7531]27
28namespace HeuristicLab.Problems.DataAnalysis {
29  /// <summary>
30  /// Base class for weight calculators for classification solutions in an ensemble.
31  /// </summary>
32  [StorableClass]
33  public abstract class DiscriminantClassificationWeightCalculator : ClassificationWeightCalculator {
34    [StorableConstructor]
35    protected DiscriminantClassificationWeightCalculator(bool deserializing) : base(deserializing) { }
36    protected DiscriminantClassificationWeightCalculator(DiscriminantClassificationWeightCalculator original, Cloner cloner)
37      : base(original, cloner) {
38    }
39    public DiscriminantClassificationWeightCalculator()
40      : base() {
41    }
42
[7549]43    protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
[7531]44      if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution))
[7549]45        return Enumerable.Repeat<double>(1.0, classificationSolutions.Count());
[7531]46
[7549]47      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = classificationSolutions.Cast<IDiscriminantFunctionClassificationSolution>();
[7531]48
49      return DiscriminantCalculateWeights(discriminantSolutions);
50    }
51
[7549]52    protected abstract IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions);
[7531]53
[7549]54    public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
55      if (!solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
[7531]56        return Enumerable.Repeat<double>(0.0, rows.Count());
57
[7549]58      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
[7531]59
[7549]60      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler);
61      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler);
[7531]62
63      return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values })
64             select DiscriminantAggregateEstimatedClassValues(zip.ClassValues, zip.Values);
65    }
66
[7549]67    protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) {
[7531]68      return AggregateEstimatedClassValues(estimatedClassValues);
69    }
70
[7549]71    protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
72      var estimatedValuesEnumerators = (from solution in solutions
73                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
74                                        .ToList();
[7531]75
[7549]76      var rowEnumerator = rows.GetEnumerator();
77      while (rowEnumerator.MoveNext() && estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
78        yield return (from enumerator in estimatedValuesEnumerators
79                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
80                      select enumerator)
81                      .ToDictionary(x => (IClassificationSolution)x.Solution, x => x.EstimatedValuesEnumerator.Current);
[7531]82      }
83    }
[7562]84
85    public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
86      if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
87        return double.NaN;
88
89      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
90
91      return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue);
92    }
93
94    protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
95      return base.GetConfidence(solutions, index, estimatedClassValue);
96    }
[7531]97  }
98}
Note: See TracBrowser for help on using the repository browser.