Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs @ 8297

Last change on this file since 8297 was 8297, checked in by sforsten, 12 years ago

#1776:

  • Corrected namespace of IClassificationEnsembleSolutionWeightCalculator interface
  • Corrected calculation of confidence for test and training samples in ClassificationEnsembleSolutionEstimatedClassValuesView
  • Added overload method GetConfidence to IClassificationEnsembleSolutionWeightCalculator to calculate more than one point at a time (maybe additional methods for training and test confidence could improve the performance remarkably)
  • Added ClassificationEnsembleSolutionConfidenceAccuracyDependence to see how accuracy would behave, if only samples with high confidence would be classified
File size: 13.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Collections;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis.Interfaces;
31
32namespace HeuristicLab.Problems.DataAnalysis {
33  /// <summary>
34  /// Represents classification solutions that contain an ensemble of multiple classification models
35  /// </summary>
36  [StorableClass]
37  [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")]
38  [Creatable("Data Analysis - Ensembles")]
39  public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution {
40    public new IClassificationEnsembleModel Model {
41      get { return (IClassificationEnsembleModel)base.Model; }
42    }
43    public new ClassificationEnsembleProblemData ProblemData {
44      get { return (ClassificationEnsembleProblemData)base.ProblemData; }
45      set { base.ProblemData = value; }
46    }
47
48    private readonly CheckedItemCollection<IClassificationSolution> classificationSolutions;
49    public ICheckedItemCollection<IClassificationSolution> ClassificationSolutions {
50      get { return classificationSolutions; }
51    }
52
53    private IClassificationEnsembleSolutionWeightCalculator weightCalculator;
54    public IClassificationEnsembleSolutionWeightCalculator WeightCalculator {
55      set {
56        if (value != null) {
57          weightCalculator = value;
58          if (!ProblemData.IsEmpty) {
59            RecalculateResults();
60          }
61        }
62      }
63      get { return weightCalculator; }
64    }
65
66    [Storable]
67    private Dictionary<IClassificationModel, IntRange> trainingPartitions;
68    [Storable]
69    private Dictionary<IClassificationModel, IntRange> testPartitions;
70
71    [StorableConstructor]
72    private ClassificationEnsembleSolution(bool deserializing)
73      : base(deserializing) {
74      classificationSolutions = new CheckedItemCollection<IClassificationSolution>();
75    }
76    [StorableHook(HookType.AfterDeserialization)]
77    private void AfterDeserialization() {
78      foreach (var model in Model.Models) {
79        IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone();
80        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
81        problemData.TrainingPartition.End = trainingPartitions[model].End;
82        problemData.TestPartition.Start = testPartitions[model].Start;
83        problemData.TestPartition.End = testPartitions[model].End;
84
85        classificationSolutions.Add(model.CreateClassificationSolution(problemData));
86      }
87      RegisterClassificationSolutionsEventHandler();
88    }
89
90    private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
91      : base(original, cloner) {
92      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
93      testPartitions = new Dictionary<IClassificationModel, IntRange>();
94      foreach (var pair in original.trainingPartitions) {
95        trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
96      }
97      foreach (var pair in original.testPartitions) {
98        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
99      }
100
101      weightCalculator = cloner.Clone(original.weightCalculator);
102      classificationSolutions = cloner.Clone(original.classificationSolutions);
103      RegisterClassificationSolutionsEventHandler();
104    }
105
106    public ClassificationEnsembleSolution()
107      : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) {
108      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
109      testPartitions = new Dictionary<IClassificationModel, IntRange>();
110      classificationSolutions = new CheckedItemCollection<IClassificationSolution>();
111      weightCalculator = new MajorityVoteWeightCalculator();
112
113      RegisterClassificationSolutionsEventHandler();
114    }
115
116    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData)
117      : this(models, problemData,
118             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
119             models.Select(m => (IntRange)problemData.TestPartition.Clone())
120      ) { }
121
122    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
123      : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) {
124      this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
125      this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
126      this.classificationSolutions = new CheckedItemCollection<IClassificationSolution>();
127
128      List<IClassificationSolution> solutions = new List<IClassificationSolution>();
129      var modelEnumerator = models.GetEnumerator();
130      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
131      var testPartitionEnumerator = testPartitions.GetEnumerator();
132
133      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
134        var p = (IClassificationProblemData)problemData.Clone();
135        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
136        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
137        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
138        p.TestPartition.End = testPartitionEnumerator.Current.End;
139
140        solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p));
141      }
142      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
143        throw new ArgumentException();
144      }
145
146      RegisterClassificationSolutionsEventHandler();
147      weightCalculator = new MajorityVoteWeightCalculator();
148      classificationSolutions.AddRange(solutions);
149    }
150
151    public override IDeepCloneable Clone(Cloner cloner) {
152      return new ClassificationEnsembleSolution(this, cloner);
153    }
154    private void RegisterClassificationSolutionsEventHandler() {
155      classificationSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsAdded);
156      classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved);
157      classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset);
158      classificationSolutions.CheckedItemsChanged += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CheckedItemsChanged);
159    }
160
161    protected override void RecalculateResults() {
162      weightCalculator.CalculateNormalizedWeights(classificationSolutions.CheckedItems);
163      CalculateResults();
164    }
165
166    #region Evaluation
167    public override IEnumerable<double> EstimatedTrainingClassValues {
168      get {
169        return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems,
170                                                              ProblemData.Dataset,
171                                                              ProblemData.TrainingIndizes,
172                                                              weightCalculator.GetTrainingClassDelegate());
173      }
174    }
175
176    public override IEnumerable<double> EstimatedTestClassValues {
177      get {
178        return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems,
179                                                              ProblemData.Dataset,
180                                                              ProblemData.TestIndizes,
181                                                              weightCalculator.GetTestClassDelegate());
182      }
183    }
184
185    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
186      return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems,
187                                                            ProblemData.Dataset,
188                                                            rows,
189                                                            weightCalculator.GetAllClassDelegate());
190    }
191
192    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
193      IEnumerable<IClassificationModel> models = classificationSolutions.CheckedItems.Select(sol => sol.Model);
194      if (!models.Any()) yield break;
195      var estimatedValuesEnumerators = (from model in models
196                                        select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
197                                       .ToList();
198
199      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
200        yield return from enumerator in estimatedValuesEnumerators
201                     select enumerator.Current;
202      }
203    }
204    #endregion
205
206    protected override void OnProblemDataChanged() {
207      IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset,
208                                                                     ProblemData.AllowedInputVariables,
209                                                                     ProblemData.TargetVariable);
210      problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
211      problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
212      problemData.TestPartition.Start = ProblemData.TestPartition.Start;
213      problemData.TestPartition.End = ProblemData.TestPartition.End;
214
215      foreach (var solution in ClassificationSolutions) {
216        if (solution is ClassificationEnsembleSolution)
217          solution.ProblemData = ProblemData;
218        else
219          solution.ProblemData = problemData;
220      }
221      foreach (var trainingPartition in trainingPartitions.Values) {
222        trainingPartition.Start = ProblemData.TrainingPartition.Start;
223        trainingPartition.End = ProblemData.TrainingPartition.End;
224      }
225      foreach (var testPartition in testPartitions.Values) {
226        testPartition.Start = ProblemData.TestPartition.Start;
227        testPartition.End = ProblemData.TestPartition.End;
228      }
229
230      base.OnProblemDataChanged();
231    }
232
233    public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
234      classificationSolutions.AddRange(solutions);
235    }
236    public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
237      classificationSolutions.RemoveRange(solutions);
238    }
239
240    private void classificationSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
241      foreach (var solution in e.Items) AddClassificationSolution(solution);
242      RecalculateResults();
243    }
244    private void classificationSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
245      foreach (var solution in e.Items) RemoveClassificationSolution(solution);
246      RecalculateResults();
247    }
248    private void classificationSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
249      foreach (var solution in e.OldItems) RemoveClassificationSolution(solution);
250      foreach (var solution in e.Items) AddClassificationSolution(solution);
251      RecalculateResults();
252    }
253    private void classificationSolutions_CheckedItemsChanged(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
254      RecalculateResults();
255    }
256
257    private void AddClassificationSolution(IClassificationSolution solution) {
258      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
259      Model.Add(solution.Model);
260      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
261      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
262    }
263
264    private void RemoveClassificationSolution(IClassificationSolution solution) {
265      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
266      Model.Remove(solution.Model);
267      trainingPartitions.Remove(solution.Model);
268      testPartitions.Remove(solution.Model);
269    }
270  }
271}
Note: See TracBrowser for help on using the repository browser.