Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs @ 8177

Last change on this file since 8177 was 8177, checked in by sforsten, 12 years ago

#1776: minor changes

File size: 13.3 KB
RevLine 
[5816]1#region License Information
2/* HeuristicLab
[7259]3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5816]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[6589]22using System;
[5816]23using System.Collections.Generic;
24using System.Linq;
[6613]25using HeuristicLab.Collections;
[5816]26using HeuristicLab.Common;
27using HeuristicLab.Core;
[6589]28using HeuristicLab.Data;
[5816]29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
[7459]30using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
[5816]31
32namespace HeuristicLab.Problems.DataAnalysis {
33  /// <summary>
34  /// Represents classification solutions that contain an ensemble of multiple classification models
35  /// </summary>
36  [StorableClass]
37  [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")]
[6666]38  [Creatable("Data Analysis - Ensembles")]
[6592]39  public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution {
[6239]40    public new IClassificationEnsembleModel Model {
41      get { return (IClassificationEnsembleModel)base.Model; }
42    }
[6666]43    public new ClassificationEnsembleProblemData ProblemData {
44      get { return (ClassificationEnsembleProblemData)base.ProblemData; }
45      set { base.ProblemData = value; }
46    }
[6239]47
[7549]48    private readonly CheckedItemCollection<IClassificationSolution> classificationSolutions;
49    public ICheckedItemCollection<IClassificationSolution> ClassificationSolutions {
[6613]50      get { return classificationSolutions; }
51    }
52
[7459]53    private IClassificationEnsembleSolutionWeightCalculator weightCalculator;
[7464]54    public IClassificationEnsembleSolutionWeightCalculator WeightCalculator {
55      set {
56        if (value != null) {
57          weightCalculator = value;
[8101]58          if (!ProblemData.IsEmpty) {
[7464]59            RecalculateResults();
[8101]60          }
[7464]61        }
62      }
[7562]63      get { return weightCalculator; }
[7464]64    }
65
[7596]66    [Storable]
67    private Dictionary<IClassificationModel, IntRange> trainingPartitions;
68    [Storable]
69    private Dictionary<IClassificationModel, IntRange> testPartitions;
70
[6613]71    [StorableConstructor]
72    private ClassificationEnsembleSolution(bool deserializing)
73      : base(deserializing) {
[7549]74      classificationSolutions = new CheckedItemCollection<IClassificationSolution>();
[6613]75    }
76    [StorableHook(HookType.AfterDeserialization)]
77    private void AfterDeserialization() {
78      foreach (var model in Model.Models) {
79        IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone();
[7596]80        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
81        problemData.TrainingPartition.End = trainingPartitions[model].End;
82        problemData.TestPartition.Start = testPartitions[model].Start;
83        problemData.TestPartition.End = testPartitions[model].End;
84
[6613]85        classificationSolutions.Add(model.CreateClassificationSolution(problemData));
86      }
87      RegisterClassificationSolutionsEventHandler();
88    }
89
[6592]90    private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
[5816]91      : base(original, cloner) {
[7596]92      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
93      testPartitions = new Dictionary<IClassificationModel, IntRange>();
94      foreach (var pair in original.trainingPartitions) {
95        trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
96      }
97      foreach (var pair in original.testPartitions) {
98        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
99      }
100
101      weightCalculator = cloner.Clone(original.weightCalculator);
[6613]102      classificationSolutions = cloner.Clone(original.classificationSolutions);
103      RegisterClassificationSolutionsEventHandler();
[5816]104    }
[6613]105
[6666]106    public ClassificationEnsembleSolution()
107      : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) {
[7596]108      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
109      testPartitions = new Dictionary<IClassificationModel, IntRange>();
[7549]110      classificationSolutions = new CheckedItemCollection<IClassificationSolution>();
[7464]111      weightCalculator = new MajorityVoteWeightCalculator();
[7459]112
[6666]113      RegisterClassificationSolutionsEventHandler();
114    }
115
[6239]116    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData)
[6613]117      : this(models, problemData,
118             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
119             models.Select(m => (IntRange)problemData.TestPartition.Clone())
120      ) { }
[5816]121
[6239]122    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
[6613]123      : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) {
[7596]124      this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
125      this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
[7549]126      this.classificationSolutions = new CheckedItemCollection<IClassificationSolution>();
[6613]127
128      List<IClassificationSolution> solutions = new List<IClassificationSolution>();
129      var modelEnumerator = models.GetEnumerator();
130      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
131      var testPartitionEnumerator = testPartitions.GetEnumerator();
132
133      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
134        var p = (IClassificationProblemData)problemData.Clone();
135        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
136        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
137        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
138        p.TestPartition.End = testPartitionEnumerator.Current.End;
139
140        solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p));
141      }
142      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
143        throw new ArgumentException();
144      }
145
146      RegisterClassificationSolutionsEventHandler();
[8177]147      weightCalculator = new MajorityVoteWeightCalculator();
[6613]148      classificationSolutions.AddRange(solutions);
[6239]149    }
150
[5816]151    public override IDeepCloneable Clone(Cloner cloner) {
152      return new ClassificationEnsembleSolution(this, cloner);
153    }
[6613]154    private void RegisterClassificationSolutionsEventHandler() {
155      classificationSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsAdded);
156      classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved);
157      classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset);
[7549]158      classificationSolutions.CheckedItemsChanged += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CheckedItemsChanged);
[6613]159    }
[5816]160
[6589]161    protected override void RecalculateResults() {
[7549]162      weightCalculator.CalculateNormalizedWeights(classificationSolutions.CheckedItems);
[6589]163      CalculateResults();
164    }
165
[6613]166    #region Evaluation
[6239]167    public override IEnumerable<double> EstimatedTrainingClassValues {
[7549]168      get {
169        return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems,
170                                                              ProblemData.Dataset,
171                                                              ProblemData.TrainingIndizes,
172                                                              weightCalculator.GetTrainingClassDelegate());
173      }
[6239]174    }
175
176    public override IEnumerable<double> EstimatedTestClassValues {
[7549]177      get {
178        return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems,
179                                                              ProblemData.Dataset,
180                                                              ProblemData.TestIndizes,
181                                                              weightCalculator.GetTestClassDelegate());
182      }
[6239]183    }
184
185    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
[7549]186      return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems,
187                                                            ProblemData.Dataset,
188                                                            rows,
189                                                            weightCalculator.GetAllClassDelegate());
[6239]190    }
191
[5816]192    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
[7549]193      IEnumerable<IClassificationModel> models = classificationSolutions.CheckedItems.Select(sol => sol.Model);
194      if (!models.Any()) yield break;
195      var estimatedValuesEnumerators = (from model in models
[5816]196                                        select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
197                                       .ToList();
198
199      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
200        yield return from enumerator in estimatedValuesEnumerators
201                     select enumerator.Current;
202      }
203    }
[6613]204    #endregion
[6520]205
[6666]206    protected override void OnProblemDataChanged() {
207      IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset,
208                                                                     ProblemData.AllowedInputVariables,
209                                                                     ProblemData.TargetVariable);
210      problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
211      problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
212      problemData.TestPartition.Start = ProblemData.TestPartition.Start;
213      problemData.TestPartition.End = ProblemData.TestPartition.End;
214
215      foreach (var solution in ClassificationSolutions) {
216        if (solution is ClassificationEnsembleSolution)
217          solution.ProblemData = ProblemData;
218        else
219          solution.ProblemData = problemData;
220      }
[7596]221      foreach (var trainingPartition in trainingPartitions.Values) {
222        trainingPartition.Start = ProblemData.TrainingPartition.Start;
223        trainingPartition.End = ProblemData.TrainingPartition.End;
224      }
225      foreach (var testPartition in testPartitions.Values) {
226        testPartition.Start = ProblemData.TestPartition.Start;
227        testPartition.End = ProblemData.TestPartition.End;
228      }
229
[6666]230      base.OnProblemDataChanged();
231    }
232
[6613]233    public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
234      classificationSolutions.AddRange(solutions);
235    }
236    public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
237      classificationSolutions.RemoveRange(solutions);
238    }
[6520]239
[6613]240    private void classificationSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
241      foreach (var solution in e.Items) AddClassificationSolution(solution);
[6520]242      RecalculateResults();
243    }
[6613]244    private void classificationSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
245      foreach (var solution in e.Items) RemoveClassificationSolution(solution);
246      RecalculateResults();
247    }
248    private void classificationSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
249      foreach (var solution in e.OldItems) RemoveClassificationSolution(solution);
250      foreach (var solution in e.Items) AddClassificationSolution(solution);
251      RecalculateResults();
252    }
[7549]253    private void classificationSolutions_CheckedItemsChanged(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
254      RecalculateResults();
255    }
[6520]256
[6613]257    private void AddClassificationSolution(IClassificationSolution solution) {
258      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
259      Model.Add(solution.Model);
[7596]260      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
261      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
[6613]262    }
[6520]263
[6613]264    private void RemoveClassificationSolution(IClassificationSolution solution) {
265      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
266      Model.Remove(solution.Model);
[7596]267      trainingPartitions.Remove(solution.Model);
268      testPartitions.Remove(solution.Model);
[6520]269    }
[5816]270  }
271}
Note: See TracBrowser for help on using the repository browser.