Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/05/12 17:02:37 (12 years ago)
Author:
sforsten
Message:

#1776:

  • models can be selected with a check box
  • all strategies are now finished
  • major changes have been made to provide the same behaviour when getting the estimated training or test values of an ensemble
File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs

    r7531 r7549  
    2323using System.Linq;
    2424using HeuristicLab.Common;
    25 using HeuristicLab.Core;
    2625using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     26using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
    2727
    2828namespace HeuristicLab.Problems.DataAnalysis {
     
    4141    }
    4242
    43     protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
     43    protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
    4444      if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution))
    45         return Enumerable.Repeat<double>(1.0, classificationSolutions.Count);
     45        return Enumerable.Repeat<double>(1.0, classificationSolutions.Count());
    4646
    47       ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions = new ItemCollection<IDiscriminantFunctionClassificationSolution>();
    48       foreach (var solution in classificationSolutions) {
    49         discriminantSolutions.Add((IDiscriminantFunctionClassificationSolution)solution);
    50       }
     47      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = classificationSolutions.Cast<IDiscriminantFunctionClassificationSolution>();
    5148
    5249      return DiscriminantCalculateWeights(discriminantSolutions);
    5350    }
    5451
    55     protected abstract IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions);
     52    protected abstract IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions);
    5653
    57     public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {
    58       if (!models.All(x => x is IDiscriminantFunctionClassificationModel))
     54    public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
     55      if (!solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
    5956        return Enumerable.Repeat<double>(0.0, rows.Count());
    6057
    61       IEnumerable<IDiscriminantFunctionClassificationModel> discriminantModels = models.Cast<IDiscriminantFunctionClassificationModel>();
     58      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
    6259
    63       IEnumerable<IEnumerable<double>> estimatedClassValues = ClassificationWeightCalculator.GetEstimatedClassValues(models, dataset, rows);
    64       IEnumerable<IEnumerable<double>> estimatedValues = DiscriminantClassificationWeightCalculator.GetEstimatedValues(discriminantModels, dataset, rows);
     60      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler);
     61      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler);
    6562
    6663      return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values })
     
    6865    }
    6966
    70     protected virtual double DiscriminantAggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues, IEnumerable<double> estimatedValues) {
     67    protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) {
    7168      return AggregateEstimatedClassValues(estimatedClassValues);
    7269    }
    7370
    74     protected static IEnumerable<IEnumerable<double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {
    75       if (!models.Any()) yield break;
    76       var estimatedValuesEnumerators = (from model in models
    77                                         select model.GetEstimatedValues(dataset, rows).GetEnumerator())
    78                                        .ToList();
     71    protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
     72      if (!solutions.Any()) yield break;
     73      var estimatedValuesEnumerators = (from solution in solutions
     74                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
     75                                        .ToList();
    7976
    80       while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
    81         yield return from enumerator in estimatedValuesEnumerators
    82                      select enumerator.Current;
     77      var rowEnumerator = rows.GetEnumerator();
     78      while (rowEnumerator.MoveNext() && estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
     79        yield return (from enumerator in estimatedValuesEnumerators
     80                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
     81                      select enumerator)
     82                      .ToDictionary(x => (IClassificationSolution)x.Solution, x => x.EstimatedValuesEnumerator.Current);
    8383      }
    8484    }
Note: See TracChangeset for help on using the changeset viewer.