#region License Information /* HeuristicLab * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification; namespace HeuristicLab.Problems.DataAnalysis { /// /// Base class for weight calculators for classification solutions in an ensemble. /// [StorableClass] public abstract class DiscriminantClassificationWeightCalculator : ClassificationWeightCalculator { [StorableConstructor] protected DiscriminantClassificationWeightCalculator(bool deserializing) : base(deserializing) { } protected DiscriminantClassificationWeightCalculator(DiscriminantClassificationWeightCalculator original, Cloner cloner) : base(original, cloner) { } public DiscriminantClassificationWeightCalculator() : base() { } protected override IEnumerable CalculateWeights(IEnumerable classificationSolutions) { if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution)) return Enumerable.Repeat(1.0, classificationSolutions.Count()); IEnumerable discriminantSolutions = classificationSolutions.Cast(); return DiscriminantCalculateWeights(discriminantSolutions); } protected abstract IEnumerable DiscriminantCalculateWeights(IEnumerable discriminantSolutions); public override IEnumerable AggregateEstimatedClassValues(IEnumerable solutions, Dataset dataset, IEnumerable rows, CheckPoint handler) { if (!solutions.All(x => x is IDiscriminantFunctionClassificationSolution)) return Enumerable.Repeat(0.0, rows.Count()); IEnumerable discriminantSolutions = solutions.Cast(); IEnumerable> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler); IEnumerable> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler); return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values }) select DiscriminantAggregateEstimatedClassValues(zip.ClassValues, zip.Values); } protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary estimatedClassValues, IDictionary estimatedValues) { return AggregateEstimatedClassValues(estimatedClassValues); } protected IEnumerable> GetEstimatedValues(IEnumerable solutions, Dataset dataset, IEnumerable rows, CheckPoint handler) { if (!solutions.Any()) yield break; var estimatedValuesEnumerators = (from solution in solutions select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() }) .ToList(); var rowEnumerator = rows.GetEnumerator(); while (rowEnumerator.MoveNext() && estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) { yield return (from enumerator in estimatedValuesEnumerators where handler(enumerator.Solution.ProblemData, rowEnumerator.Current) select enumerator) .ToDictionary(x => (IClassificationSolution)x.Solution, x => x.EstimatedValuesEnumerator.Current); } } } }