Changeset 6589 for trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs
- Timestamp:
- 07/25/11 16:54:15 (13 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs
r6411 r6589 23 23 using System.Linq; 24 24 using HeuristicLab.Common; 25 using HeuristicLab.Data;26 using HeuristicLab.Optimization;27 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 26 … … 32 30 /// </summary> 33 31 [StorableClass] 34 public class ClassificationSolution : DataAnalysisSolution, IClassificationSolution { 35 private const string TrainingAccuracyResultName = "Accuracy (training)"; 36 private const string TestAccuracyResultName = "Accuracy (test)"; 37 38 public new IClassificationModel Model { 39 get { return (IClassificationModel)base.Model; } 40 protected set { base.Model = value; } 41 } 42 43 public new IClassificationProblemData ProblemData { 44 get { return (IClassificationProblemData)base.ProblemData; } 45 protected set { base.ProblemData = value; } 46 } 47 48 public double TrainingAccuracy { 49 get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; } 50 private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; } 51 } 52 53 public double TestAccuracy { 54 get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; } 55 private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; } 56 } 57 32 public abstract class ClassificationSolution : ClassificationSolutionBase { 58 33 [StorableConstructor] 59 34 protected ClassificationSolution(bool deserializing) : base(deserializing) { } … … 63 38 public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData) 64 39 : base(model, problemData) { 65 Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));66 Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));67 CalculateResults();68 40 } 69 41 70 public override IDeepCloneable Clone(Cloner cloner) { 71 return new ClassificationSolution(this, cloner); 42 public override IEnumerable<double> EstimatedClassValues { 43 get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 44 } 45 public override IEnumerable<double> EstimatedTrainingClassValues { 46 get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); } 47 } 48 public override IEnumerable<double> EstimatedTestClassValues { 49 get { return GetEstimatedClassValues(ProblemData.TestIndizes); } 72 50 } 73 51 74 protected override void RecalculateResults() { 75 CalculateResults(); 76 } 77 78 private void CalculateResults() { 79 double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values 80 IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 81 double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values 82 IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 83 84 OnlineCalculatorError errorState; 85 double trainingAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues, out errorState); 86 if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN; 87 double testAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTestClassValues, originalTestClassValues, out errorState); 88 if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN; 89 90 TrainingAccuracy = trainingAccuracy; 91 TestAccuracy = testAccuracy; 92 } 93 94 public virtual IEnumerable<double> EstimatedClassValues { 95 get { 96 return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); 97 } 98 } 99 100 public virtual IEnumerable<double> EstimatedTrainingClassValues { 101 get { 102 return GetEstimatedClassValues(ProblemData.TrainingIndizes); 103 } 104 } 105 106 public virtual IEnumerable<double> EstimatedTestClassValues { 107 get { 108 return GetEstimatedClassValues(ProblemData.TestIndizes); 109 } 110 } 111 112 public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 52 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 113 53 return Model.GetEstimatedClassValues(ProblemData.Dataset, rows); 114 54 }
Note: See TracChangeset
for help on using the changeset viewer.