Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs @ 7531

Last change on this file since 7531 was 7531, checked in by sforsten, 12 years ago

#1776:

  • 2 more strategies have been implemented
  • major changes in the inheritance have been made to make it possible to add strategies which don't use a voting strategy with weights
  • ClassificationEnsembleSolutionEstimatedClassValuesView doesn't currently show the confidence (has been removed for test purpose)
File size: 12.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Collections;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
31
32namespace HeuristicLab.Problems.DataAnalysis {
33  /// <summary>
34  /// Represents classification solutions that contain an ensemble of multiple classification models
35  /// </summary>
36  [StorableClass]
37  [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")]
38  [Creatable("Data Analysis - Ensembles")]
39  public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution {
40    public new IClassificationEnsembleModel Model {
41      get { return (IClassificationEnsembleModel)base.Model; }
42    }
43    public new ClassificationEnsembleProblemData ProblemData {
44      get { return (ClassificationEnsembleProblemData)base.ProblemData; }
45      set { base.ProblemData = value; }
46    }
47
48    private readonly ItemCollection<IClassificationSolution> classificationSolutions;
49    public IItemCollection<IClassificationSolution> ClassificationSolutions {
50      get { return classificationSolutions; }
51    }
52
53    [Storable]
54    private Dictionary<IClassificationModel, IntRange> trainingPartitions;
55    [Storable]
56    private Dictionary<IClassificationModel, IntRange> testPartitions;
57
58    private IClassificationEnsembleSolutionWeightCalculator weightCalculator;
59
60    public IClassificationEnsembleSolutionWeightCalculator WeightCalculator {
61      set {
62        if (value != null) {
63          weightCalculator = value;
64          weightCalculator.CalculateNormalizedWeights(classificationSolutions);
65          if (!ProblemData.IsEmpty)
66            RecalculateResults();
67        }
68      }
69    }
70
71    [StorableConstructor]
72    private ClassificationEnsembleSolution(bool deserializing)
73      : base(deserializing) {
74      classificationSolutions = new ItemCollection<IClassificationSolution>();
75    }
76    [StorableHook(HookType.AfterDeserialization)]
77    private void AfterDeserialization() {
78      foreach (var model in Model.Models) {
79        IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone();
80        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
81        problemData.TrainingPartition.End = trainingPartitions[model].End;
82        problemData.TestPartition.Start = testPartitions[model].Start;
83        problemData.TestPartition.End = testPartitions[model].End;
84
85        classificationSolutions.Add(model.CreateClassificationSolution(problemData));
86      }
87      RegisterClassificationSolutionsEventHandler();
88    }
89
90    private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
91      : base(original, cloner) {
92      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
93      testPartitions = new Dictionary<IClassificationModel, IntRange>();
94      foreach (var pair in original.trainingPartitions) {
95        trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
96      }
97      foreach (var pair in original.testPartitions) {
98        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
99      }
100
101      classificationSolutions = cloner.Clone(original.classificationSolutions);
102      RegisterClassificationSolutionsEventHandler();
103    }
104
105    public ClassificationEnsembleSolution()
106      : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) {
107      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
108      testPartitions = new Dictionary<IClassificationModel, IntRange>();
109      classificationSolutions = new ItemCollection<IClassificationSolution>();
110
111      weightCalculator = new MajorityVoteWeightCalculator();
112
113      RegisterClassificationSolutionsEventHandler();
114    }
115
116    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData)
117      : this(models, problemData,
118             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
119             models.Select(m => (IntRange)problemData.TestPartition.Clone())
120      ) { }
121
122    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
123      : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) {
124      this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
125      this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
126      this.classificationSolutions = new ItemCollection<IClassificationSolution>();
127
128      List<IClassificationSolution> solutions = new List<IClassificationSolution>();
129      var modelEnumerator = models.GetEnumerator();
130      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
131      var testPartitionEnumerator = testPartitions.GetEnumerator();
132
133      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
134        var p = (IClassificationProblemData)problemData.Clone();
135        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
136        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
137        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
138        p.TestPartition.End = testPartitionEnumerator.Current.End;
139
140        solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p));
141      }
142      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
143        throw new ArgumentException();
144      }
145
146      RegisterClassificationSolutionsEventHandler();
147      classificationSolutions.AddRange(solutions);
148    }
149
150    public override IDeepCloneable Clone(Cloner cloner) {
151      return new ClassificationEnsembleSolution(this, cloner);
152    }
153    private void RegisterClassificationSolutionsEventHandler() {
154      classificationSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsAdded);
155      classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved);
156      classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset);
157    }
158
159    protected override void RecalculateResults() {
160      CalculateResults();
161    }
162
163    #region Evaluation
164    public override IEnumerable<double> EstimatedTrainingClassValues {
165      get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TrainingIndizes); }
166    }
167
168    public override IEnumerable<double> EstimatedTestClassValues {
169      get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TestIndizes); }
170    }
171
172    private bool RowIsTrainingForModel(int currentRow, IClassificationModel model) {
173      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
174              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
175    }
176
177    private bool RowIsTestForModel(int currentRow, IClassificationModel model) {
178      return testPartitions == null || !testPartitions.ContainsKey(model) ||
179              (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
180    }
181
182    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
183      return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, rows);
184    }
185
186    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
187      if (!Model.Models.Any()) yield break;
188      var estimatedValuesEnumerators = (from model in Model.Models
189                                        select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
190                                       .ToList();
191
192      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
193        yield return from enumerator in estimatedValuesEnumerators
194                     select enumerator.Current;
195      }
196    }
197    #endregion
198
199    protected override void OnProblemDataChanged() {
200      IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset,
201                                                                     ProblemData.AllowedInputVariables,
202                                                                     ProblemData.TargetVariable);
203      problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
204      problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
205      problemData.TestPartition.Start = ProblemData.TestPartition.Start;
206      problemData.TestPartition.End = ProblemData.TestPartition.End;
207
208      foreach (var solution in ClassificationSolutions) {
209        if (solution is ClassificationEnsembleSolution)
210          solution.ProblemData = ProblemData;
211        else
212          solution.ProblemData = problemData;
213      }
214      foreach (var trainingPartition in trainingPartitions.Values) {
215        trainingPartition.Start = ProblemData.TrainingPartition.Start;
216        trainingPartition.End = ProblemData.TrainingPartition.End;
217      }
218      foreach (var testPartition in testPartitions.Values) {
219        testPartition.Start = ProblemData.TestPartition.Start;
220        testPartition.End = ProblemData.TestPartition.End;
221      }
222
223      base.OnProblemDataChanged();
224    }
225
226    public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
227      classificationSolutions.AddRange(solutions);
228    }
229    public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
230      classificationSolutions.RemoveRange(solutions);
231    }
232
233    private void classificationSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
234      foreach (var solution in e.Items) AddClassificationSolution(solution);
235      RecalculateResults();
236    }
237    private void classificationSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
238      foreach (var solution in e.Items) RemoveClassificationSolution(solution);
239      RecalculateResults();
240    }
241    private void classificationSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
242      foreach (var solution in e.OldItems) RemoveClassificationSolution(solution);
243      foreach (var solution in e.Items) AddClassificationSolution(solution);
244      RecalculateResults();
245    }
246
247    private void AddClassificationSolution(IClassificationSolution solution) {
248      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
249      Model.Add(solution.Model);
250      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
251      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
252      weightCalculator.CalculateNormalizedWeights(classificationSolutions);
253    }
254
255    private void RemoveClassificationSolution(IClassificationSolution solution) {
256      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
257      Model.Remove(solution.Model);
258      trainingPartitions.Remove(solution.Model);
259      testPartitions.Remove(solution.Model);
260      weightCalculator.CalculateNormalizedWeights(classificationSolutions);
261    }
262  }
263}
Note: See TracBrowser for help on using the repository browser.