source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs @ 7549

Last change on this file since 7549 was 7549, checked in by sforsten, 8 years ago

#1776:

  • models can be selected with a check box
  • all strategies are now finished
  • major changes have been made to provide the same behaviour when getting the estimated training or test values of an ensemble
File size: 5.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
29
30namespace HeuristicLab.Problems.DataAnalysis {
31  /// <summary>
32  /// Base class for weight calculators for classification solutions in an ensemble.
33  /// </summary>
34  [StorableClass]
35  public abstract class ClassificationWeightCalculator : NamedItem, IClassificationEnsembleSolutionWeightCalculator {
36    [StorableConstructor]
37    protected ClassificationWeightCalculator(bool deserializing) : base(deserializing) { }
38    protected ClassificationWeightCalculator(ClassificationWeightCalculator original, Cloner cloner)
39      : base(original, cloner) {
40    }
41    public ClassificationWeightCalculator()
42      : base() {
43      this.name = ItemName;
44      this.description = ItemDescription;
45    }
46
47    private IDictionary<IClassificationSolution, double> weights;
48
49    /// <summary>
50    /// calls CalculateWeights and removes negative weights
51    /// </summary>
52    /// <param name="classificationSolutions"></param>
53    /// <returns>weights which are equal or bigger than zero</returns>
54    public void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
55      List<double> weights = new List<double>();
56      if (classificationSolutions.Count() > 0) {
57        foreach (var weight in CalculateWeights(classificationSolutions)) {
58          weights.Add(weight >= 0 ? weight : 0);
59        }
60      }
61      double sum = weights.Sum();
62      this.weights = classificationSolutions.Zip(weights, (sol, wei) => new { sol, wei }).ToDictionary(x => x.sol, x => x.wei / sum);
63    }
64
65    protected abstract IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions);
66
67    #region delegate CheckPoint
68    public CheckPoint GetTestClassDelegate() {
69      return PointInTest;
70    }
71    public CheckPoint GetTrainingClassDelegate() {
72      return PointInTraining;
73    }
74    public CheckPoint GetAllClassDelegate() {
75      return AllPoints;
76    }
77    #endregion
78
79    public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
80      return from xs in GetEstimatedClassValues(solutions, dataset, rows, handler)
81             select AggregateEstimatedClassValues(xs);
82    }
83
84    protected double AggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues) {
85      IDictionary<double, double> weightSum = new Dictionary<double, double>();
86      foreach (var item in estimatedClassValues) {
87        if (!weightSum.ContainsKey(item.Value))
88          weightSum[item.Value] = 0.0;
89        weightSum[item.Value] += weights[item.Key];
90      }
91      if (weightSum.Count <= 0)
92        return double.NaN;
93      var max = weightSum.Max(x => x.Value);
94      max = weightSum
95        .Where(x => x.Value.Equals(max))
96        .Select(x => x.Key)
97        .First();
98      return max;
99    }
100
101    protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
102      if (!solutions.Any()) yield break;
103      var estimatedValuesEnumerators = (from solution in solutions
104                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
105                                       .ToList();
106
107      var rowEnumerator = rows.GetEnumerator();
108      while (rowEnumerator.MoveNext() & estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
109        yield return (from enumerator in estimatedValuesEnumerators
110                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
111                      select enumerator)
112                     .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current);
113      }
114    }
115
116    #region Helper
117    protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {
118      return from i in indizes
119             select targetValues[i];
120    }
121    protected bool PointInTraining(IClassificationProblemData problemData, int point) {
122      IntRange trainingPartition = problemData.TrainingPartition;
123      IntRange testPartition = problemData.TestPartition;
124      return (trainingPartition.Start <= point && point < trainingPartition.End)
125        && !(testPartition.Start <= point && point < testPartition.End);
126    }
127    protected bool PointInTest(IClassificationProblemData problemData, int point) {
128      IntRange testPartition = problemData.TestPartition;
129      return testPartition.Start <= point && point < testPartition.End;
130    }
131    protected bool AllPoints(IClassificationProblemData problemData, int point) {
132      return true;
133    }
134    #endregion
135  }
136}
Note: See TracBrowser for help on using the repository browser.