Changeset 7459 for branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
- Timestamp:
- 02/06/12 16:55:27 (13 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r7259 r7459 28 28 using HeuristicLab.Data; 29 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 30 using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification; 30 31 31 32 namespace HeuristicLab.Problems.DataAnalysis { … … 55 56 private Dictionary<IClassificationModel, IntRange> testPartitions; 56 57 58 private IClassificationEnsembleSolutionWeightCalculator weightCalculator; 59 57 60 [StorableConstructor] 58 61 private ClassificationEnsembleSolution(bool deserializing) … … 95 98 classificationSolutions = new ItemCollection<IClassificationSolution>(); 96 99 100 weightCalculator = new AccuracyWeightCalculator(); 101 97 102 RegisterClassificationSolutionsEventHandler(); 98 103 } … … 153 158 .ToList(); 154 159 var rowsEnumerator = rows.GetEnumerator(); 160 IEnumerable<double> weights = weightCalculator.CalculateWeights(classificationSolutions); 155 161 // aggregate to make sure that MoveNext is called for all enumerators 156 162 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { … … 160 166 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 161 167 select pair.EstimatedValuesEnumerator; 162 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current) );168 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights); 163 169 } 164 170 } … … 172 178 .ToList(); 173 179 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 180 IEnumerable<double> weights = weightCalculator.CalculateWeights(classificationSolutions); 174 181 // aggregate to make sure that MoveNext is called for all enumerators 175 182 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { … … 180 187 select pair.EstimatedValuesEnumerator; 181 188 182 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current) );189 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights); 183 190 } 184 191 } … … 196 203 197 204 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 205 IEnumerable<double> weights = weightCalculator.CalculateWeights(classificationSolutions); 198 206 return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows) 199 select AggregateEstimatedClassValues(xs );207 select AggregateEstimatedClassValues(xs, weights); 200 208 } 201 209 … … 212 220 } 213 221 214 private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) { 215 return estimatedClassValues 216 .GroupBy(x => x) 217 .OrderBy(g => -g.Count()) 218 .Select(g => g.Key) 219 .DefaultIfEmpty(double.NaN) 220 .First(); 222 private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues, IEnumerable<double> weights) { 223 IDictionary<double, double> weightSum = new Dictionary<double, double>(); 224 for (int i = 0; i < estimatedClassValues.Count(); i++) { 225 if (!weightSum.ContainsKey(estimatedClassValues.ElementAt(i))) 226 weightSum[estimatedClassValues.ElementAt(i)] = 0.0; 227 weightSum[estimatedClassValues.ElementAt(i)] += weights.ElementAt(i); 228 } 229 if (weightSum.Count <= 0) 230 return double.NaN; 231 var max = weightSum.Max(x => x.Value); 232 max = weightSum 233 .Where(x => x.Value.Equals(max)) 234 .Select(x => x.Key) 235 .First(); 236 return max; 237 //old code 238 //return weightSum 239 // .Where(x => x.Value.Equals(max)) 240 // .Select(x => x.Key) 241 // .First(); 242 //return estimatedClassValues 243 //.GroupBy(x => x) 244 //.OrderBy(g => -g.Count()) 245 //.Select(g => g.Key) 246 //.DefaultIfEmpty(double.NaN) 247 //.First(); 221 248 } 222 249 #endregion
Note: See TracChangeset
for help on using the changeset viewer.