Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs @ 8297

Last change on this file since 8297 was 8297, checked in by sforsten, 12 years ago

#1776:

  • Corrected namespace of IClassificationEnsembleSolutionWeightCalculator interface
  • Corrected calculation of confidence for test and training samples in ClassificationEnsembleSolutionEstimatedClassValuesView
  • Added overload method GetConfidence to IClassificationEnsembleSolutionWeightCalculator to calculate more than one point at a time (maybe additional methods for training and test confidence could improve the performance remarkably)
  • Added ClassificationEnsembleSolutionConfidenceAccuracyDependence to see how accuracy would behave, if only samples with high confidence would be classified
File size: 7.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis.Interfaces;
29
30namespace HeuristicLab.Problems.DataAnalysis {
31  /// <summary>
32  /// Base class for weight calculators for classification solutions in an ensemble.
33  /// </summary>
34  [StorableClass]
35  public abstract class ClassificationWeightCalculator : NamedItem, IClassificationEnsembleSolutionWeightCalculator {
36    [StorableConstructor]
37    protected ClassificationWeightCalculator(bool deserializing) : base(deserializing) { }
38    protected ClassificationWeightCalculator(ClassificationWeightCalculator original, Cloner cloner)
39      : base(original, cloner) {
40    }
41    public ClassificationWeightCalculator()
42      : base() {
43      this.name = ItemName;
44      this.description = ItemDescription;
45    }
46
47    private IDictionary<IClassificationSolution, double> weights;
48
49    /// <summary>
50    /// calls CalculateWeights and removes negative weights
51    /// </summary>
52    /// <param name="classificationSolutions"></param>
53    /// <returns>weights which are equal or bigger than zero</returns>
54    public void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
55      List<double> weights = new List<double>();
56      if (classificationSolutions.Count() > 0) {
57        foreach (var weight in CalculateWeights(classificationSolutions)) {
58          weights.Add(weight >= 0 ? weight : 0);
59        }
60      }
61      double sum = weights.Sum();
62      this.weights = classificationSolutions.Zip(weights, (sol, wei) => new { sol, wei }).ToDictionary(x => x.sol, x => x.wei / sum);
63    }
64
65    protected abstract IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions);
66
67    #region delegate CheckPoint
68    public CheckPoint GetTestClassDelegate() {
69      return PointInTest;
70    }
71    public CheckPoint GetTrainingClassDelegate() {
72      return PointInTraining;
73    }
74    public CheckPoint GetAllClassDelegate() {
75      return AllPoints;
76    }
77    #endregion
78
79    public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
80      return from xs in GetEstimatedClassValues(solutions, dataset, rows, handler)
81             select AggregateEstimatedClassValues(xs);
82    }
83
84    protected double AggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues) {
85      IDictionary<double, double> weightSum = new Dictionary<double, double>();
86      foreach (var item in estimatedClassValues) {
87        if (!weightSum.ContainsKey(item.Value))
88          weightSum[item.Value] = 0.0;
89        weightSum[item.Value] += weights[item.Key];
90      }
91      if (weightSum.Count <= 0)
92        return double.NaN;
93      var max = weightSum.Max(x => x.Value);
94      max = weightSum
95        .Where(x => x.Value.Equals(max))
96        .Select(x => x.Key)
97        .First();
98      return max;
99    }
100
101    protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
102      var estimatedValuesEnumerators = (from solution in solutions
103                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
104                                       .ToList();
105
106      var rowEnumerator = rows.GetEnumerator();
107      while (rowEnumerator.MoveNext() & estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
108        yield return (from enumerator in estimatedValuesEnumerators
109                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
110                      select enumerator)
111                     .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current);
112      }
113    }
114
115    public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
116      if (solutions.Count() < 1)
117        return double.NaN;
118      Dataset dataset = solutions.First().ProblemData.Dataset;
119      var correctSolutions = solutions.Select(s => new { Solution = s, Values = s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First() })
120                                      .Where(a => a.Values.Equals(estimatedClassValue))
121                                      .Select(a => a.Solution);
122      return (from sol in correctSolutions
123              select weights[sol]).Sum();
124    }
125
126    public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
127      if (solutions.Count() < 1)
128        return Enumerable.Repeat(double.NaN, indices.Count());
129
130      Dataset dataset = solutions.First().ProblemData.Dataset;
131      Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indices).ToArray());
132      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
133      double[] confidences = new double[indices.Count()];
134
135      for (int i = 0; i < indices.Count(); i++) {
136        var correctSolutions = solValues.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i]));
137        confidences[i] = (from sol in correctSolutions
138                          select weights[sol.Key]).Sum();
139      }
140
141      return confidences;
142    }
143
144    #region Helper
145    protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {
146      return from i in indizes
147             select targetValues[i];
148    }
149    protected bool PointInTraining(IClassificationProblemData problemData, int point) {
150      IntRange trainingPartition = problemData.TrainingPartition;
151      IntRange testPartition = problemData.TestPartition;
152      return (trainingPartition.Start <= point && point < trainingPartition.End)
153        && !(testPartition.Start <= point && point < testPartition.End);
154    }
155    protected bool PointInTest(IClassificationProblemData problemData, int point) {
156      IntRange testPartition = problemData.TestPartition;
157      return testPartition.Start <= point && point < testPartition.End;
158    }
159    protected bool AllPoints(IClassificationProblemData problemData, int point) {
160      return true;
161    }
162    #endregion
163  }
164}
Note: See TracBrowser for help on using the repository browser.