Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
02/27/12 16:11:34 (12 years ago)
Author:
sforsten
Message:

#1776:

  • 2 more strategies have been implemented
  • major changes in the inheritance have been made to make it possible to add strategies which don't use a voting strategy with weights
  • ClassificationEnsembleSolutionEstimatedClassValuesView doesn't currently show the confidence (has been removed for test purpose)
File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r7504 r7531  
    5151    }
    5252
    53     public IEnumerable<double> Weights {
    54       get { return new List<double>(weights); }
    55     }
    56 
    5753    [Storable]
    5854    private Dictionary<IClassificationModel, IntRange> trainingPartitions;
     
    6056    private Dictionary<IClassificationModel, IntRange> testPartitions;
    6157
    62     private IEnumerable<double> weights;
    63 
    6458    private IClassificationEnsembleSolutionWeightCalculator weightCalculator;
    6559
     
    6862        if (value != null) {
    6963          weightCalculator = value;
    70           weights = weights = weightCalculator.CalculateNormalizedWeights(classificationSolutions);
     64          weightCalculator.CalculateNormalizedWeights(classificationSolutions);
    7165          if (!ProblemData.IsEmpty)
    7266            RecalculateResults();
     
    169163    #region Evaluation
    170164    public override IEnumerable<double> EstimatedTrainingClassValues {
    171       get {
    172         var rows = ProblemData.TrainingIndizes;
    173         var estimatedValuesEnumerators = (from model in Model.Models
    174                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
    175                                          .ToList();
    176         var rowsEnumerator = rows.GetEnumerator();
    177         // aggregate to make sure that MoveNext is called for all enumerators
    178         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    179           int currentRow = rowsEnumerator.Current;
    180 
    181           var selectedEnumerators = from pair in estimatedValuesEnumerators
    182                                     where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
    183                                     select pair.EstimatedValuesEnumerator;
    184           yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights);
    185         }
    186       }
     165      get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TrainingIndizes); }
    187166    }
    188167
    189168    public override IEnumerable<double> EstimatedTestClassValues {
    190       get {
    191         var rows = ProblemData.TestIndizes;
    192         var estimatedValuesEnumerators = (from model in Model.Models
    193                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
    194                                          .ToList();
    195         var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
    196         // aggregate to make sure that MoveNext is called for all enumerators
    197         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    198           int currentRow = rowsEnumerator.Current;
    199 
    200           var selectedEnumerators = from pair in estimatedValuesEnumerators
    201                                     where RowIsTestForModel(currentRow, pair.Model)
    202                                     select pair.EstimatedValuesEnumerator;
    203 
    204           yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights);
    205         }
    206       }
     169      get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TestIndizes); }
    207170    }
    208171
     
    218181
    219182    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    220       return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows)
    221              select AggregateEstimatedClassValues(xs, weights);
     183      return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, rows);
    222184    }
    223185
     
    232194                     select enumerator.Current;
    233195      }
    234     }
    235 
    236     private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues, IEnumerable<double> weights) {
    237       IDictionary<double, double> weightSum = new Dictionary<double, double>();
    238       for (int i = 0; i < estimatedClassValues.Count(); i++) {
    239         if (!weightSum.ContainsKey(estimatedClassValues.ElementAt(i)))
    240           weightSum[estimatedClassValues.ElementAt(i)] = 0.0;
    241         weightSum[estimatedClassValues.ElementAt(i)] += weights.ElementAt(i);
    242       }
    243       if (weightSum.Count <= 0)
    244         return double.NaN;
    245       var max = weightSum.Max(x => x.Value);
    246       max = weightSum
    247         .Where(x => x.Value.Equals(max))
    248         .Select(x => x.Key)
    249         .First();
    250       return max;
    251       //old code
    252       //return weightSum
    253       //  .Where(x => x.Value.Equals(max))
    254       //  .Select(x => x.Key)
    255       //  .First();
    256       //return estimatedClassValues
    257       //.GroupBy(x => x)
    258       //.OrderBy(g => -g.Count())
    259       //.Select(g => g.Key)
    260       //.DefaultIfEmpty(double.NaN)
    261       //.First();
    262196    }
    263197    #endregion
     
    316250      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    317251      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
    318       weights = weightCalculator.CalculateNormalizedWeights(classificationSolutions);
     252      weightCalculator.CalculateNormalizedWeights(classificationSolutions);
    319253    }
    320254
     
    324258      trainingPartitions.Remove(solution.Model);
    325259      testPartitions.Remove(solution.Model);
    326       weights = weightCalculator.CalculateNormalizedWeights(classificationSolutions);
     260      weightCalculator.CalculateNormalizedWeights(classificationSolutions);
    327261    }
    328262  }
Note: See TracChangeset for help on using the changeset viewer.