Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
02/06/12 16:55:27 (13 years ago)
Author:
sforsten
Message:

#1776: first implementation of different voting strategies (currently no gui elements are available to choose between the strategies)

Location:
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
Files:
4 added
1 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r7259 r7459  
    2828using HeuristicLab.Data;
    2929using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     30using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
    3031
    3132namespace HeuristicLab.Problems.DataAnalysis {
     
    5556    private Dictionary<IClassificationModel, IntRange> testPartitions;
    5657
     58    private IClassificationEnsembleSolutionWeightCalculator weightCalculator;
     59
    5760    [StorableConstructor]
    5861    private ClassificationEnsembleSolution(bool deserializing)
     
    9598      classificationSolutions = new ItemCollection<IClassificationSolution>();
    9699
     100      weightCalculator = new AccuracyWeightCalculator();
     101
    97102      RegisterClassificationSolutionsEventHandler();
    98103    }
     
    153158                                         .ToList();
    154159        var rowsEnumerator = rows.GetEnumerator();
     160        IEnumerable<double> weights = weightCalculator.CalculateWeights(classificationSolutions);
    155161        // aggregate to make sure that MoveNext is called for all enumerators
    156162        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     
    160166                                    where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
    161167                                    select pair.EstimatedValuesEnumerator;
    162           yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     168          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights);
    163169        }
    164170      }
     
    172178                                         .ToList();
    173179        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
     180        IEnumerable<double> weights = weightCalculator.CalculateWeights(classificationSolutions);
    174181        // aggregate to make sure that MoveNext is called for all enumerators
    175182        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     
    180187                                    select pair.EstimatedValuesEnumerator;
    181188
    182           yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     189          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights);
    183190        }
    184191      }
     
    196203
    197204    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     205      IEnumerable<double> weights = weightCalculator.CalculateWeights(classificationSolutions);
    198206      return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows)
    199              select AggregateEstimatedClassValues(xs);
     207             select AggregateEstimatedClassValues(xs, weights);
    200208    }
    201209
     
    212220    }
    213221
    214     private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) {
    215       return estimatedClassValues
    216       .GroupBy(x => x)
    217       .OrderBy(g => -g.Count())
    218       .Select(g => g.Key)
    219       .DefaultIfEmpty(double.NaN)
    220       .First();
     222    private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues, IEnumerable<double> weights) {
     223      IDictionary<double, double> weightSum = new Dictionary<double, double>();
     224      for (int i = 0; i < estimatedClassValues.Count(); i++) {
     225        if (!weightSum.ContainsKey(estimatedClassValues.ElementAt(i)))
     226          weightSum[estimatedClassValues.ElementAt(i)] = 0.0;
     227        weightSum[estimatedClassValues.ElementAt(i)] += weights.ElementAt(i);
     228      }
     229      if (weightSum.Count <= 0)
     230        return double.NaN;
     231      var max = weightSum.Max(x => x.Value);
     232      max = weightSum
     233        .Where(x => x.Value.Equals(max))
     234        .Select(x => x.Key)
     235        .First();
     236      return max;
     237      //old code
     238      //return weightSum
     239      //  .Where(x => x.Value.Equals(max))
     240      //  .Select(x => x.Key)
     241      //  .First();
     242      //return estimatedClassValues
     243      //.GroupBy(x => x)
     244      //.OrderBy(g => -g.Count())
     245      //.Select(g => g.Key)
     246      //.DefaultIfEmpty(double.NaN)
     247      //.First();
    221248    }
    222249    #endregion
Note: See TracChangeset for help on using the changeset viewer.