#region License Information
/* HeuristicLab
* Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
namespace HeuristicLab.Problems.DataAnalysis {
///
/// Base class for weight calculators for classification solutions in an ensemble.
///
[StorableClass]
public abstract class DiscriminantClassificationWeightCalculator : ClassificationWeightCalculator {
[StorableConstructor]
protected DiscriminantClassificationWeightCalculator(bool deserializing) : base(deserializing) { }
protected DiscriminantClassificationWeightCalculator(DiscriminantClassificationWeightCalculator original, Cloner cloner)
: base(original, cloner) {
}
public DiscriminantClassificationWeightCalculator()
: base() {
}
protected override IEnumerable CalculateWeights(IEnumerable classificationSolutions) {
if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution))
return Enumerable.Repeat(1.0, classificationSolutions.Count());
IEnumerable discriminantSolutions = classificationSolutions.Cast();
return DiscriminantCalculateWeights(discriminantSolutions);
}
protected abstract IEnumerable DiscriminantCalculateWeights(IEnumerable discriminantSolutions);
public override IEnumerable AggregateEstimatedClassValues(IEnumerable solutions, Dataset dataset, IEnumerable rows, CheckPoint handler) {
if (!solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
return Enumerable.Repeat(0.0, rows.Count());
IEnumerable discriminantSolutions = solutions.Cast();
IEnumerable> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler);
IEnumerable> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler);
return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values })
select DiscriminantAggregateEstimatedClassValues(zip.ClassValues, zip.Values);
}
protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary estimatedClassValues, IDictionary estimatedValues) {
return AggregateEstimatedClassValues(estimatedClassValues);
}
protected IEnumerable> GetEstimatedValues(IEnumerable solutions, Dataset dataset, IEnumerable rows, CheckPoint handler) {
if (!solutions.Any()) yield break;
var estimatedValuesEnumerators = (from solution in solutions
select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
.ToList();
var rowEnumerator = rows.GetEnumerator();
while (rowEnumerator.MoveNext() && estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {
yield return (from enumerator in estimatedValuesEnumerators
where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
select enumerator)
.ToDictionary(x => (IClassificationSolution)x.Solution, x => x.EstimatedValuesEnumerator.Current);
}
}
}
}