Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/05/12 17:02:37 (12 years ago)
Author:
sforsten
Message:

#1776:

  • models can be selected with a check box
  • all strategies are now finished
  • major changes have been made to provide the same behaviour when getting the estimated training or test values of an ensemble
File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs

    r7531 r7549  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
    2524using HeuristicLab.Common;
    2625using HeuristicLab.Core;
     26using HeuristicLab.Data;
    2727using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2828using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
     
    4545    }
    4646
    47     private IEnumerable<double> weights;
     47    private IDictionary<IClassificationSolution, double> weights;
    4848
    4949    /// <summary>
     
    5252    /// <param name="classificationSolutions"></param>
    5353    /// <returns>weights which are equal or bigger than zero</returns>
    54     public void CalculateNormalizedWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
     54    public void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
    5555      List<double> weights = new List<double>();
    56       if (classificationSolutions.Count > 0) {
     56      if (classificationSolutions.Count() > 0) {
    5757        foreach (var weight in CalculateWeights(classificationSolutions)) {
    5858          weights.Add(weight >= 0 ? weight : 0);
    5959        }
    6060      }
    61       this.weights = weights.Select(x => x / weights.Sum());
     61      double sum = weights.Sum();
     62      this.weights = classificationSolutions.Zip(weights, (sol, wei) => new { sol, wei }).ToDictionary(x => x.sol, x => x.wei / sum);
    6263    }
    6364
    64     protected abstract IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions);
     65    protected abstract IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions);
    6566
    66     public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {
    67       return from xs in ClassificationWeightCalculator.GetEstimatedClassValues(models, dataset, rows)
     67    #region delegate CheckPoint
     68    public CheckPoint GetTestClassDelegate() {
     69      return PointInTest;
     70    }
     71    public CheckPoint GetTrainingClassDelegate() {
     72      return PointInTraining;
     73    }
     74    public CheckPoint GetAllClassDelegate() {
     75      return AllPoints;
     76    }
     77    #endregion
     78
     79    public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
     80      return from xs in GetEstimatedClassValues(solutions, dataset, rows, handler)
    6881             select AggregateEstimatedClassValues(xs);
    6982    }
    7083
    71     protected double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) {
    72       if (!estimatedClassValues.Count().Equals(weights.Count()))
    73         throw new ArgumentException("'estimatedClassValues' has " + estimatedClassValues.Count() + " elements, while 'weights' has" + weights.Count());
     84    protected double AggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues) {
    7485      IDictionary<double, double> weightSum = new Dictionary<double, double>();
    75       for (int i = 0; i < estimatedClassValues.Count(); i++) {
    76         if (!weightSum.ContainsKey(estimatedClassValues.ElementAt(i)))
    77           weightSum[estimatedClassValues.ElementAt(i)] = 0.0;
    78         weightSum[estimatedClassValues.ElementAt(i)] += weights.ElementAt(i);
     86      foreach (var item in estimatedClassValues) {
     87        if (!weightSum.ContainsKey(item.Value))
     88          weightSum[item.Value] = 0.0;
     89        weightSum[item.Value] += weights[item.Key];
    7990      }
    8091      if (weightSum.Count <= 0)
     
    8899    }
    89100
    90     protected static IEnumerable<IEnumerable<double>> GetEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {
    91       if (!models.Any()) yield break;
    92       var estimatedValuesEnumerators = (from model in models
    93                                         select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
     101    protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
     102      if (!solutions.Any()) yield break;
     103      var estimatedValuesEnumerators = (from solution in solutions
     104                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
    94105                                       .ToList();
    95106
    96       while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
    97         yield return from enumerator in estimatedValuesEnumerators
    98                      select enumerator.Current;
     107      var rowEnumerator = rows.GetEnumerator();
     108      while (rowEnumerator.MoveNext() & estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
     109        yield return (from enumerator in estimatedValuesEnumerators
     110                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
     111                      select enumerator)
     112                     .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current);
    99113      }
    100114    }
     
    105119             select targetValues[i];
    106120    }
     121    protected bool PointInTraining(IClassificationProblemData problemData, int point) {
     122      IntRange trainingPartition = problemData.TrainingPartition;
     123      IntRange testPartition = problemData.TestPartition;
     124      return (trainingPartition.Start <= point && point < trainingPartition.End)
     125        && !(testPartition.Start <= point && point < testPartition.End);
     126    }
     127    protected bool PointInTest(IClassificationProblemData problemData, int point) {
     128      IntRange testPartition = problemData.TestPartition;
     129      return testPartition.Start <= point && point < testPartition.End;
     130    }
     131    protected bool AllPoints(IClassificationProblemData problemData, int point) {
     132      return true;
     133    }
    107134    #endregion
    108135  }
Note: See TracChangeset for help on using the changeset viewer.