Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs @ 8811

Last change on this file since 8811 was 8811, checked in by sforsten, 11 years ago

#1776:

File size: 7.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28
29namespace HeuristicLab.Problems.DataAnalysis {
30  /// <summary>
31  /// Base class for weight calculators for classification solutions in an ensemble.
32  /// </summary>
33  [StorableClass]
34  public abstract class ClassificationWeightCalculator : NamedItem, IClassificationEnsembleSolutionWeightCalculator {
35    [StorableConstructor]
36    protected ClassificationWeightCalculator(bool deserializing) : base(deserializing) { }
37    protected ClassificationWeightCalculator(ClassificationWeightCalculator original, Cloner cloner)
38      : base(original, cloner) {
39    }
40    public ClassificationWeightCalculator()
41      : base() {
42      this.name = ItemName;
43      this.description = ItemDescription;
44    }
45
46    private IDictionary<IClassificationSolution, double> weights;
47
48    /// <summary>
49    /// calls CalculateWeights and removes negative weights
50    /// </summary>
51    /// <param name="classificationSolutions"></param>
52    /// <returns>weights which are equal or bigger than zero</returns>
53    public void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
54      List<double> weights = new List<double>();
55      if (classificationSolutions.Count() > 0) {
56        foreach (var weight in CalculateWeights(classificationSolutions)) {
57          weights.Add(weight >= 0 ? weight : 0);
58        }
59      }
60      double sum = weights.Sum();
61      this.weights = classificationSolutions.Zip(weights, (sol, wei) => new { sol, wei }).ToDictionary(x => x.sol, x => x.wei / sum);
62    }
63
64    protected abstract IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions);
65
66    #region delegate CheckPoint
67    public CheckPoint GetTestClassDelegate() {
68      return PointInTest;
69    }
70    public CheckPoint GetTrainingClassDelegate() {
71      return PointInTraining;
72    }
73    public CheckPoint GetAllClassDelegate() {
74      return AllPoints;
75    }
76    #endregion
77
78    public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
79      return from xs in GetEstimatedClassValues(solutions, dataset, rows, handler)
80             select AggregateEstimatedClassValues(xs);
81    }
82
83    protected double AggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues) {
84      IDictionary<double, double> weightSum = new Dictionary<double, double>();
85      foreach (var item in estimatedClassValues) {
86        if (!weightSum.ContainsKey(item.Value))
87          weightSum[item.Value] = 0.0;
88        weightSum[item.Value] += weights[item.Key];
89      }
90      if (weightSum.Count <= 0)
91        return double.NaN;
92      var max = weightSum.Max(x => x.Value);
93      max = weightSum
94        .Where(x => x.Value.Equals(max))
95        .Select(x => x.Key)
96        .First();
97      return max;
98    }
99
100    protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
101      var estimatedValuesEnumerators = (from solution in solutions
102                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
103                                       .ToList();
104
105      var rowEnumerator = rows.GetEnumerator();
106      while (rowEnumerator.MoveNext() & estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
107        yield return (from enumerator in estimatedValuesEnumerators
108                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
109                      select enumerator)
110                     .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current);
111      }
112    }
113
114    public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
115      if (solutions.Count() < 1)
116        return double.NaN;
117      Dataset dataset = solutions.First().ProblemData.Dataset;
118      var correctSolutions = solutions.Select(s => new { Solution = s, Values = s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First() })
119                                      .Where(a => a.Values.Equals(estimatedClassValue))
120                                      .Select(a => a.Solution);
121      return (from sol in correctSolutions
122              select weights[sol]).Sum();
123    }
124
125    public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
126      if (solutions.Count() < 1)
127        return Enumerable.Repeat(double.NaN, indices.Count());
128
129      Dataset dataset = solutions.First().ProblemData.Dataset;
130      Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indices).ToArray());
131      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
132      double[] confidences = new double[indices.Count()];
133
134      for (int i = 0; i < indices.Count(); i++) {
135        var correctSolutions = solValues.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i]));
136        confidences[i] = (from sol in correctSolutions
137                          select weights[sol.Key]).Sum();
138      }
139
140      return confidences;
141    }
142
143    #region Helper
144    protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {
145      return from i in indizes
146             select targetValues[i];
147    }
148    protected bool PointInTraining(IClassificationProblemData problemData, int point) {
149      IntRange trainingPartition = problemData.TrainingPartition;
150      IntRange testPartition = problemData.TestPartition;
151      return (trainingPartition.Start <= point && point < trainingPartition.End)
152        && !(testPartition.Start <= point && point < testPartition.End);
153    }
154    protected bool PointInTest(IClassificationProblemData problemData, int point) {
155      IntRange testPartition = problemData.TestPartition;
156      return testPartition.Start <= point && point < testPartition.End;
157    }
158    protected bool AllPoints(IClassificationProblemData problemData, int point) {
159      return true;
160    }
161    #endregion
162  }
163}
Note: See TracBrowser for help on using the repository browser.