Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs @ 7531

Last change on this file since 7531 was 7531, checked in by sforsten, 12 years ago

#1776:

  • 2 more strategies have been implemented
  • major changes in the inheritance have been made to make it possible to add strategies which don't use a voting strategy with weights
  • ClassificationEnsembleSolutionEstimatedClassValuesView doesn't currently show the confidence (has been removed for test purpose)
File size: 4.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
29
30namespace HeuristicLab.Problems.DataAnalysis {
31  /// <summary>
32  /// Base class for weight calculators for classification solutions in an ensemble.
33  /// </summary>
34  [StorableClass]
35  public abstract class ClassificationWeightCalculator : NamedItem, IClassificationEnsembleSolutionWeightCalculator {
36    [StorableConstructor]
37    protected ClassificationWeightCalculator(bool deserializing) : base(deserializing) { }
38    protected ClassificationWeightCalculator(ClassificationWeightCalculator original, Cloner cloner)
39      : base(original, cloner) {
40    }
41    public ClassificationWeightCalculator()
42      : base() {
43      this.name = ItemName;
44      this.description = ItemDescription;
45    }
46
47    private IEnumerable<double> weights;
48
49    /// <summary>
50    /// calls CalculateWeights and removes negative weights
51    /// </summary>
52    /// <param name="classificationSolutions"></param>
53    /// <returns>weights which are equal or bigger than zero</returns>
54    public void CalculateNormalizedWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
55      List<double> weights = new List<double>();
56      if (classificationSolutions.Count > 0) {
57        foreach (var weight in CalculateWeights(classificationSolutions)) {
58          weights.Add(weight >= 0 ? weight : 0);
59        }
60      }
61      this.weights = weights.Select(x => x / weights.Sum());
62    }
63
64    protected abstract IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions);
65
66    public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {
67      return from xs in ClassificationWeightCalculator.GetEstimatedClassValues(models, dataset, rows)
68             select AggregateEstimatedClassValues(xs);
69    }
70
71    protected double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) {
72      if (!estimatedClassValues.Count().Equals(weights.Count()))
73        throw new ArgumentException("'estimatedClassValues' has " + estimatedClassValues.Count() + " elements, while 'weights' has" + weights.Count());
74      IDictionary<double, double> weightSum = new Dictionary<double, double>();
75      for (int i = 0; i < estimatedClassValues.Count(); i++) {
76        if (!weightSum.ContainsKey(estimatedClassValues.ElementAt(i)))
77          weightSum[estimatedClassValues.ElementAt(i)] = 0.0;
78        weightSum[estimatedClassValues.ElementAt(i)] += weights.ElementAt(i);
79      }
80      if (weightSum.Count <= 0)
81        return double.NaN;
82      var max = weightSum.Max(x => x.Value);
83      max = weightSum
84        .Where(x => x.Value.Equals(max))
85        .Select(x => x.Key)
86        .First();
87      return max;
88    }
89
90    protected static IEnumerable<IEnumerable<double>> GetEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {
91      if (!models.Any()) yield break;
92      var estimatedValuesEnumerators = (from model in models
93                                        select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
94                                       .ToList();
95
96      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
97        yield return from enumerator in estimatedValuesEnumerators
98                     select enumerator.Current;
99      }
100    }
101
102    #region Helper
103    protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {
104      return from i in indizes
105             select targetValues[i];
106    }
107    #endregion
108  }
109}
Note: See TracBrowser for help on using the repository browser.