1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  *


5  * This file is part of HeuristicLab.


6  *


7  * HeuristicLab is free software: you can redistribute it and/or modify


8  * it under the terms of the GNU General Public License as published by


9  * the Free Software Foundation, either version 3 of the License, or


10  * (at your option) any later version.


11  *


12  * HeuristicLab is distributed in the hope that it will be useful,


13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


15  * GNU General Public License for more details.


16  *


17  * You should have received a copy of the GNU General Public License


18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


19  */


20  #endregion


21 


22  using System.Collections.Generic;


23  using System.Linq;


24  using HeuristicLab.Common;


25  using HeuristicLab.Data;


26  using HeuristicLab.Optimization;


27  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


28 


29  namespace HeuristicLab.Problems.DataAnalysis {


30  [StorableClass]


31  public abstract class ClassificationSolutionBase : DataAnalysisSolution, IClassificationSolution {


32  private const string TrainingAccuracyResultName = "Accuracy (training)";


33  private const string TestAccuracyResultName = "Accuracy (test)";


34  private const string TrainingNormalizedGiniCoefficientResultName = "Normalized Gini Coefficient (training)";


35  private const string TestNormalizedGiniCoefficientResultName = "Normalized Gini Coefficient (test)";


36 


37  public new IClassificationModel Model {


38  get { return (IClassificationModel)base.Model; }


39  protected set { base.Model = value; }


40  }


41 


42  public new IClassificationProblemData ProblemData {


43  get { return (IClassificationProblemData)base.ProblemData; }


44  set { base.ProblemData = value; }


45  }


46 


47  #region Results


48  public double TrainingAccuracy {


49  get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; }


50  private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; }


51  }


52  public double TestAccuracy {


53  get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; }


54  private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; }


55  }


56  public double TrainingNormalizedGiniCoefficient {


57  get { return ((DoubleValue)this[TrainingNormalizedGiniCoefficientResultName].Value).Value; }


58  protected set { ((DoubleValue)this[TrainingNormalizedGiniCoefficientResultName].Value).Value = value; }


59  }


60  public double TestNormalizedGiniCoefficient {


61  get { return ((DoubleValue)this[TestNormalizedGiniCoefficientResultName].Value).Value; }


62  protected set { ((DoubleValue)this[TestNormalizedGiniCoefficientResultName].Value).Value = value; }


63  }


64  #endregion


65 


66  [StorableConstructor]


67  protected ClassificationSolutionBase(bool deserializing) : base(deserializing) { }


68  protected ClassificationSolutionBase(ClassificationSolutionBase original, Cloner cloner)


69  : base(original, cloner) {


70  }


71  protected ClassificationSolutionBase(IClassificationModel model, IClassificationProblemData problemData)


72  : base(model, problemData) {


73  Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));


74  Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));


75  Add(new Result(TrainingNormalizedGiniCoefficientResultName, "Normalized Gini coefficient of the model on the training partition.", new DoubleValue()));


76  Add(new Result(TestNormalizedGiniCoefficientResultName, "Normalized Gini coefficient of the model on the test partition.", new DoubleValue()));


77  }


78 


79  [StorableHook(HookType.AfterDeserialization)]


80  private void AfterDeserialization() {


81  if (!this.ContainsKey(TrainingNormalizedGiniCoefficientResultName))


82  Add(new Result(TrainingNormalizedGiniCoefficientResultName, "Normalized Gini coefficient of the model on the training partition.", new DoubleValue()));


83  if (!this.ContainsKey(TestNormalizedGiniCoefficientResultName))


84  Add(new Result(TestNormalizedGiniCoefficientResultName, "Normalized Gini coefficient of the model on the test partition.", new DoubleValue()));


85  }


86 


87  protected void CalculateResults() {


88  double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values


89  double[] originalTrainingClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes).ToArray();


90  double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values


91  double[] originalTestClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndizes).ToArray();


92 


93  OnlineCalculatorError errorState;


94  double trainingAccuracy = OnlineAccuracyCalculator.Calculate(originalTrainingClassValues, estimatedTrainingClassValues, out errorState);


95  if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN;


96  double testAccuracy = OnlineAccuracyCalculator.Calculate(originalTestClassValues, estimatedTestClassValues, out errorState);


97  if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN;


98 


99  TrainingAccuracy = trainingAccuracy;


100  TestAccuracy = testAccuracy;


101 


102  double trainingNormalizedGini = NormalizedGiniCalculator.Calculate(originalTrainingClassValues, estimatedTrainingClassValues, out errorState);


103  if (errorState != OnlineCalculatorError.None) trainingNormalizedGini = double.NaN;


104  double testNormalizedGini = NormalizedGiniCalculator.Calculate(originalTestClassValues, estimatedTestClassValues, out errorState);


105  if (errorState != OnlineCalculatorError.None) testNormalizedGini = double.NaN;


106 


107  TrainingNormalizedGiniCoefficient = trainingNormalizedGini;


108  TestNormalizedGiniCoefficient = testNormalizedGini;


109  }


110 


111  public abstract IEnumerable<double> EstimatedClassValues { get; }


112  public abstract IEnumerable<double> EstimatedTrainingClassValues { get; }


113  public abstract IEnumerable<double> EstimatedTestClassValues { get; }


114 


115  public abstract IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows);


116  }


117  }

