Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs @ 7531

Last change on this file since 7531 was 7531, checked in by sforsten, 12 years ago

#1776:

  • 2 more strategies have been implemented
  • major changes in the inheritance have been made to make it possible to add strategies which don't use a voting strategy with weights
  • ClassificationEnsembleSolutionEstimatedClassValuesView doesn't currently show the confidence (has been removed for test purpose)
File size: 4.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
27
28namespace HeuristicLab.Problems.DataAnalysis {
29  /// <summary>
30  /// Base class for weight calculators for classification solutions in an ensemble.
31  /// </summary>
32  [StorableClass]
33  public abstract class DiscriminantClassificationWeightCalculator : ClassificationWeightCalculator {
34    [StorableConstructor]
35    protected DiscriminantClassificationWeightCalculator(bool deserializing) : base(deserializing) { }
36    protected DiscriminantClassificationWeightCalculator(DiscriminantClassificationWeightCalculator original, Cloner cloner)
37      : base(original, cloner) {
38    }
39    public DiscriminantClassificationWeightCalculator()
40      : base() {
41    }
42
43    protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
44      if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution))
45        return Enumerable.Repeat<double>(1.0, classificationSolutions.Count);
46
47      ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions = new ItemCollection<IDiscriminantFunctionClassificationSolution>();
48      foreach (var solution in classificationSolutions) {
49        discriminantSolutions.Add((IDiscriminantFunctionClassificationSolution)solution);
50      }
51
52      return DiscriminantCalculateWeights(discriminantSolutions);
53    }
54
55    protected abstract IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions);
56
57    public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {
58      if (!models.All(x => x is IDiscriminantFunctionClassificationModel))
59        return Enumerable.Repeat<double>(0.0, rows.Count());
60
61      IEnumerable<IDiscriminantFunctionClassificationModel> discriminantModels = models.Cast<IDiscriminantFunctionClassificationModel>();
62
63      IEnumerable<IEnumerable<double>> estimatedClassValues = ClassificationWeightCalculator.GetEstimatedClassValues(models, dataset, rows);
64      IEnumerable<IEnumerable<double>> estimatedValues = DiscriminantClassificationWeightCalculator.GetEstimatedValues(discriminantModels, dataset, rows);
65
66      return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values })
67             select DiscriminantAggregateEstimatedClassValues(zip.ClassValues, zip.Values);
68    }
69
70    protected virtual double DiscriminantAggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues, IEnumerable<double> estimatedValues) {
71      return AggregateEstimatedClassValues(estimatedClassValues);
72    }
73
74    protected static IEnumerable<IEnumerable<double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {
75      if (!models.Any()) yield break;
76      var estimatedValuesEnumerators = (from model in models
77                                        select model.GetEstimatedValues(dataset, rows).GetEnumerator())
78                                       .ToList();
79
80      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
81        yield return from enumerator in estimatedValuesEnumerators
82                     select enumerator.Current;
83      }
84    }
85  }
86}
Note: See TracBrowser for help on using the repository browser.