1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  *


5  * This file is part of HeuristicLab.


6  *


7  * HeuristicLab is free software: you can redistribute it and/or modify


8  * it under the terms of the GNU General Public License as published by


9  * the Free Software Foundation, either version 3 of the License, or


10  * (at your option) any later version.


11  *


12  * HeuristicLab is distributed in the hope that it will be useful,


13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


15  * GNU General Public License for more details.


16  *


17  * You should have received a copy of the GNU General Public License


18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


19  */


20  #endregion


21 


22  using System.Collections.Generic;


23  using System.Linq;


24  using HeuristicLab.Common;


25  using HeuristicLab.Core;


26  using HeuristicLab.Data;


27  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


28  using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;


29 


30  namespace HeuristicLab.Problems.DataAnalysis {


31  /// <summary>


32  /// Base class for weight calculators for classification solutions in an ensemble.


33  /// </summary>


34  [StorableClass]


35  public abstract class ClassificationWeightCalculator : NamedItem, IClassificationEnsembleSolutionWeightCalculator {


36  [StorableConstructor]


37  protected ClassificationWeightCalculator(bool deserializing) : base(deserializing) { }


38  protected ClassificationWeightCalculator(ClassificationWeightCalculator original, Cloner cloner)


39  : base(original, cloner) {


40  }


41  public ClassificationWeightCalculator()


42  : base() {


43  this.name = ItemName;


44  this.description = ItemDescription;


45  }


46 


47  private IDictionary<IClassificationSolution, double> weights;


48 


49  /// <summary>


50  /// calls CalculateWeights and removes negative weights


51  /// </summary>


52  /// <param name="classificationSolutions"></param>


53  /// <returns>weights which are equal or bigger than zero</returns>


54  public void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions) {


55  List<double> weights = new List<double>();


56  if (classificationSolutions.Count() > 0) {


57  foreach (var weight in CalculateWeights(classificationSolutions)) {


58  weights.Add(weight >= 0 ? weight : 0);


59  }


60  }


61  double sum = weights.Sum();


62  this.weights = classificationSolutions.Zip(weights, (sol, wei) => new { sol, wei }).ToDictionary(x => x.sol, x => x.wei / sum);


63  }


64 


65  protected abstract IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions);


66 


67  #region delegate CheckPoint


68  public CheckPoint GetTestClassDelegate() {


69  return PointInTest;


70  }


71  public CheckPoint GetTrainingClassDelegate() {


72  return PointInTraining;


73  }


74  public CheckPoint GetAllClassDelegate() {


75  return AllPoints;


76  }


77  #endregion


78 


79  public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {


80  return from xs in GetEstimatedClassValues(solutions, dataset, rows, handler)


81  select AggregateEstimatedClassValues(xs);


82  }


83 


84  protected double AggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues) {


85  IDictionary<double, double> weightSum = new Dictionary<double, double>();


86  foreach (var item in estimatedClassValues) {


87  if (!weightSum.ContainsKey(item.Value))


88  weightSum[item.Value] = 0.0;


89  weightSum[item.Value] += weights[item.Key];


90  }


91  if (weightSum.Count <= 0)


92  return double.NaN;


93  var max = weightSum.Max(x => x.Value);


94  max = weightSum


95  .Where(x => x.Value.Equals(max))


96  .Select(x => x.Key)


97  .First();


98  return max;


99  }


100 


101  protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {


102  var estimatedValuesEnumerators = (from solution in solutions


103  select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })


104  .ToList();


105 


106  var rowEnumerator = rows.GetEnumerator();


107  while (rowEnumerator.MoveNext() & estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) {


108  yield return (from enumerator in estimatedValuesEnumerators


109  where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)


110  select enumerator)


111  .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current);


112  }


113  }


114 


115  #region Helper


116  protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {


117  return from i in indizes


118  select targetValues[i];


119  }


120  protected bool PointInTraining(IClassificationProblemData problemData, int point) {


121  IntRange trainingPartition = problemData.TrainingPartition;


122  IntRange testPartition = problemData.TestPartition;


123  return (trainingPartition.Start <= point && point < trainingPartition.End)


124  && !(testPartition.Start <= point && point < testPartition.End);


125  }


126  protected bool PointInTest(IClassificationProblemData problemData, int point) {


127  IntRange testPartition = problemData.TestPartition;


128  return testPartition.Start <= point && point < testPartition.End;


129  }


130  protected bool AllPoints(IClassificationProblemData problemData, int point) {


131  return true;


132  }


133  #endregion


134  }


135  }

