#region License Information /* HeuristicLab * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; namespace HeuristicLab.Problems.DataAnalysis { /// /// Base class for weight calculators for classification solutions in an ensemble. /// [StorableClass] public abstract class ClassificationWeightCalculator : NamedItem, IClassificationEnsembleSolutionWeightCalculator { [StorableConstructor] protected ClassificationWeightCalculator(bool deserializing) : base(deserializing) { } protected ClassificationWeightCalculator(ClassificationWeightCalculator original, Cloner cloner) : base(original, cloner) { } public ClassificationWeightCalculator() : base() { this.name = ItemName; this.description = ItemDescription; } private IDictionary weights; /// /// calls CalculateWeights and removes negative weights /// /// /// weights which are equal or bigger than zero public void CalculateNormalizedWeights(IEnumerable classificationSolutions) { List weights = new List(); if (classificationSolutions.Count() > 0) { foreach (var weight in CalculateWeights(classificationSolutions)) { weights.Add(weight >= 0 ? weight : 0); } } double sum = weights.Sum(); this.weights = classificationSolutions.Zip(weights, (sol, wei) => new { sol, wei }).ToDictionary(x => x.sol, x => x.wei / sum); } protected abstract IEnumerable CalculateWeights(IEnumerable classificationSolutions); #region delegate CheckPoint public CheckPoint GetTestClassDelegate() { return PointInTest; } public CheckPoint GetTrainingClassDelegate() { return PointInTraining; } public CheckPoint GetAllClassDelegate() { return AllPoints; } #endregion public virtual IEnumerable AggregateEstimatedClassValues(IEnumerable solutions, Dataset dataset, IEnumerable rows, CheckPoint handler) { return from xs in GetEstimatedClassValues(solutions, dataset, rows, handler) select AggregateEstimatedClassValues(xs); } protected double AggregateEstimatedClassValues(IDictionary estimatedClassValues) { IDictionary weightSum = new Dictionary(); foreach (var item in estimatedClassValues) { if (!weightSum.ContainsKey(item.Value)) weightSum[item.Value] = 0.0; weightSum[item.Value] += weights[item.Key]; } if (weightSum.Count <= 0) return double.NaN; var max = weightSum.Max(x => x.Value); max = weightSum .Where(x => x.Value.Equals(max)) .Select(x => x.Key) .First(); return max; } protected IEnumerable> GetEstimatedClassValues(IEnumerable solutions, Dataset dataset, IEnumerable rows, CheckPoint handler) { var estimatedValuesEnumerators = (from solution in solutions select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() }) .ToList(); var rowEnumerator = rows.GetEnumerator(); while (rowEnumerator.MoveNext() & estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) { yield return (from enumerator in estimatedValuesEnumerators where handler(enumerator.Solution.ProblemData, rowEnumerator.Current) select enumerator) .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current); } } public virtual double GetConfidence(IEnumerable solutions, int index, double estimatedClassValue, CheckPoint handler) { if (solutions.Count() < 1) return double.NaN; Dataset dataset = solutions.First().ProblemData.Dataset; var correctSolutions = solutions.Select(s => new { Solution = s, Values = s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First() }) .Where(a => handler(a.Solution.ProblemData, index) && a.Values.Equals(estimatedClassValue)) .Select(a => a.Solution); return (from sol in correctSolutions select weights[sol]).Sum(); } public virtual IEnumerable GetConfidence(IEnumerable solutions, IEnumerable indices, IEnumerable estimatedClassValue, CheckPoint handler) { if (solutions.Count() < 1) return Enumerable.Repeat(double.NaN, indices.Count()); List indicesList = indices.ToList(); Dataset dataset = solutions.First().ProblemData.Dataset; Dictionary solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indicesList).ToArray()); double[] estimatedClassValueArr = estimatedClassValue.ToArray(); double[] confidences = new double[indicesList.Count]; for (int i = 0; i < indicesList.Count; i++) { var correctSolutions = solValues.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i])); confidences[i] = (from sol in correctSolutions where handler(sol.Key.ProblemData, indicesList[i]) select weights[sol.Key]).Sum(); } return confidences; } #region Helper protected bool PointInTraining(IClassificationProblemData problemData, int point) { return problemData.IsTrainingSample(point); } protected bool PointInTest(IClassificationProblemData problemData, int point) { return problemData.IsTestSample(point); } protected bool AllPoints(IClassificationProblemData problemData, int point) { return true; } #endregion } }