Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/20/11 16:10:07 (12 years ago)
Author:
gkronber
Message:

#1450: implemented support for ensemble solutions for classification.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r6184 r6239  
    2525using HeuristicLab.Core;
    2626using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     27using HeuristicLab.Data;
     28using System;
    2729
    2830namespace HeuristicLab.Problems.DataAnalysis {
     
    3335  [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")]
    3436  // [Creatable("Data Analysis")]
    35   public class ClassificationEnsembleSolution : NamedItem, IClassificationEnsembleSolution {
     37  public class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution {
     38
     39    public new IClassificationEnsembleModel Model {
     40      set { base.Model = value; }
     41      get { return (IClassificationEnsembleModel)base.Model; }
     42    }
    3643
    3744    [Storable]
    38     private List<IClassificationModel> models;
    39     public IEnumerable<IClassificationModel> Models {
    40       get { return new List<IClassificationModel>(models); }
    41     }
     45    private Dictionary<IClassificationModel, IntRange> trainingPartitions;
     46    [Storable]
     47    private Dictionary<IClassificationModel, IntRange> testPartitions;
     48
     49
    4250    [StorableConstructor]
    4351    protected ClassificationEnsembleSolution(bool deserializing) : base(deserializing) { }
    4452    protected ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
    4553      : base(original, cloner) {
    46       this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
     54      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     55      testPartitions = new Dictionary<IClassificationModel, IntRange>();
     56      foreach (var model in Model.Models) {
     57        trainingPartitions[model] = (IntRange)ProblemData.TrainingPartition.Clone();
     58        testPartitions[model] = (IntRange)ProblemData.TestPartition.Clone();
     59      }
     60      RecalculateResults();
    4761    }
    48     public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models)
    49       : base() {
     62    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData)
     63      : base(new ClassificationEnsembleModel(models), new ClassificationEnsembleProblemData(problemData)) {
    5064      this.name = ItemName;
    5165      this.description = ItemDescription;
    52       this.models = new List<IClassificationModel>(models);
     66      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     67      testPartitions = new Dictionary<IClassificationModel, IntRange>();
     68      foreach (var model in models) {
     69        trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone();
     70        testPartitions[model] = (IntRange)problemData.TestPartition.Clone();
     71      }
     72      RecalculateResults();
     73    }
     74
     75    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
     76      : base(new ClassificationEnsembleModel(models), new ClassificationEnsembleProblemData(problemData)) {
     77      this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     78      this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
     79      var modelEnumerator = models.GetEnumerator();
     80      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
     81      var testPartitionEnumerator = testPartitions.GetEnumerator();
     82      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
     83        this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone();
     84        this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone();
     85      }
     86      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
     87        throw new ArgumentException();
     88      }
     89      RecalculateResults();
    5390    }
    5491
     
    5794    }
    5895
    59     #region IClassificationEnsembleModel Members
     96    public override IEnumerable<double> EstimatedTrainingClassValues {
     97      get {
     98        var rows = ProblemData.TrainingIndizes;
     99        var estimatedValuesEnumerators = (from model in Model.Models
     100                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     101                                         .ToList();
     102        var rowsEnumerator = rows.GetEnumerator();
     103        // aggregate to make sure that MoveNext is called for all enumerators
     104        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     105          int currentRow = rowsEnumerator.Current;
     106
     107          var selectedEnumerators = from pair in estimatedValuesEnumerators
     108                                    where trainingPartitions == null || !trainingPartitions.ContainsKey(pair.Model) ||
     109                                         (trainingPartitions[pair.Model].Start <= currentRow && currentRow < trainingPartitions[pair.Model].End)
     110                                    select pair.EstimatedValuesEnumerator;
     111          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     112        }
     113      }
     114    }
     115
     116    public override IEnumerable<double> EstimatedTestClassValues {
     117      get {
     118        var rows = ProblemData.TestIndizes;
     119        var estimatedValuesEnumerators = (from model in Model.Models
     120                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     121                                         .ToList();
     122        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
     123        // aggregate to make sure that MoveNext is called for all enumerators
     124        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     125          int currentRow = rowsEnumerator.Current;
     126
     127          var selectedEnumerators = from pair in estimatedValuesEnumerators
     128                                    where testPartitions == null || !testPartitions.ContainsKey(pair.Model) ||
     129                                      (testPartitions[pair.Model].Start <= currentRow && currentRow < testPartitions[pair.Model].End)
     130                                    select pair.EstimatedValuesEnumerator;
     131
     132          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     133        }
     134      }
     135    }
     136
     137    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     138      return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows)
     139             select AggregateEstimatedClassValues(xs);
     140    }
    60141
    61142    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
    62       var estimatedValuesEnumerators = (from model in models
     143      var estimatedValuesEnumerators = (from model in Model.Models
    63144                                        select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
    64145                                       .ToList();
     
    70151    }
    71152
    72     #endregion
    73 
    74     #region IClassificationModel Members
    75 
    76     public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    77       foreach (var estimatedValuesVector in GetEstimatedClassValueVectors(dataset, rows)) {
    78         // return the class which is most often occuring
    79         yield return
    80           estimatedValuesVector
    81           .GroupBy(x => x)
    82           .OrderBy(g => -g.Count())
    83           .Select(g => g.Key)
    84           .First();
    85       }
     153    private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) {
     154      return estimatedClassValues
     155      .GroupBy(x => x)
     156      .OrderBy(g => -g.Count())
     157      .Select(g => g.Key)
     158      .First();
    86159    }
    87 
    88     #endregion
    89160  }
    90161}
Note: See TracChangeset for help on using the changeset viewer.