Changeset 6618 for branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs
- Timestamp:
- 08/01/11 17:48:53 (13 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs
r6415 r6618 23 23 using System.Linq; 24 24 using HeuristicLab.Common; 25 using HeuristicLab.Data;26 using HeuristicLab.Optimization;27 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 26 … … 32 30 /// </summary> 33 31 [StorableClass] 34 public class ClassificationSolution : DataAnalysisSolution, IClassificationSolution { 35 private const string TrainingAccuracyResultName = "Accuracy (training)"; 36 private const string TestAccuracyResultName = "Accuracy (test)"; 37 38 public new IClassificationModel Model { 39 get { return (IClassificationModel)base.Model; } 40 protected set { base.Model = value; } 41 } 42 43 public new IClassificationProblemData ProblemData { 44 get { return (IClassificationProblemData)base.ProblemData; } 45 protected set { base.ProblemData = value; } 46 } 47 48 public double TrainingAccuracy { 49 get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; } 50 private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; } 51 } 52 53 public double TestAccuracy { 54 get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; } 55 private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; } 56 } 32 public abstract class ClassificationSolution : ClassificationSolutionBase { 33 protected readonly Dictionary<int, double> evaluationCache; 57 34 58 35 [StorableConstructor] 59 protected ClassificationSolution(bool deserializing) : base(deserializing) { } 36 protected ClassificationSolution(bool deserializing) 37 : base(deserializing) { 38 evaluationCache = new Dictionary<int, double>(); 39 } 60 40 protected ClassificationSolution(ClassificationSolution original, Cloner cloner) 61 41 : base(original, cloner) { 42 evaluationCache = new Dictionary<int, double>(original.evaluationCache); 62 43 } 63 44 public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData) 64 45 : base(model, problemData) { 65 Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue())); 66 Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue())); 67 CalculateResults(); 46 evaluationCache = new Dictionary<int, double>(); 68 47 } 69 48 70 public override IDeepCloneable Clone(Cloner cloner) { 71 return new ClassificationSolution(this, cloner); 49 public override IEnumerable<double> EstimatedClassValues { 50 get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 51 } 52 public override IEnumerable<double> EstimatedTrainingClassValues { 53 get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); } 54 } 55 public override IEnumerable<double> EstimatedTestClassValues { 56 get { return GetEstimatedClassValues(ProblemData.TestIndizes); } 72 57 } 73 58 74 protected override void RecalculateResults() { 75 CalculateResults(); 59 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 60 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 61 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 62 var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 63 64 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 65 evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 66 } 67 68 return rows.Select(row => evaluationCache[row]); 76 69 } 77 70 78 private void CalculateResults() { 79 double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values 80 IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 81 double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values 82 IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 83 84 OnlineCalculatorError errorState; 85 double trainingAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues, out errorState); 86 if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN; 87 double testAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTestClassValues, originalTestClassValues, out errorState); 88 if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN; 89 90 TrainingAccuracy = trainingAccuracy; 91 TestAccuracy = testAccuracy; 71 protected override void OnProblemDataChanged() { 72 evaluationCache.Clear(); 73 base.OnProblemDataChanged(); 92 74 } 93 75 94 public virtual IEnumerable<double> EstimatedClassValues { 95 get { 96 return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); 97 } 98 } 99 100 public virtual IEnumerable<double> EstimatedTrainingClassValues { 101 get { 102 return GetEstimatedClassValues(ProblemData.TrainingIndizes); 103 } 104 } 105 106 public virtual IEnumerable<double> EstimatedTestClassValues { 107 get { 108 return GetEstimatedClassValues(ProblemData.TestIndizes); 109 } 110 } 111 112 public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 113 return Model.GetEstimatedClassValues(ProblemData.Dataset, rows); 76 protected override void OnModelChanged() { 77 evaluationCache.Clear(); 78 base.OnModelChanged(); 114 79 } 115 80 }
Note: See TracChangeset
for help on using the changeset viewer.