Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
02/27/12 16:11:34 (13 years ago)
Author:
sforsten
Message:

#1776:

  • 2 more strategies have been implemented
  • major changes in the inheritance have been made to make it possible to add strategies which don't use a voting strategy with weights
  • ClassificationEnsembleSolutionEstimatedClassValuesView doesn't currently show the confidence (has been removed for test purpose)
Location:
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
Files:
4 added
1 deleted
5 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r7504 r7531  
    5151    }
    5252
    53     public IEnumerable<double> Weights {
    54       get { return new List<double>(weights); }
    55     }
    56 
    5753    [Storable]
    5854    private Dictionary<IClassificationModel, IntRange> trainingPartitions;
     
    6056    private Dictionary<IClassificationModel, IntRange> testPartitions;
    6157
    62     private IEnumerable<double> weights;
    63 
    6458    private IClassificationEnsembleSolutionWeightCalculator weightCalculator;
    6559
     
    6862        if (value != null) {
    6963          weightCalculator = value;
    70           weights = weights = weightCalculator.CalculateNormalizedWeights(classificationSolutions);
     64          weightCalculator.CalculateNormalizedWeights(classificationSolutions);
    7165          if (!ProblemData.IsEmpty)
    7266            RecalculateResults();
     
    169163    #region Evaluation
    170164    public override IEnumerable<double> EstimatedTrainingClassValues {
    171       get {
    172         var rows = ProblemData.TrainingIndizes;
    173         var estimatedValuesEnumerators = (from model in Model.Models
    174                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
    175                                          .ToList();
    176         var rowsEnumerator = rows.GetEnumerator();
    177         // aggregate to make sure that MoveNext is called for all enumerators
    178         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    179           int currentRow = rowsEnumerator.Current;
    180 
    181           var selectedEnumerators = from pair in estimatedValuesEnumerators
    182                                     where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
    183                                     select pair.EstimatedValuesEnumerator;
    184           yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights);
    185         }
    186       }
     165      get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TrainingIndizes); }
    187166    }
    188167
    189168    public override IEnumerable<double> EstimatedTestClassValues {
    190       get {
    191         var rows = ProblemData.TestIndizes;
    192         var estimatedValuesEnumerators = (from model in Model.Models
    193                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
    194                                          .ToList();
    195         var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
    196         // aggregate to make sure that MoveNext is called for all enumerators
    197         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    198           int currentRow = rowsEnumerator.Current;
    199 
    200           var selectedEnumerators = from pair in estimatedValuesEnumerators
    201                                     where RowIsTestForModel(currentRow, pair.Model)
    202                                     select pair.EstimatedValuesEnumerator;
    203 
    204           yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights);
    205         }
    206       }
     169      get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TestIndizes); }
    207170    }
    208171
     
    218181
    219182    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    220       return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows)
    221              select AggregateEstimatedClassValues(xs, weights);
     183      return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, rows);
    222184    }
    223185
     
    232194                     select enumerator.Current;
    233195      }
    234     }
    235 
    236     private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues, IEnumerable<double> weights) {
    237       IDictionary<double, double> weightSum = new Dictionary<double, double>();
    238       for (int i = 0; i < estimatedClassValues.Count(); i++) {
    239         if (!weightSum.ContainsKey(estimatedClassValues.ElementAt(i)))
    240           weightSum[estimatedClassValues.ElementAt(i)] = 0.0;
    241         weightSum[estimatedClassValues.ElementAt(i)] += weights.ElementAt(i);
    242       }
    243       if (weightSum.Count <= 0)
    244         return double.NaN;
    245       var max = weightSum.Max(x => x.Value);
    246       max = weightSum
    247         .Where(x => x.Value.Equals(max))
    248         .Select(x => x.Key)
    249         .First();
    250       return max;
    251       //old code
    252       //return weightSum
    253       //  .Where(x => x.Value.Equals(max))
    254       //  .Select(x => x.Key)
    255       //  .First();
    256       //return estimatedClassValues
    257       //.GroupBy(x => x)
    258       //.OrderBy(g => -g.Count())
    259       //.Select(g => g.Key)
    260       //.DefaultIfEmpty(double.NaN)
    261       //.First();
    262196    }
    263197    #endregion
     
    316250      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    317251      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
    318       weights = weightCalculator.CalculateNormalizedWeights(classificationSolutions);
     252      weightCalculator.CalculateNormalizedWeights(classificationSolutions);
    319253    }
    320254
     
    324258      trainingPartitions.Remove(solution.Model);
    325259      testPartitions.Remove(solution.Model);
    326       weights = weightCalculator.CalculateNormalizedWeights(classificationSolutions);
     260      weightCalculator.CalculateNormalizedWeights(classificationSolutions);
    327261    }
    328262  }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AccuracyWeightCalculator.cs

    r7504 r7531  
    3232  [StorableClass]
    3333  [Item("AccuracyWeightCalculator", "Represents a weight calculator that gives every classification solution a weight based on the accuracy.")]
    34   public class AccuracyWeightCalculator : WeightCalculator {
     34  public class AccuracyWeightCalculator : ClassificationWeightCalculator {
    3535
    3636    public AccuracyWeightCalculator()
    3737      : base() {
    3838    }
    39 
    4039    [StorableConstructor]
    4140    protected AccuracyWeightCalculator(bool deserializing) : base(deserializing) { }
     
    4342      : base(original, cloner) {
    4443    }
    45 
    4644    public override IDeepCloneable Clone(Cloner cloner) {
    4745      return new AccuracyWeightCalculator(this, cloner);
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MajorityVoteWeightCalculator.cs

    r7504 r7531  
    3232  [StorableClass]
    3333  [Item("MajorityVoteWeightCalculator", "Represents a weight calculator that gives every classification solution the same weight.")]
    34   public class MajorityVoteWeightCalculator : WeightCalculator {
     34  public class MajorityVoteWeightCalculator : ClassificationWeightCalculator {
    3535
    3636    public MajorityVoteWeightCalculator()
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/NeighbourhoodWeightCalculator.cs

    r7504 r7531  
    3333  [StorableClass]
    3434  [Item("NeighbourhoodWeightCalculator", "")]
    35   public class NeighbourhoodWeightCalculator : WeightCalculator {
     35  public class NeighbourhoodWeightCalculator : DiscriminantClassificationWeightCalculator {
    3636
    3737    public NeighbourhoodWeightCalculator()
     
    4949    }
    5050
    51     protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
    52       if (classificationSolutions.Count <= 0)
    53         return new List<double>();
    54 
    55       if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution))
    56         return Enumerable.Repeat<double>(1, classificationSolutions.Count);
    57 
     51    protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {
    5852      List<List<double>> estimatedTrainingValEnumerators = new List<List<double>>();
    5953      List<List<double>> estimatedTrainingClassValEnumerators = new List<List<double>>();
    60       IDiscriminantFunctionClassificationSolution discriminantSolution;
    61       foreach (var solution in classificationSolutions) {
    62         discriminantSolution = (IDiscriminantFunctionClassificationSolution)solution;
    63         estimatedTrainingValEnumerators.Add(discriminantSolution.EstimatedTrainingValues.ToList());
    64         estimatedTrainingClassValEnumerators.Add(discriminantSolution.EstimatedTrainingClassValues.ToList());
     54      foreach (var solution in discriminantSolutions) {
     55        estimatedTrainingValEnumerators.Add(solution.EstimatedTrainingValues.ToList());
     56        estimatedTrainingClassValEnumerators.Add(solution.EstimatedTrainingClassValues.ToList());
    6557      }
    6658
    67       List<double> weights = Enumerable.Repeat<double>(0, classificationSolutions.Count()).ToList<double>();
     59      List<double> weights = Enumerable.Repeat<double>(0, discriminantSolutions.Count()).ToList<double>();
    6860
    69       IClassificationProblemData problemData = classificationSolutions.ElementAt(0).ProblemData;
     61      IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData;
    7062      List<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList();
    7163      List<double> trainingVal = GetValues(targetValues, problemData.TrainingIndizes).ToList();
     
    9385      return weights;
    9486    }
    95 
    96     private IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {
    97       return from i in indizes
    98              select targetValues[i];
    99     }
    10087  }
    10188}
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/PointCertaintyWeightCalculator.cs

    r7504 r7531  
    2828
    2929namespace HeuristicLab.Problems.DataAnalysis {
    30   /// <summary>
    31   ///
    32   /// </summary>
    3330  [StorableClass]
    3431  [Item("PointCertaintyWeightCalculator", "")]
    35   public class PointCertaintyWeightCalculator : WeightCalculator {
     32  public class PointCertaintyWeightCalculator : DiscriminantClassificationWeightCalculator {
    3633
    3734    public PointCertaintyWeightCalculator()
    3835      : base() {
    3936    }
    40 
    4137    [StorableConstructor]
    4238    protected PointCertaintyWeightCalculator(bool deserializing) : base(deserializing) { }
     
    4440      : base(original, cloner) {
    4541    }
    46 
    4742    public override IDeepCloneable Clone(Cloner cloner) {
    4843      return new PointCertaintyWeightCalculator(this, cloner);
    4944    }
    5045
    51     protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
    52       if (classificationSolutions.Count <= 0)
    53         return new List<double>();
    54 
    55       if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution))
    56         return Enumerable.Repeat<double>(1, classificationSolutions.Count);
    57 
    58       ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions = new ItemCollection<IDiscriminantFunctionClassificationSolution>();
    59       foreach (var solution in classificationSolutions) {
    60         discriminantSolutions.Add((IDiscriminantFunctionClassificationSolution)solution);
    61       }
    62 
     46    protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {
    6347      List<double> weights = new List<double>();
    64       IClassificationProblemData problemData = classificationSolutions.ElementAt(0).ProblemData;
     48      IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData;
    6549      IEnumerable<double> targetValues = GetValues(problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(), problemData.TrainingIndizes);
    6650      IEnumerator<double> trainingValues;
     
    8771      return weights;
    8872    }
    89 
    90     private IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {
    91       return from i in indizes
    92              select targetValues[i];
    93     }
    9473  }
    9574}
Note: See TracChangeset for help on using the changeset viewer.