Changeset 7531 for branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation
- Timestamp:
- 02/27/12 16:11:34 (13 years ago)
- Location:
- branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
- Files:
-
- 4 added
- 1 deleted
- 5 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r7504 r7531 51 51 } 52 52 53 public IEnumerable<double> Weights {54 get { return new List<double>(weights); }55 }56 57 53 [Storable] 58 54 private Dictionary<IClassificationModel, IntRange> trainingPartitions; … … 60 56 private Dictionary<IClassificationModel, IntRange> testPartitions; 61 57 62 private IEnumerable<double> weights;63 64 58 private IClassificationEnsembleSolutionWeightCalculator weightCalculator; 65 59 … … 68 62 if (value != null) { 69 63 weightCalculator = value; 70 weight s = weights = weightCalculator.CalculateNormalizedWeights(classificationSolutions);64 weightCalculator.CalculateNormalizedWeights(classificationSolutions); 71 65 if (!ProblemData.IsEmpty) 72 66 RecalculateResults(); … … 169 163 #region Evaluation 170 164 public override IEnumerable<double> EstimatedTrainingClassValues { 171 get { 172 var rows = ProblemData.TrainingIndizes; 173 var estimatedValuesEnumerators = (from model in Model.Models 174 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 175 .ToList(); 176 var rowsEnumerator = rows.GetEnumerator(); 177 // aggregate to make sure that MoveNext is called for all enumerators 178 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 179 int currentRow = rowsEnumerator.Current; 180 181 var selectedEnumerators = from pair in estimatedValuesEnumerators 182 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 183 select pair.EstimatedValuesEnumerator; 184 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights); 185 } 186 } 165 get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TrainingIndizes); } 187 166 } 188 167 189 168 public override IEnumerable<double> EstimatedTestClassValues { 190 get { 191 var rows = ProblemData.TestIndizes; 192 var estimatedValuesEnumerators = (from model in Model.Models 193 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 194 .ToList(); 195 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 196 // aggregate to make sure that MoveNext is called for all enumerators 197 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 198 int currentRow = rowsEnumerator.Current; 199 200 var selectedEnumerators = from pair in estimatedValuesEnumerators 201 where RowIsTestForModel(currentRow, pair.Model) 202 select pair.EstimatedValuesEnumerator; 203 204 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights); 205 } 206 } 169 get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TestIndizes); } 207 170 } 208 171 … … 218 181 219 182 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 220 return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows) 221 select AggregateEstimatedClassValues(xs, weights); 183 return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, rows); 222 184 } 223 185 … … 232 194 select enumerator.Current; 233 195 } 234 }235 236 private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues, IEnumerable<double> weights) {237 IDictionary<double, double> weightSum = new Dictionary<double, double>();238 for (int i = 0; i < estimatedClassValues.Count(); i++) {239 if (!weightSum.ContainsKey(estimatedClassValues.ElementAt(i)))240 weightSum[estimatedClassValues.ElementAt(i)] = 0.0;241 weightSum[estimatedClassValues.ElementAt(i)] += weights.ElementAt(i);242 }243 if (weightSum.Count <= 0)244 return double.NaN;245 var max = weightSum.Max(x => x.Value);246 max = weightSum247 .Where(x => x.Value.Equals(max))248 .Select(x => x.Key)249 .First();250 return max;251 //old code252 //return weightSum253 // .Where(x => x.Value.Equals(max))254 // .Select(x => x.Key)255 // .First();256 //return estimatedClassValues257 //.GroupBy(x => x)258 //.OrderBy(g => -g.Count())259 //.Select(g => g.Key)260 //.DefaultIfEmpty(double.NaN)261 //.First();262 196 } 263 197 #endregion … … 316 250 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 317 251 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 318 weight s = weightCalculator.CalculateNormalizedWeights(classificationSolutions);252 weightCalculator.CalculateNormalizedWeights(classificationSolutions); 319 253 } 320 254 … … 324 258 trainingPartitions.Remove(solution.Model); 325 259 testPartitions.Remove(solution.Model); 326 weight s = weightCalculator.CalculateNormalizedWeights(classificationSolutions);260 weightCalculator.CalculateNormalizedWeights(classificationSolutions); 327 261 } 328 262 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AccuracyWeightCalculator.cs
r7504 r7531 32 32 [StorableClass] 33 33 [Item("AccuracyWeightCalculator", "Represents a weight calculator that gives every classification solution a weight based on the accuracy.")] 34 public class AccuracyWeightCalculator : WeightCalculator {34 public class AccuracyWeightCalculator : ClassificationWeightCalculator { 35 35 36 36 public AccuracyWeightCalculator() 37 37 : base() { 38 38 } 39 40 39 [StorableConstructor] 41 40 protected AccuracyWeightCalculator(bool deserializing) : base(deserializing) { } … … 43 42 : base(original, cloner) { 44 43 } 45 46 44 public override IDeepCloneable Clone(Cloner cloner) { 47 45 return new AccuracyWeightCalculator(this, cloner); -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MajorityVoteWeightCalculator.cs
r7504 r7531 32 32 [StorableClass] 33 33 [Item("MajorityVoteWeightCalculator", "Represents a weight calculator that gives every classification solution the same weight.")] 34 public class MajorityVoteWeightCalculator : WeightCalculator {34 public class MajorityVoteWeightCalculator : ClassificationWeightCalculator { 35 35 36 36 public MajorityVoteWeightCalculator() -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/NeighbourhoodWeightCalculator.cs
r7504 r7531 33 33 [StorableClass] 34 34 [Item("NeighbourhoodWeightCalculator", "")] 35 public class NeighbourhoodWeightCalculator : WeightCalculator {35 public class NeighbourhoodWeightCalculator : DiscriminantClassificationWeightCalculator { 36 36 37 37 public NeighbourhoodWeightCalculator() … … 49 49 } 50 50 51 protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) { 52 if (classificationSolutions.Count <= 0) 53 return new List<double>(); 54 55 if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 56 return Enumerable.Repeat<double>(1, classificationSolutions.Count); 57 51 protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 58 52 List<List<double>> estimatedTrainingValEnumerators = new List<List<double>>(); 59 53 List<List<double>> estimatedTrainingClassValEnumerators = new List<List<double>>(); 60 IDiscriminantFunctionClassificationSolution discriminantSolution; 61 foreach (var solution in classificationSolutions) { 62 discriminantSolution = (IDiscriminantFunctionClassificationSolution)solution; 63 estimatedTrainingValEnumerators.Add(discriminantSolution.EstimatedTrainingValues.ToList()); 64 estimatedTrainingClassValEnumerators.Add(discriminantSolution.EstimatedTrainingClassValues.ToList()); 54 foreach (var solution in discriminantSolutions) { 55 estimatedTrainingValEnumerators.Add(solution.EstimatedTrainingValues.ToList()); 56 estimatedTrainingClassValEnumerators.Add(solution.EstimatedTrainingClassValues.ToList()); 65 57 } 66 58 67 List<double> weights = Enumerable.Repeat<double>(0, classificationSolutions.Count()).ToList<double>();59 List<double> weights = Enumerable.Repeat<double>(0, discriminantSolutions.Count()).ToList<double>(); 68 60 69 IClassificationProblemData problemData = classificationSolutions.ElementAt(0).ProblemData;61 IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData; 70 62 List<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(); 71 63 List<double> trainingVal = GetValues(targetValues, problemData.TrainingIndizes).ToList(); … … 93 85 return weights; 94 86 } 95 96 private IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {97 return from i in indizes98 select targetValues[i];99 }100 87 } 101 88 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/PointCertaintyWeightCalculator.cs
r7504 r7531 28 28 29 29 namespace HeuristicLab.Problems.DataAnalysis { 30 /// <summary>31 ///32 /// </summary>33 30 [StorableClass] 34 31 [Item("PointCertaintyWeightCalculator", "")] 35 public class PointCertaintyWeightCalculator : WeightCalculator {32 public class PointCertaintyWeightCalculator : DiscriminantClassificationWeightCalculator { 36 33 37 34 public PointCertaintyWeightCalculator() 38 35 : base() { 39 36 } 40 41 37 [StorableConstructor] 42 38 protected PointCertaintyWeightCalculator(bool deserializing) : base(deserializing) { } … … 44 40 : base(original, cloner) { 45 41 } 46 47 42 public override IDeepCloneable Clone(Cloner cloner) { 48 43 return new PointCertaintyWeightCalculator(this, cloner); 49 44 } 50 45 51 protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) { 52 if (classificationSolutions.Count <= 0) 53 return new List<double>(); 54 55 if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 56 return Enumerable.Repeat<double>(1, classificationSolutions.Count); 57 58 ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions = new ItemCollection<IDiscriminantFunctionClassificationSolution>(); 59 foreach (var solution in classificationSolutions) { 60 discriminantSolutions.Add((IDiscriminantFunctionClassificationSolution)solution); 61 } 62 46 protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 63 47 List<double> weights = new List<double>(); 64 IClassificationProblemData problemData = classificationSolutions.ElementAt(0).ProblemData;48 IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData; 65 49 IEnumerable<double> targetValues = GetValues(problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(), problemData.TrainingIndizes); 66 50 IEnumerator<double> trainingValues; … … 87 71 return weights; 88 72 } 89 90 private IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {91 return from i in indizes92 select targetValues[i];93 }94 73 } 95 74 }
Note: See TracChangeset
for help on using the changeset viewer.