Changeset 7531
- Timestamp:
- 02/27/12 16:11:34 (13 years ago)
- Location:
- branches/ClassificationEnsembleVoting
- Files:
-
- 4 added
- 1 deleted
- 8 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionEstimatedClassValuesView.cs
r7464 r7531 102 102 List<List<double?>> estimatedValuesVector = GetEstimatedValues(SamplesComboBox.SelectedItem.ToString(), indizes, 103 103 Content.ClassificationSolutions); 104 List<double> weights = Content.Weights.ToList();105 double weightSum = weights.Sum();104 //List<double> weights = Content.Weights.ToList(); 105 //double weightSum = weights.Sum(); 106 106 107 107 for (int i = 0; i < indizes.Length; i++) { … … 114 114 values[i, 3] = (target[i].IsAlmost(estimatedClassValues[i])).ToString(); 115 115 116 IEnumerable<int> indices = FindAllIndices(estimatedValuesVector[i], estimatedClassValues[i]); 117 double confidence = 0.0; 118 foreach (var index in indices) { 119 confidence += weights[index]; 120 } 121 values[i, 4] = (confidence / weightSum).ToString(); 116 //currently disabled for test purpose 117 118 //IEnumerable<int> indices = FindAllIndices(estimatedValuesVector[i], estimatedClassValues[i]); 119 //double confidence = 0.0; 120 //foreach (var index in indices) { 121 // confidence += weights[index]; 122 //} 123 //values[i, 4] = (confidence / weightSum).ToString(); 122 124 //var estimationCount = groups.Where(g => g.Key != null).Select(g => g.Count).Sum(); 123 125 //values[i, 4] = 124 126 // (((double)groups.Where(g => g.Key == estimatedClassValues[i]).Single().Count) / estimationCount).ToString(); 127 values[i, 4] = "1.0"; 125 128 126 129 var groups = -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/HeuristicLab.Problems.DataAnalysis-3.4.csproj
r7504 r7531 130 130 <Compile Include="Implementation\Classification\WeightCalculators\AccuracyWeightCalculator.cs" /> 131 131 <Compile Include="Implementation\Classification\WeightCalculators\ContinuousPointCertaintyWeightCalculator.cs" /> 132 <Compile Include="Implementation\Classification\WeightCalculators\DiscriminantClassificationWeightCalculator.cs" /> 133 <Compile Include="Implementation\Classification\WeightCalculators\MedianThresholdCalculator.cs" /> 132 134 <Compile Include="Implementation\Classification\WeightCalculators\NeighbourhoodWeightCalculator.cs" /> 133 135 <Compile Include="Implementation\Classification\WeightCalculators\PointCertaintyWeightCalculator.cs" /> 134 136 <Compile Include="Implementation\Classification\WeightCalculators\MajorityVoteWeightCalculator.cs" /> 135 <Compile Include="Implementation\Classification\WeightCalculators\ WeightCalculator.cs" />137 <Compile Include="Implementation\Classification\WeightCalculators\ClassificationWeightCalculator.cs" /> 136 138 <Compile Include="Implementation\Clustering\ClusteringProblem.cs" /> 137 139 <Compile Include="Implementation\Clustering\ClusteringProblemData.cs" /> -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r7504 r7531 51 51 } 52 52 53 public IEnumerable<double> Weights {54 get { return new List<double>(weights); }55 }56 57 53 [Storable] 58 54 private Dictionary<IClassificationModel, IntRange> trainingPartitions; … … 60 56 private Dictionary<IClassificationModel, IntRange> testPartitions; 61 57 62 private IEnumerable<double> weights;63 64 58 private IClassificationEnsembleSolutionWeightCalculator weightCalculator; 65 59 … … 68 62 if (value != null) { 69 63 weightCalculator = value; 70 weight s = weights = weightCalculator.CalculateNormalizedWeights(classificationSolutions);64 weightCalculator.CalculateNormalizedWeights(classificationSolutions); 71 65 if (!ProblemData.IsEmpty) 72 66 RecalculateResults(); … … 169 163 #region Evaluation 170 164 public override IEnumerable<double> EstimatedTrainingClassValues { 171 get { 172 var rows = ProblemData.TrainingIndizes; 173 var estimatedValuesEnumerators = (from model in Model.Models 174 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 175 .ToList(); 176 var rowsEnumerator = rows.GetEnumerator(); 177 // aggregate to make sure that MoveNext is called for all enumerators 178 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 179 int currentRow = rowsEnumerator.Current; 180 181 var selectedEnumerators = from pair in estimatedValuesEnumerators 182 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 183 select pair.EstimatedValuesEnumerator; 184 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights); 185 } 186 } 165 get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TrainingIndizes); } 187 166 } 188 167 189 168 public override IEnumerable<double> EstimatedTestClassValues { 190 get { 191 var rows = ProblemData.TestIndizes; 192 var estimatedValuesEnumerators = (from model in Model.Models 193 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 194 .ToList(); 195 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 196 // aggregate to make sure that MoveNext is called for all enumerators 197 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 198 int currentRow = rowsEnumerator.Current; 199 200 var selectedEnumerators = from pair in estimatedValuesEnumerators 201 where RowIsTestForModel(currentRow, pair.Model) 202 select pair.EstimatedValuesEnumerator; 203 204 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights); 205 } 206 } 169 get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TestIndizes); } 207 170 } 208 171 … … 218 181 219 182 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 220 return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows) 221 select AggregateEstimatedClassValues(xs, weights); 183 return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, rows); 222 184 } 223 185 … … 232 194 select enumerator.Current; 233 195 } 234 }235 236 private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues, IEnumerable<double> weights) {237 IDictionary<double, double> weightSum = new Dictionary<double, double>();238 for (int i = 0; i < estimatedClassValues.Count(); i++) {239 if (!weightSum.ContainsKey(estimatedClassValues.ElementAt(i)))240 weightSum[estimatedClassValues.ElementAt(i)] = 0.0;241 weightSum[estimatedClassValues.ElementAt(i)] += weights.ElementAt(i);242 }243 if (weightSum.Count <= 0)244 return double.NaN;245 var max = weightSum.Max(x => x.Value);246 max = weightSum247 .Where(x => x.Value.Equals(max))248 .Select(x => x.Key)249 .First();250 return max;251 //old code252 //return weightSum253 // .Where(x => x.Value.Equals(max))254 // .Select(x => x.Key)255 // .First();256 //return estimatedClassValues257 //.GroupBy(x => x)258 //.OrderBy(g => -g.Count())259 //.Select(g => g.Key)260 //.DefaultIfEmpty(double.NaN)261 //.First();262 196 } 263 197 #endregion … … 316 250 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 317 251 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 318 weight s = weightCalculator.CalculateNormalizedWeights(classificationSolutions);252 weightCalculator.CalculateNormalizedWeights(classificationSolutions); 319 253 } 320 254 … … 324 258 trainingPartitions.Remove(solution.Model); 325 259 testPartitions.Remove(solution.Model); 326 weight s = weightCalculator.CalculateNormalizedWeights(classificationSolutions);260 weightCalculator.CalculateNormalizedWeights(classificationSolutions); 327 261 } 328 262 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AccuracyWeightCalculator.cs
r7504 r7531 32 32 [StorableClass] 33 33 [Item("AccuracyWeightCalculator", "Represents a weight calculator that gives every classification solution a weight based on the accuracy.")] 34 public class AccuracyWeightCalculator : WeightCalculator {34 public class AccuracyWeightCalculator : ClassificationWeightCalculator { 35 35 36 36 public AccuracyWeightCalculator() 37 37 : base() { 38 38 } 39 40 39 [StorableConstructor] 41 40 protected AccuracyWeightCalculator(bool deserializing) : base(deserializing) { } … … 43 42 : base(original, cloner) { 44 43 } 45 46 44 public override IDeepCloneable Clone(Cloner cloner) { 47 45 return new AccuracyWeightCalculator(this, cloner); -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MajorityVoteWeightCalculator.cs
r7504 r7531 32 32 [StorableClass] 33 33 [Item("MajorityVoteWeightCalculator", "Represents a weight calculator that gives every classification solution the same weight.")] 34 public class MajorityVoteWeightCalculator : WeightCalculator {34 public class MajorityVoteWeightCalculator : ClassificationWeightCalculator { 35 35 36 36 public MajorityVoteWeightCalculator() -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/NeighbourhoodWeightCalculator.cs
r7504 r7531 33 33 [StorableClass] 34 34 [Item("NeighbourhoodWeightCalculator", "")] 35 public class NeighbourhoodWeightCalculator : WeightCalculator {35 public class NeighbourhoodWeightCalculator : DiscriminantClassificationWeightCalculator { 36 36 37 37 public NeighbourhoodWeightCalculator() … … 49 49 } 50 50 51 protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) { 52 if (classificationSolutions.Count <= 0) 53 return new List<double>(); 54 55 if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 56 return Enumerable.Repeat<double>(1, classificationSolutions.Count); 57 51 protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 58 52 List<List<double>> estimatedTrainingValEnumerators = new List<List<double>>(); 59 53 List<List<double>> estimatedTrainingClassValEnumerators = new List<List<double>>(); 60 IDiscriminantFunctionClassificationSolution discriminantSolution; 61 foreach (var solution in classificationSolutions) { 62 discriminantSolution = (IDiscriminantFunctionClassificationSolution)solution; 63 estimatedTrainingValEnumerators.Add(discriminantSolution.EstimatedTrainingValues.ToList()); 64 estimatedTrainingClassValEnumerators.Add(discriminantSolution.EstimatedTrainingClassValues.ToList()); 54 foreach (var solution in discriminantSolutions) { 55 estimatedTrainingValEnumerators.Add(solution.EstimatedTrainingValues.ToList()); 56 estimatedTrainingClassValEnumerators.Add(solution.EstimatedTrainingClassValues.ToList()); 65 57 } 66 58 67 List<double> weights = Enumerable.Repeat<double>(0, classificationSolutions.Count()).ToList<double>();59 List<double> weights = Enumerable.Repeat<double>(0, discriminantSolutions.Count()).ToList<double>(); 68 60 69 IClassificationProblemData problemData = classificationSolutions.ElementAt(0).ProblemData;61 IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData; 70 62 List<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(); 71 63 List<double> trainingVal = GetValues(targetValues, problemData.TrainingIndizes).ToList(); … … 93 85 return weights; 94 86 } 95 96 private IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {97 return from i in indizes98 select targetValues[i];99 }100 87 } 101 88 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/PointCertaintyWeightCalculator.cs
r7504 r7531 28 28 29 29 namespace HeuristicLab.Problems.DataAnalysis { 30 /// <summary>31 ///32 /// </summary>33 30 [StorableClass] 34 31 [Item("PointCertaintyWeightCalculator", "")] 35 public class PointCertaintyWeightCalculator : WeightCalculator {32 public class PointCertaintyWeightCalculator : DiscriminantClassificationWeightCalculator { 36 33 37 34 public PointCertaintyWeightCalculator() 38 35 : base() { 39 36 } 40 41 37 [StorableConstructor] 42 38 protected PointCertaintyWeightCalculator(bool deserializing) : base(deserializing) { } … … 44 40 : base(original, cloner) { 45 41 } 46 47 42 public override IDeepCloneable Clone(Cloner cloner) { 48 43 return new PointCertaintyWeightCalculator(this, cloner); 49 44 } 50 45 51 protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) { 52 if (classificationSolutions.Count <= 0) 53 return new List<double>(); 54 55 if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 56 return Enumerable.Repeat<double>(1, classificationSolutions.Count); 57 58 ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions = new ItemCollection<IDiscriminantFunctionClassificationSolution>(); 59 foreach (var solution in classificationSolutions) { 60 discriminantSolutions.Add((IDiscriminantFunctionClassificationSolution)solution); 61 } 62 46 protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 63 47 List<double> weights = new List<double>(); 64 IClassificationProblemData problemData = classificationSolutions.ElementAt(0).ProblemData;48 IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData; 65 49 IEnumerable<double> targetValues = GetValues(problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(), problemData.TrainingIndizes); 66 50 IEnumerator<double> trainingValues; … … 87 71 return weights; 88 72 } 89 90 private IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {91 return from i in indizes92 select targetValues[i];93 }94 73 } 95 74 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleSolutionWeightCalculator.cs
r7504 r7531 25 25 namespace HeuristicLab.Problems.DataAnalysis.Interfaces.Classification { 26 26 public interface IClassificationEnsembleSolutionWeightCalculator : INamedItem { 27 IEnumerable<double> CalculateNormalizedWeights(ItemCollection<IClassificationSolution> classificationSolutions); 27 void CalculateNormalizedWeights(ItemCollection<IClassificationSolution> classificationSolutions); 28 IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows); 28 29 } 29 30 }
Note: See TracChangeset
for help on using the changeset viewer.