Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/14/11 13:59:25 (13 years ago)
Author:
epitzer
Message:

#1530 integrate changes from trunk

Location:
branches/PersistenceSpeedUp
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/PersistenceSpeedUp

  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis

  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r6184 r6760  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     25using HeuristicLab.Collections;
    2426using HeuristicLab.Common;
    2527using HeuristicLab.Core;
     28using HeuristicLab.Data;
    2629using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2730
     
    3235  [StorableClass]
    3336  [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")]
    34   // [Creatable("Data Analysis")]
    35   public class ClassificationEnsembleSolution : NamedItem, IClassificationEnsembleSolution {
     37  [Creatable("Data Analysis - Ensembles")]
     38  public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution {
     39    public new IClassificationEnsembleModel Model {
     40      get { return (IClassificationEnsembleModel)base.Model; }
     41    }
     42    public new ClassificationEnsembleProblemData ProblemData {
     43      get { return (ClassificationEnsembleProblemData)base.ProblemData; }
     44      set { base.ProblemData = value; }
     45    }
     46
     47    private readonly ItemCollection<IClassificationSolution> classificationSolutions;
     48    public IItemCollection<IClassificationSolution> ClassificationSolutions {
     49      get { return classificationSolutions; }
     50    }
    3651
    3752    [Storable]
    38     private List<IClassificationModel> models;
    39     public IEnumerable<IClassificationModel> Models {
    40       get { return new List<IClassificationModel>(models); }
    41     }
     53    private Dictionary<IClassificationModel, IntRange> trainingPartitions;
     54    [Storable]
     55    private Dictionary<IClassificationModel, IntRange> testPartitions;
     56
    4257    [StorableConstructor]
    43     protected ClassificationEnsembleSolution(bool deserializing) : base(deserializing) { }
    44     protected ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
     58    private ClassificationEnsembleSolution(bool deserializing)
     59      : base(deserializing) {
     60      classificationSolutions = new ItemCollection<IClassificationSolution>();
     61    }
     62    [StorableHook(HookType.AfterDeserialization)]
     63    private void AfterDeserialization() {
     64      foreach (var model in Model.Models) {
     65        IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone();
     66        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
     67        problemData.TrainingPartition.End = trainingPartitions[model].End;
     68        problemData.TestPartition.Start = testPartitions[model].Start;
     69        problemData.TestPartition.End = testPartitions[model].End;
     70
     71        classificationSolutions.Add(model.CreateClassificationSolution(problemData));
     72      }
     73      RegisterClassificationSolutionsEventHandler();
     74    }
     75
     76    private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
    4577      : base(original, cloner) {
    46       this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
    47     }
    48     public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models)
    49       : base() {
    50       this.name = ItemName;
    51       this.description = ItemDescription;
    52       this.models = new List<IClassificationModel>(models);
     78      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     79      testPartitions = new Dictionary<IClassificationModel, IntRange>();
     80      foreach (var pair in original.trainingPartitions) {
     81        trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     82      }
     83      foreach (var pair in original.testPartitions) {
     84        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     85      }
     86
     87      classificationSolutions = cloner.Clone(original.classificationSolutions);
     88      RegisterClassificationSolutionsEventHandler();
     89    }
     90
     91    public ClassificationEnsembleSolution()
     92      : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) {
     93      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     94      testPartitions = new Dictionary<IClassificationModel, IntRange>();
     95      classificationSolutions = new ItemCollection<IClassificationSolution>();
     96
     97      RegisterClassificationSolutionsEventHandler();
     98    }
     99
     100    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData)
     101      : this(models, problemData,
     102             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
     103             models.Select(m => (IntRange)problemData.TestPartition.Clone())
     104      ) { }
     105
     106    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
     107      : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) {
     108      this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     109      this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
     110      this.classificationSolutions = new ItemCollection<IClassificationSolution>();
     111
     112      List<IClassificationSolution> solutions = new List<IClassificationSolution>();
     113      var modelEnumerator = models.GetEnumerator();
     114      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
     115      var testPartitionEnumerator = testPartitions.GetEnumerator();
     116
     117      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
     118        var p = (IClassificationProblemData)problemData.Clone();
     119        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
     120        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
     121        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
     122        p.TestPartition.End = testPartitionEnumerator.Current.End;
     123
     124        solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p));
     125      }
     126      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
     127        throw new ArgumentException();
     128      }
     129
     130      RegisterClassificationSolutionsEventHandler();
     131      classificationSolutions.AddRange(solutions);
    53132    }
    54133
     
    56135      return new ClassificationEnsembleSolution(this, cloner);
    57136    }
    58 
    59     #region IClassificationEnsembleModel Members
     137    private void RegisterClassificationSolutionsEventHandler() {
     138      classificationSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsAdded);
     139      classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved);
     140      classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset);
     141    }
     142
     143    protected override void RecalculateResults() {
     144      CalculateResults();
     145    }
     146
     147    #region Evaluation
     148    public override IEnumerable<double> EstimatedTrainingClassValues {
     149      get {
     150        var rows = ProblemData.TrainingIndizes;
     151        var estimatedValuesEnumerators = (from model in Model.Models
     152                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     153                                         .ToList();
     154        var rowsEnumerator = rows.GetEnumerator();
     155        // aggregate to make sure that MoveNext is called for all enumerators
     156        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     157          int currentRow = rowsEnumerator.Current;
     158
     159          var selectedEnumerators = from pair in estimatedValuesEnumerators
     160                                    where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
     161                                    select pair.EstimatedValuesEnumerator;
     162          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     163        }
     164      }
     165    }
     166
     167    public override IEnumerable<double> EstimatedTestClassValues {
     168      get {
     169        var rows = ProblemData.TestIndizes;
     170        var estimatedValuesEnumerators = (from model in Model.Models
     171                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     172                                         .ToList();
     173        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
     174        // aggregate to make sure that MoveNext is called for all enumerators
     175        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     176          int currentRow = rowsEnumerator.Current;
     177
     178          var selectedEnumerators = from pair in estimatedValuesEnumerators
     179                                    where RowIsTestForModel(currentRow, pair.Model)
     180                                    select pair.EstimatedValuesEnumerator;
     181
     182          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     183        }
     184      }
     185    }
     186
     187    private bool RowIsTrainingForModel(int currentRow, IClassificationModel model) {
     188      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
     189              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
     190    }
     191
     192    private bool RowIsTestForModel(int currentRow, IClassificationModel model) {
     193      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     194              (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
     195    }
     196
     197    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     198      return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows)
     199             select AggregateEstimatedClassValues(xs);
     200    }
    60201
    61202    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
    62       var estimatedValuesEnumerators = (from model in models
     203      var estimatedValuesEnumerators = (from model in Model.Models
    63204                                        select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
    64205                                       .ToList();
     
    70211    }
    71212
     213    private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) {
     214      return estimatedClassValues
     215      .GroupBy(x => x)
     216      .OrderBy(g => -g.Count())
     217      .Select(g => g.Key)
     218      .DefaultIfEmpty(double.NaN)
     219      .First();
     220    }
    72221    #endregion
    73222
    74     #region IClassificationModel Members
    75 
    76     public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    77       foreach (var estimatedValuesVector in GetEstimatedClassValueVectors(dataset, rows)) {
    78         // return the class which is most often occuring
    79         yield return
    80           estimatedValuesVector
    81           .GroupBy(x => x)
    82           .OrderBy(g => -g.Count())
    83           .Select(g => g.Key)
    84           .First();
    85       }
    86     }
    87 
    88     #endregion
     223    protected override void OnProblemDataChanged() {
     224      IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset,
     225                                                                     ProblemData.AllowedInputVariables,
     226                                                                     ProblemData.TargetVariable);
     227      problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
     228      problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
     229      problemData.TestPartition.Start = ProblemData.TestPartition.Start;
     230      problemData.TestPartition.End = ProblemData.TestPartition.End;
     231
     232      foreach (var solution in ClassificationSolutions) {
     233        if (solution is ClassificationEnsembleSolution)
     234          solution.ProblemData = ProblemData;
     235        else
     236          solution.ProblemData = problemData;
     237      }
     238      foreach (var trainingPartition in trainingPartitions.Values) {
     239        trainingPartition.Start = ProblemData.TrainingPartition.Start;
     240        trainingPartition.End = ProblemData.TrainingPartition.End;
     241      }
     242      foreach (var testPartition in testPartitions.Values) {
     243        testPartition.Start = ProblemData.TestPartition.Start;
     244        testPartition.End = ProblemData.TestPartition.End;
     245      }
     246
     247      base.OnProblemDataChanged();
     248    }
     249
     250    public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
     251      classificationSolutions.AddRange(solutions);
     252    }
     253    public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
     254      classificationSolutions.RemoveRange(solutions);
     255    }
     256
     257    private void classificationSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     258      foreach (var solution in e.Items) AddClassificationSolution(solution);
     259      RecalculateResults();
     260    }
     261    private void classificationSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     262      foreach (var solution in e.Items) RemoveClassificationSolution(solution);
     263      RecalculateResults();
     264    }
     265    private void classificationSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     266      foreach (var solution in e.OldItems) RemoveClassificationSolution(solution);
     267      foreach (var solution in e.Items) AddClassificationSolution(solution);
     268      RecalculateResults();
     269    }
     270
     271    private void AddClassificationSolution(IClassificationSolution solution) {
     272      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
     273      Model.Add(solution.Model);
     274      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
     275      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
     276    }
     277
     278    private void RemoveClassificationSolution(IClassificationSolution solution) {
     279      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
     280      Model.Remove(solution.Model);
     281      trainingPartitions.Remove(solution.Model);
     282      testPartitions.Remove(solution.Model);
     283    }
    89284  }
    90285}
Note: See TracChangeset for help on using the changeset viewer.