Changeset 7549 for branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs
- Timestamp:
- 03/05/12 17:02:37 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs
r7531 r7549 23 23 using System.Linq; 24 24 using HeuristicLab.Common; 25 using HeuristicLab.Core;26 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 26 using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification; 27 27 28 28 namespace HeuristicLab.Problems.DataAnalysis { … … 41 41 } 42 42 43 protected override IEnumerable<double> CalculateWeights(I temCollection<IClassificationSolution> classificationSolutions) {43 protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) { 44 44 if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 45 return Enumerable.Repeat<double>(1.0, classificationSolutions.Count );45 return Enumerable.Repeat<double>(1.0, classificationSolutions.Count()); 46 46 47 ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions = new ItemCollection<IDiscriminantFunctionClassificationSolution>(); 48 foreach (var solution in classificationSolutions) { 49 discriminantSolutions.Add((IDiscriminantFunctionClassificationSolution)solution); 50 } 47 IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = classificationSolutions.Cast<IDiscriminantFunctionClassificationSolution>(); 51 48 52 49 return DiscriminantCalculateWeights(discriminantSolutions); 53 50 } 54 51 55 protected abstract IEnumerable<double> DiscriminantCalculateWeights(I temCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions);52 protected abstract IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions); 56 53 57 public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassification Model> models, Dataset dataset, IEnumerable<int> rows) {58 if (! models.All(x => x is IDiscriminantFunctionClassificationModel))54 public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) { 55 if (!solutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 59 56 return Enumerable.Repeat<double>(0.0, rows.Count()); 60 57 61 IEnumerable<IDiscriminantFunctionClassification Model> discriminantModels = models.Cast<IDiscriminantFunctionClassificationModel>();58 IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>(); 62 59 63 IEnumerable<I Enumerable<double>> estimatedClassValues = ClassificationWeightCalculator.GetEstimatedClassValues(models, dataset, rows);64 IEnumerable<I Enumerable<double>> estimatedValues = DiscriminantClassificationWeightCalculator.GetEstimatedValues(discriminantModels, dataset, rows);60 IEnumerable<IDictionary<IClassificationSolution, double>> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler); 61 IEnumerable<IDictionary<IClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler); 65 62 66 63 return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values }) … … 68 65 } 69 66 70 protected virtual double DiscriminantAggregateEstimatedClassValues(I Enumerable<double> estimatedClassValues, IEnumerable<double> estimatedValues) {67 protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) { 71 68 return AggregateEstimatedClassValues(estimatedClassValues); 72 69 } 73 70 74 protected static IEnumerable<IEnumerable<double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {75 if (! models.Any()) yield break;76 var estimatedValuesEnumerators = (from model in models77 select model.GetEstimatedValues(dataset, rows).GetEnumerator())78 .ToList();71 protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) { 72 if (!solutions.Any()) yield break; 73 var estimatedValuesEnumerators = (from solution in solutions 74 select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() }) 75 .ToList(); 79 76 80 while (estimatedValuesEnumerators.All(en => en.MoveNext())) { 81 yield return from enumerator in estimatedValuesEnumerators 82 select enumerator.Current; 77 var rowEnumerator = rows.GetEnumerator(); 78 while (rowEnumerator.MoveNext() && estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) { 79 yield return (from enumerator in estimatedValuesEnumerators 80 where handler(enumerator.Solution.ProblemData, rowEnumerator.Current) 81 select enumerator) 82 .ToDictionary(x => (IClassificationSolution)x.Solution, x => x.EstimatedValuesEnumerator.Current); 83 83 } 84 84 }
Note: See TracChangeset
for help on using the changeset viewer.