Changeset 7549 for branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs
- Timestamp:
- 03/05/12 17:02:37 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs
r7531 r7549 20 20 #endregion 21 21 22 using System;23 22 using System.Collections.Generic; 24 23 using System.Linq; 25 24 using HeuristicLab.Common; 26 25 using HeuristicLab.Core; 26 using HeuristicLab.Data; 27 27 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 28 using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification; … … 45 45 } 46 46 47 private I Enumerable<double> weights;47 private IDictionary<IClassificationSolution, double> weights; 48 48 49 49 /// <summary> … … 52 52 /// <param name="classificationSolutions"></param> 53 53 /// <returns>weights which are equal or bigger than zero</returns> 54 public void CalculateNormalizedWeights(I temCollection<IClassificationSolution> classificationSolutions) {54 public void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions) { 55 55 List<double> weights = new List<double>(); 56 if (classificationSolutions.Count > 0) {56 if (classificationSolutions.Count() > 0) { 57 57 foreach (var weight in CalculateWeights(classificationSolutions)) { 58 58 weights.Add(weight >= 0 ? weight : 0); 59 59 } 60 60 } 61 this.weights = weights.Select(x => x / weights.Sum()); 61 double sum = weights.Sum(); 62 this.weights = classificationSolutions.Zip(weights, (sol, wei) => new { sol, wei }).ToDictionary(x => x.sol, x => x.wei / sum); 62 63 } 63 64 64 protected abstract IEnumerable<double> CalculateWeights(I temCollection<IClassificationSolution> classificationSolutions);65 protected abstract IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions); 65 66 66 public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) { 67 return from xs in ClassificationWeightCalculator.GetEstimatedClassValues(models, dataset, rows) 67 #region delegate CheckPoint 68 public CheckPoint GetTestClassDelegate() { 69 return PointInTest; 70 } 71 public CheckPoint GetTrainingClassDelegate() { 72 return PointInTraining; 73 } 74 public CheckPoint GetAllClassDelegate() { 75 return AllPoints; 76 } 77 #endregion 78 79 public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) { 80 return from xs in GetEstimatedClassValues(solutions, dataset, rows, handler) 68 81 select AggregateEstimatedClassValues(xs); 69 82 } 70 83 71 protected double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) { 72 if (!estimatedClassValues.Count().Equals(weights.Count())) 73 throw new ArgumentException("'estimatedClassValues' has " + estimatedClassValues.Count() + " elements, while 'weights' has" + weights.Count()); 84 protected double AggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues) { 74 85 IDictionary<double, double> weightSum = new Dictionary<double, double>(); 75 for (int i = 0; i < estimatedClassValues.Count(); i++) {76 if (!weightSum.ContainsKey( estimatedClassValues.ElementAt(i)))77 weightSum[ estimatedClassValues.ElementAt(i)] = 0.0;78 weightSum[ estimatedClassValues.ElementAt(i)] += weights.ElementAt(i);86 foreach (var item in estimatedClassValues) { 87 if (!weightSum.ContainsKey(item.Value)) 88 weightSum[item.Value] = 0.0; 89 weightSum[item.Value] += weights[item.Key]; 79 90 } 80 91 if (weightSum.Count <= 0) … … 88 99 } 89 100 90 protected static IEnumerable<IEnumerable<double>> GetEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {91 if (! models.Any()) yield break;92 var estimatedValuesEnumerators = (from model in models93 select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())101 protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) { 102 if (!solutions.Any()) yield break; 103 var estimatedValuesEnumerators = (from solution in solutions 104 select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() }) 94 105 .ToList(); 95 106 96 while (estimatedValuesEnumerators.All(en => en.MoveNext())) { 97 yield return from enumerator in estimatedValuesEnumerators 98 select enumerator.Current; 107 var rowEnumerator = rows.GetEnumerator(); 108 while (rowEnumerator.MoveNext() & estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) { 109 yield return (from enumerator in estimatedValuesEnumerators 110 where handler(enumerator.Solution.ProblemData, rowEnumerator.Current) 111 select enumerator) 112 .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current); 99 113 } 100 114 } … … 105 119 select targetValues[i]; 106 120 } 121 protected bool PointInTraining(IClassificationProblemData problemData, int point) { 122 IntRange trainingPartition = problemData.TrainingPartition; 123 IntRange testPartition = problemData.TestPartition; 124 return (trainingPartition.Start <= point && point < trainingPartition.End) 125 && !(testPartition.Start <= point && point < testPartition.End); 126 } 127 protected bool PointInTest(IClassificationProblemData problemData, int point) { 128 IntRange testPartition = problemData.TestPartition; 129 return testPartition.Start <= point && point < testPartition.End; 130 } 131 protected bool AllPoints(IClassificationProblemData problemData, int point) { 132 return true; 133 } 107 134 #endregion 108 135 }
Note: See TracChangeset
for help on using the changeset viewer.