Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/05/12 17:02:37 (13 years ago)
Author:
sforsten
Message:

#1776:

  • models can be selected with a check box
  • all strategies are now finished
  • major changes have been made to provide the same behaviour when getting the estimated training or test values of an ensemble
Location:
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
Files:
9 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r7531 r7549  
    4646    }
    4747
    48     private readonly ItemCollection<IClassificationSolution> classificationSolutions;
    49     public IItemCollection<IClassificationSolution> ClassificationSolutions {
     48    private readonly CheckedItemCollection<IClassificationSolution> classificationSolutions;
     49    public ICheckedItemCollection<IClassificationSolution> ClassificationSolutions {
    5050      get { return classificationSolutions; }
    5151    }
    5252
    53     [Storable]
    54     private Dictionary<IClassificationModel, IntRange> trainingPartitions;
    55     [Storable]
    56     private Dictionary<IClassificationModel, IntRange> testPartitions;
     53    //[Storable]
     54    //private Dictionary<IClassificationModel, IntRange> trainingPartitions;
     55    //[Storable]
     56    //private Dictionary<IClassificationModel, IntRange> testPartitions;
    5757
    5858    private IClassificationEnsembleSolutionWeightCalculator weightCalculator;
     
    6262        if (value != null) {
    6363          weightCalculator = value;
    64           weightCalculator.CalculateNormalizedWeights(classificationSolutions);
    6564          if (!ProblemData.IsEmpty)
    6665            RecalculateResults();
     
    7271    private ClassificationEnsembleSolution(bool deserializing)
    7372      : base(deserializing) {
    74       classificationSolutions = new ItemCollection<IClassificationSolution>();
     73      classificationSolutions = new CheckedItemCollection<IClassificationSolution>();
    7574    }
    7675    [StorableHook(HookType.AfterDeserialization)]
     
    7877      foreach (var model in Model.Models) {
    7978        IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone();
    80         problemData.TrainingPartition.Start = trainingPartitions[model].Start;
    81         problemData.TrainingPartition.End = trainingPartitions[model].End;
    82         problemData.TestPartition.Start = testPartitions[model].Start;
    83         problemData.TestPartition.End = testPartitions[model].End;
     79        //problemData.TrainingPartition.Start = trainingPartitions[model].Start;
     80        //problemData.TrainingPartition.End = trainingPartitions[model].End;
     81        //problemData.TestPartition.Start = testPartitions[model].Start;
     82        //problemData.TestPartition.End = testPartitions[model].End;
    8483
    8584        classificationSolutions.Add(model.CreateClassificationSolution(problemData));
     
    9089    private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
    9190      : base(original, cloner) {
    92       trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
    93       testPartitions = new Dictionary<IClassificationModel, IntRange>();
    94       foreach (var pair in original.trainingPartitions) {
    95         trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
    96       }
    97       foreach (var pair in original.testPartitions) {
    98         testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
    99       }
     91      //trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     92      //testPartitions = new Dictionary<IClassificationModel, IntRange>();
     93      //foreach (var pair in original.trainingPartitions) {
     94      //  trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     95      //}
     96      //foreach (var pair in original.testPartitions) {
     97      //  testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     98      //}
    10099
    101100      classificationSolutions = cloner.Clone(original.classificationSolutions);
     
    105104    public ClassificationEnsembleSolution()
    106105      : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) {
    107       trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
    108       testPartitions = new Dictionary<IClassificationModel, IntRange>();
    109       classificationSolutions = new ItemCollection<IClassificationSolution>();
     106      //trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     107      //testPartitions = new Dictionary<IClassificationModel, IntRange>();
     108      classificationSolutions = new CheckedItemCollection<IClassificationSolution>();
    110109
    111110      weightCalculator = new MajorityVoteWeightCalculator();
     
    122121    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
    123122      : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) {
    124       this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
    125       this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
    126       this.classificationSolutions = new ItemCollection<IClassificationSolution>();
     123      //this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     124      //this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
     125      this.classificationSolutions = new CheckedItemCollection<IClassificationSolution>();
    127126
    128127      List<IClassificationSolution> solutions = new List<IClassificationSolution>();
     
    155154      classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved);
    156155      classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset);
     156      classificationSolutions.CheckedItemsChanged += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CheckedItemsChanged);
    157157    }
    158158
    159159    protected override void RecalculateResults() {
     160      weightCalculator.CalculateNormalizedWeights(classificationSolutions.CheckedItems);
    160161      CalculateResults();
    161162    }
     
    163164    #region Evaluation
    164165    public override IEnumerable<double> EstimatedTrainingClassValues {
    165       get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TrainingIndizes); }
     166      get {
     167        return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems,
     168                                                              ProblemData.Dataset,
     169                                                              ProblemData.TrainingIndizes,
     170                                                              weightCalculator.GetTrainingClassDelegate());
     171      }
    166172    }
    167173
    168174    public override IEnumerable<double> EstimatedTestClassValues {
    169       get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TestIndizes); }
    170     }
    171 
    172     private bool RowIsTrainingForModel(int currentRow, IClassificationModel model) {
    173       return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
    174               (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
    175     }
    176 
    177     private bool RowIsTestForModel(int currentRow, IClassificationModel model) {
    178       return testPartitions == null || !testPartitions.ContainsKey(model) ||
    179               (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
     175      get {
     176        return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems,
     177                                                              ProblemData.Dataset,
     178                                                              ProblemData.TestIndizes,
     179                                                              weightCalculator.GetTestClassDelegate());
     180      }
    180181    }
    181182
    182183    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    183       return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, rows);
     184      return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems,
     185                                                            ProblemData.Dataset,
     186                                                            rows,
     187                                                            weightCalculator.GetAllClassDelegate());
    184188    }
    185189
    186190    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
    187       if (!Model.Models.Any()) yield break;
    188       var estimatedValuesEnumerators = (from model in Model.Models
     191      IEnumerable<IClassificationModel> models = classificationSolutions.CheckedItems.Select(sol => sol.Model);
     192      if (!models.Any()) yield break;
     193      var estimatedValuesEnumerators = (from model in models
    189194                                        select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
    190195                                       .ToList();
     
    212217          solution.ProblemData = problemData;
    213218      }
    214       foreach (var trainingPartition in trainingPartitions.Values) {
    215         trainingPartition.Start = ProblemData.TrainingPartition.Start;
    216         trainingPartition.End = ProblemData.TrainingPartition.End;
    217       }
    218       foreach (var testPartition in testPartitions.Values) {
    219         testPartition.Start = ProblemData.TestPartition.Start;
    220         testPartition.End = ProblemData.TestPartition.End;
    221       }
     219      //foreach (var trainingPartition in trainingPartitions.Values) {
     220      //  trainingPartition.Start = ProblemData.TrainingPartition.Start;
     221      //  trainingPartition.End = ProblemData.TrainingPartition.End;
     222      //}
     223      //foreach (var testPartition in testPartitions.Values) {
     224      //  testPartition.Start = ProblemData.TestPartition.Start;
     225      //  testPartition.End = ProblemData.TestPartition.End;
     226      //}
    222227
    223228      base.OnProblemDataChanged();
     
    244249      RecalculateResults();
    245250    }
     251    private void classificationSolutions_CheckedItemsChanged(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     252      RecalculateResults();
     253    }
    246254
    247255    private void AddClassificationSolution(IClassificationSolution solution) {
    248256      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
    249257      Model.Add(solution.Model);
    250       trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    251       testPartitions[solution.Model] = solution.ProblemData.TestPartition;
    252       weightCalculator.CalculateNormalizedWeights(classificationSolutions);
     258      //trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
     259      //testPartitions[solution.Model] = solution.ProblemData.TestPartition;
    253260    }
    254261
     
    256263      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
    257264      Model.Remove(solution.Model);
    258       trainingPartitions.Remove(solution.Model);
    259       testPartitions.Remove(solution.Model);
    260       weightCalculator.CalculateNormalizedWeights(classificationSolutions);
     265      //trainingPartitions.Remove(solution.Model);
     266      //testPartitions.Remove(solution.Model);
    261267    }
    262268  }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AccuracyWeightCalculator.cs

    r7531 r7549  
    4646    }
    4747
    48     protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
     48    protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
    4949      return classificationSolutions.Select(s => s.TrainingAccuracy);
    5050    }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs

    r7531 r7549  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
    2524using HeuristicLab.Common;
    2625using HeuristicLab.Core;
     26using HeuristicLab.Data;
    2727using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2828using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
     
    4545    }
    4646
    47     private IEnumerable<double> weights;
     47    private IDictionary<IClassificationSolution, double> weights;
    4848
    4949    /// <summary>
     
    5252    /// <param name="classificationSolutions"></param>
    5353    /// <returns>weights which are equal or bigger than zero</returns>
    54     public void CalculateNormalizedWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
     54    public void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
    5555      List<double> weights = new List<double>();
    56       if (classificationSolutions.Count > 0) {
     56      if (classificationSolutions.Count() > 0) {
    5757        foreach (var weight in CalculateWeights(classificationSolutions)) {
    5858          weights.Add(weight >= 0 ? weight : 0);
    5959        }
    6060      }
    61       this.weights = weights.Select(x => x / weights.Sum());
     61      double sum = weights.Sum();
     62      this.weights = classificationSolutions.Zip(weights, (sol, wei) => new { sol, wei }).ToDictionary(x => x.sol, x => x.wei / sum);
    6263    }
    6364
    64     protected abstract IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions);
     65    protected abstract IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions);
    6566
    66     public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {
    67       return from xs in ClassificationWeightCalculator.GetEstimatedClassValues(models, dataset, rows)
     67    #region delegate CheckPoint
     68    public CheckPoint GetTestClassDelegate() {
     69      return PointInTest;
     70    }
     71    public CheckPoint GetTrainingClassDelegate() {
     72      return PointInTraining;
     73    }
     74    public CheckPoint GetAllClassDelegate() {
     75      return AllPoints;
     76    }
     77    #endregion
     78
     79    public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
     80      return from xs in GetEstimatedClassValues(solutions, dataset, rows, handler)
    6881             select AggregateEstimatedClassValues(xs);
    6982    }
    7083
    71     protected double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) {
    72       if (!estimatedClassValues.Count().Equals(weights.Count()))
    73         throw new ArgumentException("'estimatedClassValues' has " + estimatedClassValues.Count() + " elements, while 'weights' has" + weights.Count());
     84    protected double AggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues) {
    7485      IDictionary<double, double> weightSum = new Dictionary<double, double>();
    75       for (int i = 0; i < estimatedClassValues.Count(); i++) {
    76         if (!weightSum.ContainsKey(estimatedClassValues.ElementAt(i)))
    77           weightSum[estimatedClassValues.ElementAt(i)] = 0.0;
    78         weightSum[estimatedClassValues.ElementAt(i)] += weights.ElementAt(i);
     86      foreach (var item in estimatedClassValues) {
     87        if (!weightSum.ContainsKey(item.Value))
     88          weightSum[item.Value] = 0.0;
     89        weightSum[item.Value] += weights[item.Key];
    7990      }
    8091      if (weightSum.Count <= 0)
     
    8899    }
    89100
    90     protected static IEnumerable<IEnumerable<double>> GetEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {
    91       if (!models.Any()) yield break;
    92       var estimatedValuesEnumerators = (from model in models
    93                                         select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
     101    protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
     102      if (!solutions.Any()) yield break;
     103      var estimatedValuesEnumerators = (from solution in solutions
     104                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
    94105                                       .ToList();
    95106
    96       while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
    97         yield return from enumerator in estimatedValuesEnumerators
    98                      select enumerator.Current;
     107      var rowEnumerator = rows.GetEnumerator();
     108      while (rowEnumerator.MoveNext() & estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
     109        yield return (from enumerator in estimatedValuesEnumerators
     110                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
     111                      select enumerator)
     112                     .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current);
    99113      }
    100114    }
     
    105119             select targetValues[i];
    106120    }
     121    protected bool PointInTraining(IClassificationProblemData problemData, int point) {
     122      IntRange trainingPartition = problemData.TrainingPartition;
     123      IntRange testPartition = problemData.TestPartition;
     124      return (trainingPartition.Start <= point && point < trainingPartition.End)
     125        && !(testPartition.Start <= point && point < testPartition.End);
     126    }
     127    protected bool PointInTest(IClassificationProblemData problemData, int point) {
     128      IntRange testPartition = problemData.TestPartition;
     129      return testPartition.Start <= point && point < testPartition.End;
     130    }
     131    protected bool AllPoints(IClassificationProblemData problemData, int point) {
     132      return true;
     133    }
    107134    #endregion
    108135  }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ContinuousPointCertaintyWeightCalculator.cs

    r7531 r7549  
    4949    }
    5050
    51     protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {
     51    protected override IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {
    5252      List<double> weights = new List<double>();
    5353      IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData;
    54       IEnumerable<double> targetValues = GetValues(problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(), problemData.TrainingIndizes);
     54      IEnumerable<double> targetValues;
    5555      IEnumerator<double> trainingValues;
    5656
    5757      //only works for binary classification
    5858      if (!problemData.ClassValues.Count().Equals(2))
    59         return Enumerable.Repeat<double>(1, discriminantSolutions.Count);
     59        return Enumerable.Repeat<double>(1, discriminantSolutions.Count());
    6060
    6161      double maxClass = problemData.ClassValues.Max();
     
    6464
    6565      foreach (var solution in discriminantSolutions) {
     66        problemData = solution.ProblemData;
     67        targetValues = GetValues(problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(), problemData.TrainingIndizes);
     68        trainingValues = targetValues.GetEnumerator();
     69
    6670        IEnumerator<double> estimatedTrainingVal = solution.EstimatedTrainingValues.GetEnumerator();
    6771        IEnumerator<double> estimatedTrainingClassVal = solution.EstimatedTrainingClassValues.GetEnumerator();
    6872
    69         trainingValues = targetValues.GetEnumerator();
    7073        double curWeight = 0.0;
    7174        while (estimatedTrainingVal.MoveNext() && estimatedTrainingClassVal.MoveNext() && trainingValues.MoveNext()) {
    72           //if (estimatedTrainingClassVal.Current.Equals(trainingValues.Current)) {
    7375          if (trainingValues.Current.Equals(maxClass)) {
    7476            if (estimatedTrainingVal.Current >= maxClass)
     
    8688            }
    8789          }
    88           //}
    8990        }
    90         weights.Add(curWeight);
     91        // normalize the weight (otherwise a model with a bigger training partition would probably be better)
     92        weights.Add(curWeight / targetValues.Count());
    9193      }
    9294      return weights;
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs

    r7531 r7549  
    2323using System.Linq;
    2424using HeuristicLab.Common;
    25 using HeuristicLab.Core;
    2625using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     26using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
    2727
    2828namespace HeuristicLab.Problems.DataAnalysis {
     
    4141    }
    4242
    43     protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
     43    protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
    4444      if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution))
    45         return Enumerable.Repeat<double>(1.0, classificationSolutions.Count);
     45        return Enumerable.Repeat<double>(1.0, classificationSolutions.Count());
    4646
    47       ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions = new ItemCollection<IDiscriminantFunctionClassificationSolution>();
    48       foreach (var solution in classificationSolutions) {
    49         discriminantSolutions.Add((IDiscriminantFunctionClassificationSolution)solution);
    50       }
     47      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = classificationSolutions.Cast<IDiscriminantFunctionClassificationSolution>();
    5148
    5249      return DiscriminantCalculateWeights(discriminantSolutions);
    5350    }
    5451
    55     protected abstract IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions);
     52    protected abstract IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions);
    5653
    57     public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {
    58       if (!models.All(x => x is IDiscriminantFunctionClassificationModel))
     54    public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
     55      if (!solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
    5956        return Enumerable.Repeat<double>(0.0, rows.Count());
    6057
    61       IEnumerable<IDiscriminantFunctionClassificationModel> discriminantModels = models.Cast<IDiscriminantFunctionClassificationModel>();
     58      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
    6259
    63       IEnumerable<IEnumerable<double>> estimatedClassValues = ClassificationWeightCalculator.GetEstimatedClassValues(models, dataset, rows);
    64       IEnumerable<IEnumerable<double>> estimatedValues = DiscriminantClassificationWeightCalculator.GetEstimatedValues(discriminantModels, dataset, rows);
     60      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler);
     61      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler);
    6562
    6663      return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values })
     
    6865    }
    6966
    70     protected virtual double DiscriminantAggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues, IEnumerable<double> estimatedValues) {
     67    protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) {
    7168      return AggregateEstimatedClassValues(estimatedClassValues);
    7269    }
    7370
    74     protected static IEnumerable<IEnumerable<double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {
    75       if (!models.Any()) yield break;
    76       var estimatedValuesEnumerators = (from model in models
    77                                         select model.GetEstimatedValues(dataset, rows).GetEnumerator())
    78                                        .ToList();
     71    protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
     72      if (!solutions.Any()) yield break;
     73      var estimatedValuesEnumerators = (from solution in solutions
     74                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
     75                                        .ToList();
    7976
    80       while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
    81         yield return from enumerator in estimatedValuesEnumerators
    82                      select enumerator.Current;
     77      var rowEnumerator = rows.GetEnumerator();
     78      while (rowEnumerator.MoveNext() && estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
     79        yield return (from enumerator in estimatedValuesEnumerators
     80                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
     81                      select enumerator)
     82                      .ToDictionary(x => (IClassificationSolution)x.Solution, x => x.EstimatedValuesEnumerator.Current);
    8383      }
    8484    }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MajorityVoteWeightCalculator.cs

    r7531 r7549  
    4848    }
    4949
    50     protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
    51       return Enumerable.Repeat<double>(1, classificationSolutions.Count);
     50    protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) {
     51      return Enumerable.Repeat<double>(1, classificationSolutions.Count());
    5252    }
    5353  }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs

    r7531 r7549  
    4646    protected double[] classValues;
    4747
    48     protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {
    49       List<List<double>> estimatedTrainingValEnumerators = new List<List<double>>();
    50       List<List<double>> estimatedTrainingClassValEnumerators = new List<List<double>>();
     48    /// <summary>
     49    ///
     50    /// </summary>
     51    /// <param name="discriminantSolutions"></param>
     52    /// <returns>median instead of weights, because it doesn't use any weights</returns>
     53    protected override IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {
     54      List<List<double>> estimatedValues = new List<List<double>>();
     55      List<List<double>> estimatedClassValues = new List<List<double>>();
     56
     57      List<IClassificationProblemData> solutionProblemData = discriminantSolutions.Select(sol => sol.ProblemData).ToList();
     58      Dataset dataSet = solutionProblemData[0].Dataset;
     59      IEnumerable<int> rows = Enumerable.Range(0, dataSet.Rows);
    5160      foreach (var solution in discriminantSolutions) {
    52         estimatedTrainingValEnumerators.Add(solution.EstimatedTrainingValues.ToList());
    53         estimatedTrainingClassValEnumerators.Add(solution.EstimatedTrainingClassValues.ToList());
     61        estimatedValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList());
     62        estimatedClassValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList());
    5463      }
    5564
    5665      List<double> median = new List<double>();
    57 
    58       IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData;
    59       List<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList();
    60       IEnumerable<double> trainingVal = GetValues(targetValues, problemData.TrainingIndizes);
    61 
    62       for (int i = 0; i < estimatedTrainingClassValEnumerators.First().Count; i++) {
    63         var points = (from solution in estimatedTrainingValEnumerators
    64                       select solution[i])
    65                       .OrderBy(p => p)
    66                       .ToList();
    67 
    68         median.Add(GetMedian(points));
     66      List<double> targetValues = dataSet.GetDoubleValues(solutionProblemData[0].TargetVariable).ToList();
     67      IList<double> curTrainingpoints = new List<double>();
     68      int removed = 0;
     69      int count = targetValues.Count;
     70      for (int point = 0; point < count; point++) {
     71        curTrainingpoints.Clear();
     72        for (int solutionPos = 0; solutionPos < solutionProblemData.Count; solutionPos++) {
     73          if (PointInTraining(solutionProblemData[solutionPos], point)) {
     74            curTrainingpoints.Add(estimatedValues[solutionPos][point]);
     75          }
     76        }
     77        if (curTrainingpoints.Count > 0)
     78          median.Add(GetMedian(curTrainingpoints.OrderBy(p => p).ToList()));
     79        else {
     80          //remove not used points
     81          targetValues.RemoveAt(point - removed);
     82          removed++;
     83        }
    6984      }
    70       AccuracyMaximizationThresholdCalculator.CalculateThresholds(problemData, median, trainingVal, out classValues, out threshold);
     85      AccuracyMaximizationThresholdCalculator.CalculateThresholds(solutionProblemData[0], median, targetValues, out classValues, out threshold);
    7186      return median;
    7287    }
    7388
    74     protected override double DiscriminantAggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues, IEnumerable<double> estimatedValues) {
     89    protected override double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) {
    7590      double classValue = classValues.First();
    76       double median = GetMedian(estimatedValues.ToList());
     91      IList<double> values = estimatedValues.Select(x => x.Value).ToList();
     92      if (values.Count <= 0)
     93        return double.NaN;
     94      double median = GetMedian(values);
    7795      for (int i = 0; i < classValues.Count(); i++) {
    7896        if (median > threshold[i])
     
    87105      int count = estimatedValues.Count;
    88106      if (count % 2 == 0)
    89         return 0.5 * (estimatedValues[count / 2] + estimatedValues[count / 2 + 1]);
     107        return 0.5 * (estimatedValues[count / 2 - 1] + estimatedValues[count / 2]);
    90108      else
    91         return estimatedValues[(count + 1) / 2];
     109        return estimatedValues[count / 2];
    92110    }
    93111  }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/NeighbourhoodWeightCalculator.cs

    r7531 r7549  
    4949    }
    5050
    51     protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {
    52       List<List<double>> estimatedTrainingValEnumerators = new List<List<double>>();
    53       List<List<double>> estimatedTrainingClassValEnumerators = new List<List<double>>();
     51    protected override IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {
     52      List<List<double>> estimatedValues = new List<List<double>>();
     53      List<List<double>> estimatedClassValues = new List<List<double>>();
     54
     55      List<IClassificationProblemData> solutionProblemData = discriminantSolutions.Select(sol => sol.ProblemData).ToList();
     56      Dataset dataSet = solutionProblemData[0].Dataset;
     57      IEnumerable<int> rows = Enumerable.Range(0, dataSet.Rows);
    5458      foreach (var solution in discriminantSolutions) {
    55         estimatedTrainingValEnumerators.Add(solution.EstimatedTrainingValues.ToList());
    56         estimatedTrainingClassValEnumerators.Add(solution.EstimatedTrainingClassValues.ToList());
     59        estimatedValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList());
     60        estimatedClassValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList());
    5761      }
    5862
    59       List<double> weights = Enumerable.Repeat<double>(0, discriminantSolutions.Count()).ToList<double>();
    60 
    61       IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData;
    62       List<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList();
    63       List<double> trainingVal = GetValues(targetValues, problemData.TrainingIndizes).ToList();
     63      List<double> weights = Enumerable.Repeat<double>(0, solutionProblemData.Count).ToList<double>();
     64      List<double> targetValues = dataSet.GetDoubleValues(solutionProblemData[0].TargetVariable).ToList();
    6465
    6566      double pointAvg, help;
    6667      int count;
    67       for (int point = 0; point < estimatedTrainingClassValEnumerators.First().Count; point++) {
     68      for (int point = 0; point < targetValues.Count; point++) {
    6869        pointAvg = 0.0;
    6970        count = 0;
    70         for (int solution = 0; solution < estimatedTrainingClassValEnumerators.Count; solution++) {
    71           if (estimatedTrainingClassValEnumerators[solution][point].Equals(targetValues[point])) {
    72             pointAvg += estimatedTrainingValEnumerators[solution][point];
     71        for (int solutionPos = 0; solutionPos < estimatedClassValues.Count; solutionPos++) {
     72          if (PointInTraining(solutionProblemData[solutionPos], point)
     73              && estimatedClassValues[solutionPos][point].Equals(targetValues[point])) {
     74            pointAvg += estimatedValues[solutionPos][point];
    7375            count++;
    7476          }
    7577        }
    7678        pointAvg /= (double)count;
    77         for (int solution = 0; solution < estimatedTrainingClassValEnumerators.Count; solution++) {
    78           if (estimatedTrainingClassValEnumerators[solution][point].Equals(targetValues[point])) {
    79             weights[solution] += 0.5;
    80             help = Math.Abs(estimatedTrainingValEnumerators[solution][point] - 0.5);
    81             weights[solution] += help < 0.5 ? 0.5 - help : 0.0;
     79        for (int solutionPos = 0; solutionPos < estimatedClassValues.Count; solutionPos++) {
     80          if (PointInTraining(solutionProblemData[solutionPos], point)
     81              && estimatedClassValues[solutionPos][point].Equals(targetValues[point])) {
     82            weights[solutionPos] += 0.5;
     83            help = Math.Abs(estimatedValues[solutionPos][point] - 0.5);
     84            weights[solutionPos] += help < 0.5 ? 0.5 - help : 0.0;
    8285          }
    8386        }
     87      }
     88      // normalize the weight (otherwise a model with a bigger training partition would probably be better)
     89      for (int i = 0; i < weights.Count; i++) {
     90        weights[i] = weights[i] / solutionProblemData[i].TrainingIndizes.Count();
    8491      }
    8592      return weights;
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/PointCertaintyWeightCalculator.cs

    r7531 r7549  
    4444    }
    4545
    46     protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {
     46    protected override IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {
    4747      List<double> weights = new List<double>();
    4848      IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData;
    49       IEnumerable<double> targetValues = GetValues(problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(), problemData.TrainingIndizes);
    50       IEnumerator<double> trainingValues;
     49      // class Values are the same in all problem data sets
    5150      double avg = problemData.ClassValues.Average();
    5251
     52      IEnumerable<double> targetValues;
     53      IEnumerator<double> trainingValues;
     54
    5355      foreach (var solution in discriminantSolutions) {
     56        problemData = solution.ProblemData;
     57        targetValues = GetValues(problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(), problemData.TrainingIndizes);
     58        trainingValues = targetValues.GetEnumerator();
     59
    5460        IEnumerator<double> estimatedTrainingVal = solution.EstimatedTrainingValues.GetEnumerator();
    5561        IEnumerator<double> estimatedTrainingClassVal = solution.EstimatedTrainingClassValues.GetEnumerator();
    5662
    57         trainingValues = targetValues.GetEnumerator();
    5863        double curWeight = 0.0;
    5964        while (estimatedTrainingVal.MoveNext() && estimatedTrainingClassVal.MoveNext() && trainingValues.MoveNext()) {
     
    6772          }
    6873        }
    69         weights.Add(curWeight);
     74        // normalize the weight (otherwise a model with a bigger training partition would probably be better)
     75        weights.Add(curWeight / targetValues.Count());
    7076      }
    7177      return weights;
Note: See TracChangeset for help on using the changeset viewer.