Changeset 7531 for branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
- Timestamp:
- 02/27/12 16:11:34 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r7504 r7531 51 51 } 52 52 53 public IEnumerable<double> Weights {54 get { return new List<double>(weights); }55 }56 57 53 [Storable] 58 54 private Dictionary<IClassificationModel, IntRange> trainingPartitions; … … 60 56 private Dictionary<IClassificationModel, IntRange> testPartitions; 61 57 62 private IEnumerable<double> weights;63 64 58 private IClassificationEnsembleSolutionWeightCalculator weightCalculator; 65 59 … … 68 62 if (value != null) { 69 63 weightCalculator = value; 70 weight s = weights = weightCalculator.CalculateNormalizedWeights(classificationSolutions);64 weightCalculator.CalculateNormalizedWeights(classificationSolutions); 71 65 if (!ProblemData.IsEmpty) 72 66 RecalculateResults(); … … 169 163 #region Evaluation 170 164 public override IEnumerable<double> EstimatedTrainingClassValues { 171 get { 172 var rows = ProblemData.TrainingIndizes; 173 var estimatedValuesEnumerators = (from model in Model.Models 174 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 175 .ToList(); 176 var rowsEnumerator = rows.GetEnumerator(); 177 // aggregate to make sure that MoveNext is called for all enumerators 178 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 179 int currentRow = rowsEnumerator.Current; 180 181 var selectedEnumerators = from pair in estimatedValuesEnumerators 182 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 183 select pair.EstimatedValuesEnumerator; 184 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights); 185 } 186 } 165 get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TrainingIndizes); } 187 166 } 188 167 189 168 public override IEnumerable<double> EstimatedTestClassValues { 190 get { 191 var rows = ProblemData.TestIndizes; 192 var estimatedValuesEnumerators = (from model in Model.Models 193 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 194 .ToList(); 195 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 196 // aggregate to make sure that MoveNext is called for all enumerators 197 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 198 int currentRow = rowsEnumerator.Current; 199 200 var selectedEnumerators = from pair in estimatedValuesEnumerators 201 where RowIsTestForModel(currentRow, pair.Model) 202 select pair.EstimatedValuesEnumerator; 203 204 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current), weights); 205 } 206 } 169 get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TestIndizes); } 207 170 } 208 171 … … 218 181 219 182 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 220 return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows) 221 select AggregateEstimatedClassValues(xs, weights); 183 return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, rows); 222 184 } 223 185 … … 232 194 select enumerator.Current; 233 195 } 234 }235 236 private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues, IEnumerable<double> weights) {237 IDictionary<double, double> weightSum = new Dictionary<double, double>();238 for (int i = 0; i < estimatedClassValues.Count(); i++) {239 if (!weightSum.ContainsKey(estimatedClassValues.ElementAt(i)))240 weightSum[estimatedClassValues.ElementAt(i)] = 0.0;241 weightSum[estimatedClassValues.ElementAt(i)] += weights.ElementAt(i);242 }243 if (weightSum.Count <= 0)244 return double.NaN;245 var max = weightSum.Max(x => x.Value);246 max = weightSum247 .Where(x => x.Value.Equals(max))248 .Select(x => x.Key)249 .First();250 return max;251 //old code252 //return weightSum253 // .Where(x => x.Value.Equals(max))254 // .Select(x => x.Key)255 // .First();256 //return estimatedClassValues257 //.GroupBy(x => x)258 //.OrderBy(g => -g.Count())259 //.Select(g => g.Key)260 //.DefaultIfEmpty(double.NaN)261 //.First();262 196 } 263 197 #endregion … … 316 250 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 317 251 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 318 weight s = weightCalculator.CalculateNormalizedWeights(classificationSolutions);252 weightCalculator.CalculateNormalizedWeights(classificationSolutions); 319 253 } 320 254 … … 324 258 trainingPartitions.Remove(solution.Model); 325 259 testPartitions.Remove(solution.Model); 326 weight s = weightCalculator.CalculateNormalizedWeights(classificationSolutions);260 weightCalculator.CalculateNormalizedWeights(classificationSolutions); 327 261 } 328 262 }
Note: See TracChangeset
for help on using the changeset viewer.