source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs @ 7549

Last change on this file since 7549 was 7549, checked in by sforsten, 8 years ago

#1776:

  • models can be selected with a check box
  • all strategies are now finished
  • major changes have been made to provide the same behaviour when getting the estimated training or test values of an ensemble
File size: 4.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
26using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
27
28namespace HeuristicLab.Problems.DataAnalysis {
29  /// <summary>
30  /// Base class for weight calculators for classification solutions in an ensemble.
31  /// </summary>
32  [StorableClass]
33  public abstract class DiscriminantClassificationWeightCalculator : ClassificationWeightCalculator {
34    [StorableConstructor]
35    protected DiscriminantClassificationWeightCalculator(bool deserializing) : base(deserializing) { }
36    protected DiscriminantClassificationWeightCalculator(DiscriminantClassificationWeightCalculator original, Cloner cloner)
37      : base(original, cloner) {
38    }
39    public DiscriminantClassificationWeightCalculator()
40      : base() {
41    }
42
43    protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
44      if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution))
45        return Enumerable.Repeat<double>(1.0, classificationSolutions.Count());
46
47      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = classificationSolutions.Cast<IDiscriminantFunctionClassificationSolution>();
48
49      return DiscriminantCalculateWeights(discriminantSolutions);
50    }
51
52    protected abstract IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions);
53
54    public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
55      if (!solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
56        return Enumerable.Repeat<double>(0.0, rows.Count());
57
58      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
59
60      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler);
61      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler);
62
63      return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values })
64             select DiscriminantAggregateEstimatedClassValues(zip.ClassValues, zip.Values);
65    }
66
67    protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) {
68      return AggregateEstimatedClassValues(estimatedClassValues);
69    }
70
71    protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
72      if (!solutions.Any()) yield break;
73      var estimatedValuesEnumerators = (from solution in solutions
74                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
75                                        .ToList();
76
77      var rowEnumerator = rows.GetEnumerator();
78      while (rowEnumerator.MoveNext() && estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
79        yield return (from enumerator in estimatedValuesEnumerators
80                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
81                      select enumerator)
82                      .ToDictionary(x => (IClassificationSolution)x.Solution, x => x.EstimatedValuesEnumerator.Current);
83      }
84    }
85  }
86}
Note: See TracBrowser for help on using the repository browser.