Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/10/11 10:00:09 (14 years ago)
Author:
gkronber
Message:

#1418 Implemented classes for classification based on a discriminant function and thresholds and implemented interfaces and base classes for clustering.

Location:
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4
Files:
17 added
12 edited

Legend:

Unmodified
Added
Removed
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/ClassificationProblem.cs

    r5625 r5649  
    2626namespace HeuristicLab.Problems.DataAnalysis {
    2727  [StorableClass]
    28   [Item("ClassificationProblem", "")]
     28  [Item("ClassificationProblem", "A general classification problem.")]
    2929  [Creatable("Problems")]
    3030  public class ClassificationProblem : DataAnalysisProblem<IClassificationProblemData>, IClassificationProblem {
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/ClassificationProblemData.cs

    r5601 r5649  
    194194    #endregion
    195195
    196     #region propeties
     196    #region properties
    197197    public string TargetVariable {
    198198      get { return TargetVariableParameter.Value.Value; }
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/ClassificationSolution.cs

    r5624 r5649  
    3737  [StorableClass]
    3838  public abstract class ClassificationSolution : DataAnalysisSolution, IClassificationSolution {
    39     private const string ThresholdsResultsName = "Thresholds";
     39    private const string TrainingAccuracyResultName = "Accuracy (training)";
     40    private const string TestAccuracyResultName = "Accuracy (test)";
    4041    [StorableConstructor]
    4142    protected ClassificationSolution(bool deserializing) : base(deserializing) { }
     
    4546    public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData)
    4647      : base(model, problemData) {
    47       DoubleArray thresholds = new DoubleArray();
    48       Add(new Result(ThresholdsResultsName, "The threshold values for class boundaries.", thresholds));
    49       thresholds.Reset += new EventHandler(thresholds_Reset);
    50       thresholds.ItemChanged += new EventHandler<EventArgs<int>>(thresholds_ItemChanged);
     48      double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values
     49      IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
     50      double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values
     51      IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
     52
     53      double trainingAccuracy = OnlineAccuracyEvaluator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues);
     54      double testAccuracy = OnlineAccuracyEvaluator.Calculate(estimatedTestClassValues, originalTestClassValues);
     55
     56      Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue(trainingAccuracy)));
     57      Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue(testAccuracy)));
    5158    }
    5259
     
    6168    }
    6269
    63     public virtual IEnumerable<double> EstimatedValues {
    64       get {
    65         return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
    66       }
    67     }
    68 
    69     public virtual IEnumerable<double> EstimatedTrainingValues {
    70       get {
    71         return GetEstimatedValues(ProblemData.TrainingIndizes);
    72       }
    73     }
    74 
    75     public virtual IEnumerable<double> EstimatedTestValues {
    76       get {
    77         return GetEstimatedValues(ProblemData.TestIndizes);
    78       }
    79     }
    80 
    81     public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    82       return Model.GetEstimatedValues(ProblemData, rows);
    83     }
    84 
    85     public IEnumerable<double> Thresholds {
    86       get {
    87         return (DoubleArray)this[ThresholdsResultsName].Value;
    88       }
    89     }
    90 
    91     public IEnumerable<double> EstimatedClassValues {
     70    public virtual IEnumerable<double> EstimatedClassValues {
    9271      get {
    9372        return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
     
    9574    }
    9675
    97     public IEnumerable<double> EstimatedTrainingClassValues {
     76    public virtual IEnumerable<double> EstimatedTrainingClassValues {
    9877      get {
    9978        return GetEstimatedClassValues(ProblemData.TrainingIndizes);
     
    10180    }
    10281
    103     public IEnumerable<double> EstimatedTestClassValues {
     82    public virtual IEnumerable<double> EstimatedTestClassValues {
    10483      get {
    10584        return GetEstimatedClassValues(ProblemData.TestIndizes);
     
    10786    }
    10887
    109     public IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    110       return Model.GetEstimatedClassValues(ProblemData, rows);
    111     }
    112 
    113     #endregion
    114     #region events
    115     private void thresholds_ItemChanged(object sender, EventArgs<int> e) {
    116       OnThresholdsChanged(EventArgs.Empty);
    117     }
    118 
    119     private void thresholds_Reset(object sender, EventArgs e) {
    120       OnThresholdsChanged(EventArgs.Empty);
    121     }
    122 
    123     public event EventHandler ThresholdsChanged;
    124     private void OnThresholdsChanged(EventArgs e) {
    125       var listeners = ThresholdsChanged;
    126       if (listeners != null) listeners(this, e);
     88    public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     89      return Model.GetEstimatedClassValues(ProblemData.Dataset, rows);
    12790    }
    12891    #endregion
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/DataAnalysisProblemData.cs

    r5601 r5649  
    4444      get { return (IFixedValueParameter<Dataset>)Parameters[DatasetParameterName]; }
    4545    }
    46     public IFixedValueParameter<ICheckedItemCollection<StringValue>> InputVariablesParameter {
    47       get { return (IFixedValueParameter<ICheckedItemCollection<StringValue>>)Parameters[InputVariablesParameterName]; }
     46    public IFixedValueParameter<ICheckedItemList<StringValue>> InputVariablesParameter {
     47      get { return (IFixedValueParameter<ICheckedItemList<StringValue>>)Parameters[InputVariablesParameterName]; }
    4848    }
    4949    public IFixedValueParameter<IntValue> TrainingPartitionStartParameter {
     
    6565      get { return DatasetParameter.Value; }
    6666    }
    67     public ICheckedItemCollection<StringValue> InputVariables {
     67    public ICheckedItemList<StringValue> InputVariables {
    6868      get { return InputVariablesParameter.Value; }
    6969    }
    7070    public IEnumerable<string> AllowedInputVariables {
    71       get { return InputVariables.CheckedItems.Select(x => x.Value); }
     71      get { return InputVariables.CheckedItems.Select(x => x.Value.Value); }
    7272    }
    7373
     
    110110        throw new ArgumentException("All allowed input variables must be present in the dataset.");
    111111
    112       var inputVariables = new CheckedItemCollection<StringValue>(dataset.VariableNames.Select(x => new StringValue(x)));
     112      var inputVariables = new CheckedItemList<StringValue>(dataset.VariableNames.Select(x => new StringValue(x)));
    113113      foreach (StringValue x in inputVariables)
    114114        inputVariables.SetItemCheckedState(x, allowedInputVariables.Contains(x.Value));
     
    120120
    121121      Parameters.Add(new FixedValueParameter<Dataset>(DatasetParameterName, "", dataset));
    122       Parameters.Add(new FixedValueParameter<ICheckedItemCollection<StringValue>>(InputVariablesParameterName, "", inputVariables.AsReadOnly()));
     122      Parameters.Add(new FixedValueParameter<ICheckedItemList<StringValue>>(InputVariablesParameterName, "", inputVariables.AsReadOnly()));
    123123      Parameters.Add(new FixedValueParameter<IntValue>(TrainingPartitionStartParameterName, "", new IntValue(trainingPartitionStart)));
    124124      Parameters.Add(new FixedValueParameter<IntValue>(TrainingPartitionEndParameterName, "", new IntValue(trainingPartitionEnd)));
     
    132132    private void RegisterEventHandlers() {
    133133      DatasetParameter.ValueChanged += new EventHandler(Parameter_ValueChanged);
    134       InputVariables.CheckedItemsChanged += new CollectionItemsChangedEventHandler<StringValue>(InputVariables_CheckedItemsChanged);
     134      InputVariables.CheckedItemsChanged += new CollectionItemsChangedEventHandler<IndexedItem<StringValue>>(InputVariables_CheckedItemsChanged);
    135135      TrainingPartitionStart.ValueChanged += new EventHandler(Parameter_ValueChanged);
    136136      TrainingPartitionEnd.ValueChanged += new EventHandler(Parameter_ValueChanged);
     
    139139    }
    140140
    141     private void InputVariables_CheckedItemsChanged(object sender, CollectionItemsChangedEventArgs<StringValue> e) {
     141    private void InputVariables_CheckedItemsChanged(object sender, CollectionItemsChangedEventArgs<IndexedItem<StringValue>> e) {
    142142      OnChanged();
    143143    }
     144
    144145    private void Parameter_ValueChanged(object sender, EventArgs e) {
    145146      OnChanged();
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/DataAnalysisSolution.cs

    r5624 r5649  
    8282    public DataAnalysisSolution(IDataAnalysisModel model, IDataAnalysisProblemData problemData)
    8383      : base() {
    84       name = string.Empty;
    85       description = string.Empty;
     84      name = ItemName;
     85      description = ItemDescription;
    8686      Add(new Result(ModelResultName, "The symbolic data analysis model.", model));
    8787      Add(new Result(ProblemDataResultName, "The symbolic data analysis problem data.", problemData));
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/HeuristicLab.Problems.DataAnalysis-3.4.csproj

    r5620 r5649  
    111111    <Compile Include="ClassificationProblem.cs" />
    112112    <Compile Include="ClassificationSolution.cs" />
     113    <Compile Include="ClusteringProblem.cs" />
     114    <Compile Include="ClusteringProblemData.cs" />
     115    <Compile Include="ClusteringSolution.cs" />
     116    <Compile Include="DiscriminantFunctionClassificationModel.cs" />
     117    <Compile Include="DiscriminantFunctionClassificationSolution.cs" />
    113118    <Compile Include="DataAnalysisSolution.cs" />
     119    <Compile Include="Interfaces\Classification\IDiscriminantFunctionClassificationModel.cs" />
     120    <Compile Include="Interfaces\Classification\IDiscriminantFunctionClassificationSolution.cs" />
     121    <Compile Include="Interfaces\Clustering\IClusteringModel.cs" />
     122    <Compile Include="Interfaces\Clustering\IClusteringProblem.cs" />
     123    <Compile Include="Interfaces\Clustering\IClusteringProblemData.cs" />
     124    <Compile Include="Interfaces\Clustering\IClusteringSolution.cs" />
     125    <Compile Include="OnlineEvaluators\OnlineAccuracyEvaluator.cs" />
    114126    <Compile Include="RegressionProblemData.cs" />
    115127    <Compile Include="DataAnalysisProblemData.cs" />
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationModel.cs

    r5620 r5649  
    2323namespace HeuristicLab.Problems.DataAnalysis {
    2424  public interface IClassificationModel : IDataAnalysisModel {
    25     IEnumerable<double> GetEstimatedValues(IClassificationProblemData problemData, IEnumerable<int> rows);
    26     IEnumerable<double> GetEstimatedClassValues(IClassificationProblemData problemData, IEnumerable<int> rows);
     25    IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows);
    2726  }
    2827}
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationSolution.cs

    r5509 r5649  
    2727    new IClassificationProblemData ProblemData { get; }
    2828
    29     IEnumerable<double> EstimatedValues { get; }
    30     IEnumerable<double> EstimatedTrainingValues { get; }
    31     IEnumerable<double> EstimatedTestValues { get; }
    32     IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows);
    33 
    34     IEnumerable<double> Thresholds { get; }
    3529    IEnumerable<double> EstimatedClassValues { get; }
    3630    IEnumerable<double> EstimatedTrainingClassValues { get; }
    3731    IEnumerable<double> EstimatedTestClassValues { get; }
    3832    IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows);
    39 
    40     event EventHandler ThresholdsChanged;
    4133  }
    4234}
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/IDataAnalysisProblemData.cs

    r5601 r5649  
    2828  public interface IDataAnalysisProblemData : IParameterizedNamedItem {
    2929    Dataset Dataset { get; }
    30     ICheckedItemCollection<StringValue> InputVariables { get; }
     30    ICheckedItemList<StringValue> InputVariables { get; }
    3131    IEnumerable<string> AllowedInputVariables { get; }
    3232
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Regression/IRegressionModel.cs

    r5496 r5649  
    2323namespace HeuristicLab.Problems.DataAnalysis {
    2424  public interface IRegressionModel : IDataAnalysisModel {
    25     IEnumerable<double> GetEstimatedValues(IRegressionProblemData problemData, IEnumerable<int> rows);
     25    IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows);
    2626  }
    2727}
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/RegressionProblem.cs

    r5625 r5649  
    2626namespace HeuristicLab.Problems.DataAnalysis {
    2727  [StorableClass]
    28   [Item("RegressionProblem", "")]
     28  [Item("RegressionProblem", "A general regression problem")]
    2929  [Creatable("Problems")]
    3030  public class RegressionProblem : DataAnalysisProblem<IRegressionProblemData>, IRegressionProblem {
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/RegressionSolution.cs

    r5624 r5649  
    3737  [StorableClass]
    3838  public abstract class RegressionSolution : DataAnalysisSolution, IRegressionSolution {
     39    private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
     40    private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
     41    private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)";
     42    private const string TestSquaredCorrelationResultName = "Pearson's R² (test)";
     43    private const string TrainingRelativeErrorResultName = "Average relative error (training)";
     44    private const string TestRelativeErrorResultName = "Average relative error (test)";
     45
    3946    [StorableConstructor]
    4047    protected RegressionSolution(bool deserializing) : base(deserializing) { }
     
    4451    public RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
    4552      : base(model, problemData) {
     53      double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
     54      IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
     55      double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
     56      IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
     57
     58      double trainingMSE = OnlineMeanSquaredErrorEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues);
     59      double testMSE = OnlineMeanSquaredErrorEvaluator.Calculate(estimatedTestValues, originalTestValues);
     60      double trainingR2 = OnlinePearsonsRSquaredEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues);
     61      double testR2 = OnlinePearsonsRSquaredEvaluator.Calculate(estimatedTestValues, originalTestValues);
     62      double trainingRelError = OnlineMeanAbsolutePercentageErrorEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues);
     63      double testRelError = OnlineMeanAbsolutePercentageErrorEvaluator.Calculate(estimatedTestValues, originalTestValues);
     64
     65      Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue(trainingMSE)));
     66      Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue(testMSE)));
     67      Add(new Result(TrainingSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue(trainingR2)));
     68      Add(new Result(TestSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue(testR2)));
     69      Add(new Result(TrainingRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the training partition", new PercentValue(trainingRelError)));
     70      Add(new Result(TestRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the test partition", new PercentValue(testRelError)));
    4671    }
    4772
     73    protected override void OnProblemDataChanged(EventArgs e) {
     74      base.OnProblemDataChanged(e);
     75      throw new NotImplementedException(); // need to recalculate results
     76    }
     77    protected override void OnModelChanged(EventArgs e) {
     78      base.OnModelChanged(e);
     79      throw new NotImplementedException(); // need to recalculate results
     80    }
    4881    #region IRegressionSolution Members
    4982
     
    75108
    76109    public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    77       return Model.GetEstimatedValues(ProblemData, rows);
     110      return Model.GetEstimatedValues(ProblemData.Dataset, rows);
    78111    }
    79 
    80112    #endregion
    81113  }
Note: See TracChangeset for help on using the changeset viewer.