Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/14/11 13:59:25 (13 years ago)
Author:
epitzer
Message:

#1530 integrate changes from trunk

Location:
branches/PersistenceSpeedUp
Files:
8 edited
3 copied

Legend:

Unmodified
Added
Removed
  • branches/PersistenceSpeedUp

  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis

  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleModel.cs

    r5809 r6760  
    3939      get { return new List<IClassificationModel>(models); }
    4040    }
     41
    4142    [StorableConstructor]
    4243    protected ClassificationEnsembleModel(bool deserializing) : base(deserializing) { }
     
    4546      this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
    4647    }
     48
     49    public ClassificationEnsembleModel() : this(Enumerable.Empty<IClassificationModel>()) { }
    4750    public ClassificationEnsembleModel(IEnumerable<IClassificationModel> models)
    4851      : base() {
     
    5760
    5861    #region IClassificationEnsembleModel Members
     62    public void Add(IClassificationModel model) {
     63      models.Add(model);
     64    }
     65    public void Remove(IClassificationModel model) {
     66      models.Remove(model);
     67    }
    5968
    6069    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
     
    8594    }
    8695
     96    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
     97      return new ClassificationEnsembleSolution(models, problemData);
     98    }
    8799    #endregion
    88100  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r6184 r6760  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     25using HeuristicLab.Collections;
    2426using HeuristicLab.Common;
    2527using HeuristicLab.Core;
     28using HeuristicLab.Data;
    2629using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2730
     
    3235  [StorableClass]
    3336  [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")]
    34   // [Creatable("Data Analysis")]
    35   public class ClassificationEnsembleSolution : NamedItem, IClassificationEnsembleSolution {
     37  [Creatable("Data Analysis - Ensembles")]
     38  public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution {
     39    public new IClassificationEnsembleModel Model {
     40      get { return (IClassificationEnsembleModel)base.Model; }
     41    }
     42    public new ClassificationEnsembleProblemData ProblemData {
     43      get { return (ClassificationEnsembleProblemData)base.ProblemData; }
     44      set { base.ProblemData = value; }
     45    }
     46
     47    private readonly ItemCollection<IClassificationSolution> classificationSolutions;
     48    public IItemCollection<IClassificationSolution> ClassificationSolutions {
     49      get { return classificationSolutions; }
     50    }
    3651
    3752    [Storable]
    38     private List<IClassificationModel> models;
    39     public IEnumerable<IClassificationModel> Models {
    40       get { return new List<IClassificationModel>(models); }
    41     }
     53    private Dictionary<IClassificationModel, IntRange> trainingPartitions;
     54    [Storable]
     55    private Dictionary<IClassificationModel, IntRange> testPartitions;
     56
    4257    [StorableConstructor]
    43     protected ClassificationEnsembleSolution(bool deserializing) : base(deserializing) { }
    44     protected ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
     58    private ClassificationEnsembleSolution(bool deserializing)
     59      : base(deserializing) {
     60      classificationSolutions = new ItemCollection<IClassificationSolution>();
     61    }
     62    [StorableHook(HookType.AfterDeserialization)]
     63    private void AfterDeserialization() {
     64      foreach (var model in Model.Models) {
     65        IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone();
     66        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
     67        problemData.TrainingPartition.End = trainingPartitions[model].End;
     68        problemData.TestPartition.Start = testPartitions[model].Start;
     69        problemData.TestPartition.End = testPartitions[model].End;
     70
     71        classificationSolutions.Add(model.CreateClassificationSolution(problemData));
     72      }
     73      RegisterClassificationSolutionsEventHandler();
     74    }
     75
     76    private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
    4577      : base(original, cloner) {
    46       this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
    47     }
    48     public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models)
    49       : base() {
    50       this.name = ItemName;
    51       this.description = ItemDescription;
    52       this.models = new List<IClassificationModel>(models);
     78      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     79      testPartitions = new Dictionary<IClassificationModel, IntRange>();
     80      foreach (var pair in original.trainingPartitions) {
     81        trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     82      }
     83      foreach (var pair in original.testPartitions) {
     84        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     85      }
     86
     87      classificationSolutions = cloner.Clone(original.classificationSolutions);
     88      RegisterClassificationSolutionsEventHandler();
     89    }
     90
     91    public ClassificationEnsembleSolution()
     92      : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) {
     93      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     94      testPartitions = new Dictionary<IClassificationModel, IntRange>();
     95      classificationSolutions = new ItemCollection<IClassificationSolution>();
     96
     97      RegisterClassificationSolutionsEventHandler();
     98    }
     99
     100    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData)
     101      : this(models, problemData,
     102             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
     103             models.Select(m => (IntRange)problemData.TestPartition.Clone())
     104      ) { }
     105
     106    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
     107      : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) {
     108      this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     109      this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
     110      this.classificationSolutions = new ItemCollection<IClassificationSolution>();
     111
     112      List<IClassificationSolution> solutions = new List<IClassificationSolution>();
     113      var modelEnumerator = models.GetEnumerator();
     114      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
     115      var testPartitionEnumerator = testPartitions.GetEnumerator();
     116
     117      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
     118        var p = (IClassificationProblemData)problemData.Clone();
     119        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
     120        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
     121        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
     122        p.TestPartition.End = testPartitionEnumerator.Current.End;
     123
     124        solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p));
     125      }
     126      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
     127        throw new ArgumentException();
     128      }
     129
     130      RegisterClassificationSolutionsEventHandler();
     131      classificationSolutions.AddRange(solutions);
    53132    }
    54133
     
    56135      return new ClassificationEnsembleSolution(this, cloner);
    57136    }
    58 
    59     #region IClassificationEnsembleModel Members
     137    private void RegisterClassificationSolutionsEventHandler() {
     138      classificationSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsAdded);
     139      classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved);
     140      classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset);
     141    }
     142
     143    protected override void RecalculateResults() {
     144      CalculateResults();
     145    }
     146
     147    #region Evaluation
     148    public override IEnumerable<double> EstimatedTrainingClassValues {
     149      get {
     150        var rows = ProblemData.TrainingIndizes;
     151        var estimatedValuesEnumerators = (from model in Model.Models
     152                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     153                                         .ToList();
     154        var rowsEnumerator = rows.GetEnumerator();
     155        // aggregate to make sure that MoveNext is called for all enumerators
     156        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     157          int currentRow = rowsEnumerator.Current;
     158
     159          var selectedEnumerators = from pair in estimatedValuesEnumerators
     160                                    where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
     161                                    select pair.EstimatedValuesEnumerator;
     162          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     163        }
     164      }
     165    }
     166
     167    public override IEnumerable<double> EstimatedTestClassValues {
     168      get {
     169        var rows = ProblemData.TestIndizes;
     170        var estimatedValuesEnumerators = (from model in Model.Models
     171                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     172                                         .ToList();
     173        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
     174        // aggregate to make sure that MoveNext is called for all enumerators
     175        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     176          int currentRow = rowsEnumerator.Current;
     177
     178          var selectedEnumerators = from pair in estimatedValuesEnumerators
     179                                    where RowIsTestForModel(currentRow, pair.Model)
     180                                    select pair.EstimatedValuesEnumerator;
     181
     182          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     183        }
     184      }
     185    }
     186
     187    private bool RowIsTrainingForModel(int currentRow, IClassificationModel model) {
     188      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
     189              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
     190    }
     191
     192    private bool RowIsTestForModel(int currentRow, IClassificationModel model) {
     193      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     194              (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
     195    }
     196
     197    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     198      return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows)
     199             select AggregateEstimatedClassValues(xs);
     200    }
    60201
    61202    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
    62       var estimatedValuesEnumerators = (from model in models
     203      var estimatedValuesEnumerators = (from model in Model.Models
    63204                                        select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
    64205                                       .ToList();
     
    70211    }
    71212
     213    private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) {
     214      return estimatedClassValues
     215      .GroupBy(x => x)
     216      .OrderBy(g => -g.Count())
     217      .Select(g => g.Key)
     218      .DefaultIfEmpty(double.NaN)
     219      .First();
     220    }
    72221    #endregion
    73222
    74     #region IClassificationModel Members
    75 
    76     public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    77       foreach (var estimatedValuesVector in GetEstimatedClassValueVectors(dataset, rows)) {
    78         // return the class which is most often occuring
    79         yield return
    80           estimatedValuesVector
    81           .GroupBy(x => x)
    82           .OrderBy(g => -g.Count())
    83           .Select(g => g.Key)
    84           .First();
    85       }
    86     }
    87 
    88     #endregion
     223    protected override void OnProblemDataChanged() {
     224      IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset,
     225                                                                     ProblemData.AllowedInputVariables,
     226                                                                     ProblemData.TargetVariable);
     227      problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
     228      problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
     229      problemData.TestPartition.Start = ProblemData.TestPartition.Start;
     230      problemData.TestPartition.End = ProblemData.TestPartition.End;
     231
     232      foreach (var solution in ClassificationSolutions) {
     233        if (solution is ClassificationEnsembleSolution)
     234          solution.ProblemData = ProblemData;
     235        else
     236          solution.ProblemData = problemData;
     237      }
     238      foreach (var trainingPartition in trainingPartitions.Values) {
     239        trainingPartition.Start = ProblemData.TrainingPartition.Start;
     240        trainingPartition.End = ProblemData.TrainingPartition.End;
     241      }
     242      foreach (var testPartition in testPartitions.Values) {
     243        testPartition.Start = ProblemData.TestPartition.Start;
     244        testPartition.End = ProblemData.TestPartition.End;
     245      }
     246
     247      base.OnProblemDataChanged();
     248    }
     249
     250    public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
     251      classificationSolutions.AddRange(solutions);
     252    }
     253    public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
     254      classificationSolutions.RemoveRange(solutions);
     255    }
     256
     257    private void classificationSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     258      foreach (var solution in e.Items) AddClassificationSolution(solution);
     259      RecalculateResults();
     260    }
     261    private void classificationSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     262      foreach (var solution in e.Items) RemoveClassificationSolution(solution);
     263      RecalculateResults();
     264    }
     265    private void classificationSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     266      foreach (var solution in e.OldItems) RemoveClassificationSolution(solution);
     267      foreach (var solution in e.Items) AddClassificationSolution(solution);
     268      RecalculateResults();
     269    }
     270
     271    private void AddClassificationSolution(IClassificationSolution solution) {
     272      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
     273      Model.Add(solution.Model);
     274      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
     275      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
     276    }
     277
     278    private void RemoveClassificationSolution(IClassificationSolution solution) {
     279      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
     280      Model.Remove(solution.Model);
     281      trainingPartitions.Remove(solution.Model);
     282      testPartitions.Remove(solution.Model);
     283    }
    89284  }
    90285}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblemData.cs

    r6232 r6760  
    3434  [Item("ClassificationProblemData", "Represents an item containing all data defining a classification problem.")]
    3535  public class ClassificationProblemData : DataAnalysisProblemData, IClassificationProblemData {
    36     private const string TargetVariableParameterName = "TargetVariable";
    37     private const string ClassNamesParameterName = "ClassNames";
    38     private const string ClassificationPenaltiesParameterName = "ClassificationPenalties";
    39     private const int MaximumNumberOfClasses = 20;
    40     private const int InspectedRowsToDetermineTargets = 500;
     36    protected const string TargetVariableParameterName = "TargetVariable";
     37    protected const string ClassNamesParameterName = "ClassNames";
     38    protected const string ClassificationPenaltiesParameterName = "ClassificationPenalties";
     39    protected const int MaximumNumberOfClasses = 20;
     40    protected const int InspectedRowsToDetermineTargets = 500;
    4141
    4242    #region default data
     
    171171     {1176881,7,5,3,7,4,10,7,5,5,4        }
    172172};
    173     private static Dataset defaultDataset;
    174     private static IEnumerable<string> defaultAllowedInputVariables;
    175     private static string defaultTargetVariable;
     173    private static readonly Dataset defaultDataset;
     174    private static readonly IEnumerable<string> defaultAllowedInputVariables;
     175    private static readonly string defaultTargetVariable;
     176
     177    private static readonly ClassificationProblemData emptyProblemData;
     178    public static ClassificationProblemData EmptyProblemData {
     179      get { return EmptyProblemData; }
     180    }
     181
    176182    static ClassificationProblemData() {
    177183      defaultDataset = new Dataset(defaultVariableNames, defaultData);
     
    181187      defaultAllowedInputVariables = defaultVariableNames.Except(new List<string>() { "sample", "class" });
    182188      defaultTargetVariable = "class";
     189
     190      var problemData = new ClassificationProblemData();
     191      problemData.Parameters.Clear();
     192      problemData.Name = "Empty Classification ProblemData";
     193      problemData.Description = "This ProblemData acts as place holder before the correct problem data is loaded.";
     194      problemData.isEmpty = true;
     195
     196      problemData.Parameters.Add(new FixedValueParameter<Dataset>(DatasetParameterName, "", new Dataset()));
     197      problemData.Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemList<StringValue>>(InputVariablesParameterName, ""));
     198      problemData.Parameters.Add(new FixedValueParameter<IntRange>(TrainingPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly()));
     199      problemData.Parameters.Add(new FixedValueParameter<IntRange>(TestPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly()));
     200      problemData.Parameters.Add(new ConstrainedValueParameter<StringValue>(TargetVariableParameterName, new ItemSet<StringValue>()));
     201      problemData.Parameters.Add(new FixedValueParameter<StringMatrix>(ClassNamesParameterName, "", new StringMatrix(0, 0).AsReadOnly()));
     202      problemData.Parameters.Add(new FixedValueParameter<DoubleMatrix>(ClassificationPenaltiesParameterName, "", (DoubleMatrix)new DoubleMatrix(0, 0).AsReadOnly()));
     203      emptyProblemData = problemData;
    183204    }
    184205    #endregion
    185206
    186207    #region parameter properties
    187     public IValueParameter<StringValue> TargetVariableParameter {
    188       get { return (IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
     208    public ConstrainedValueParameter<StringValue> TargetVariableParameter {
     209      get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
    189210    }
    190211    public IFixedValueParameter<StringMatrix> ClassNamesParameter {
     
    205226      get {
    206227        if (classValues == null) {
    207           classValues = Dataset.GetEnumeratedVariableValues(TargetVariableParameter.Value.Value).Distinct().ToList();
     228          classValues = Dataset.GetDoubleValues(TargetVariableParameter.Value.Value).Distinct().ToList();
    208229          classValues.Sort();
    209230        }
     
    249270      RegisterParameterEvents();
    250271    }
    251     public override IDeepCloneable Clone(Cloner cloner) { return new ClassificationProblemData(this, cloner); }
     272    public override IDeepCloneable Clone(Cloner cloner) {
     273      if (this == emptyProblemData) return emptyProblemData;
     274      return new ClassificationProblemData(this, cloner);
     275    }
    252276
    253277    public ClassificationProblemData() : this(defaultDataset, defaultAllowedInputVariables, defaultTargetVariable) { }
     
    267291    private static IEnumerable<string> CheckVariablesForPossibleTargetVariables(Dataset dataset) {
    268292      int maxSamples = Math.Min(InspectedRowsToDetermineTargets, dataset.Rows);
    269       var validTargetVariables = from v in dataset.VariableNames
    270                                  let DistinctValues = dataset.GetVariableValues(v)
    271                                    .Take(maxSamples)
    272                                    .Distinct()
    273                                    .Count()
    274                                  where DistinctValues < MaximumNumberOfClasses
    275                                  select v;
     293      var validTargetVariables = (from v in dataset.DoubleVariables
     294                                  let distinctValues = dataset.GetDoubleValues(v)
     295                                    .Take(maxSamples)
     296                                    .Distinct()
     297                                    .Count()
     298                                  where distinctValues < MaximumNumberOfClasses
     299                                  select v).ToArray();
    276300
    277301      if (!validTargetVariables.Any())
     
    283307
    284308    private void ResetTargetVariableDependentMembers() {
    285       DergisterParameterEvents();
     309      DeregisterParameterEvents();
    286310
    287311      classNames = null;
     
    357381      ClassificationPenaltiesParameter.Value.ItemChanged += new EventHandler<EventArgs<int, int>>(MatrixParameter_ItemChanged);
    358382    }
    359     private void DergisterParameterEvents() {
     383    private void DeregisterParameterEvents() {
    360384      TargetVariableParameter.ValueChanged -= new EventHandler(TargetVariableParameter_ValueChanged);
    361385      ClassNamesParameter.Value.Reset -= new EventHandler(Parameter_ValueChanged);
     
    386410      dataset.Name = Path.GetFileName(fileName);
    387411
    388       ClassificationProblemData problemData = new ClassificationProblemData(dataset, dataset.VariableNames.Skip(1), dataset.VariableNames.First());
     412      ClassificationProblemData problemData = new ClassificationProblemData(dataset, dataset.DoubleVariables.Skip(1), dataset.DoubleVariables.First());
    389413      problemData.Name = "Data imported from " + Path.GetFileName(fileName);
    390414      return problemData;
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs

    r6184 r6760  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
    2524using HeuristicLab.Common;
    26 using HeuristicLab.Data;
    27 using HeuristicLab.Optimization;
    2825using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2926
     
    3330  /// </summary>
    3431  [StorableClass]
    35   public class ClassificationSolution : DataAnalysisSolution, IClassificationSolution {
    36     private const string TrainingAccuracyResultName = "Accuracy (training)";
    37     private const string TestAccuracyResultName = "Accuracy (test)";
    38 
    39     public new IClassificationModel Model {
    40       get { return (IClassificationModel)base.Model; }
    41       protected set { base.Model = value; }
    42     }
    43 
    44     public new IClassificationProblemData ProblemData {
    45       get { return (IClassificationProblemData)base.ProblemData; }
    46       protected set { base.ProblemData = value; }
    47     }
    48 
    49     public double TrainingAccuracy {
    50       get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; }
    51       private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; }
    52     }
    53 
    54     public double TestAccuracy {
    55       get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; }
    56       private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; }
    57     }
     32  public abstract class ClassificationSolution : ClassificationSolutionBase {
     33    protected readonly Dictionary<int, double> evaluationCache;
    5834
    5935    [StorableConstructor]
    60     protected ClassificationSolution(bool deserializing) : base(deserializing) { }
     36    protected ClassificationSolution(bool deserializing)
     37      : base(deserializing) {
     38      evaluationCache = new Dictionary<int, double>();
     39    }
    6140    protected ClassificationSolution(ClassificationSolution original, Cloner cloner)
    6241      : base(original, cloner) {
     42      evaluationCache = new Dictionary<int, double>(original.evaluationCache);
    6343    }
    6444    public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData)
    6545      : base(model, problemData) {
    66       Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));
    67       Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));
    68       RecalculateResults();
     46      evaluationCache = new Dictionary<int, double>();
    6947    }
    7048
    71     public override IDeepCloneable Clone(Cloner cloner) {
    72       return new ClassificationSolution(this, cloner);
     49    public override IEnumerable<double> EstimatedClassValues {
     50      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     51    }
     52    public override IEnumerable<double> EstimatedTrainingClassValues {
     53      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
     54    }
     55    public override IEnumerable<double> EstimatedTestClassValues {
     56      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
    7357    }
    7458
    75     protected override void OnProblemDataChanged(EventArgs e) {
    76       base.OnProblemDataChanged(e);
    77       RecalculateResults();
     59    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     60      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
     61      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     62      var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     63
     64      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     65        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     66      }
     67
     68      return rows.Select(row => evaluationCache[row]);
    7869    }
    7970
    80     protected override void OnModelChanged(EventArgs e) {
    81       base.OnModelChanged(e);
    82       RecalculateResults();
     71    protected override void OnProblemDataChanged() {
     72      evaluationCache.Clear();
     73      base.OnProblemDataChanged();
    8374    }
    8475
    85     protected void RecalculateResults() {
    86       double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values
    87       IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    88       double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values
    89       IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    90 
    91       OnlineCalculatorError errorState;
    92       double trainingAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues, out errorState);
    93       if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN;
    94       double testAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTestClassValues, originalTestClassValues, out errorState);
    95       if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN;
    96 
    97       TrainingAccuracy = trainingAccuracy;
    98       TestAccuracy = testAccuracy;
    99     }
    100 
    101     public virtual IEnumerable<double> EstimatedClassValues {
    102       get {
    103         return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
    104       }
    105     }
    106 
    107     public virtual IEnumerable<double> EstimatedTrainingClassValues {
    108       get {
    109         return GetEstimatedClassValues(ProblemData.TrainingIndizes);
    110       }
    111     }
    112 
    113     public virtual IEnumerable<double> EstimatedTestClassValues {
    114       get {
    115         return GetEstimatedClassValues(ProblemData.TestIndizes);
    116       }
    117     }
    118 
    119     public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    120       return Model.GetEstimatedClassValues(ProblemData.Dataset, rows);
     76    protected override void OnModelChanged() {
     77      evaluationCache.Clear();
     78      base.OnModelChanged();
    12179    }
    12280  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs

    r5809 r6760  
    3333  [StorableClass]
    3434  [Item("DiscriminantFunctionClassificationModel", "Represents a classification model that uses a discriminant function and classification thresholds.")]
    35   public class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel {
     35  public abstract class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel {
    3636    [Storable]
    3737    private IRegressionModel model;
     
    7070    }
    7171
    72     public override IDeepCloneable Clone(Cloner cloner) {
    73       return new DiscriminantFunctionClassificationModel(this, cloner);
    74     }
    75 
    7672    public void SetThresholdsAndClassValues(IEnumerable<double> thresholds, IEnumerable<double> classValues) {
    7773      var classValuesArr = classValues.ToArray();
     
    106102    }
    107103    #endregion
     104
     105    public abstract IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData);
     106    public abstract IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData);
    108107  }
    109108}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs

    r5942 r6760  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
     
    2625using HeuristicLab.Core;
    2726using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    28 using HeuristicLab.Data;
    29 using HeuristicLab.Optimization;
    3027
    3128namespace HeuristicLab.Problems.DataAnalysis {
     
    3532  [StorableClass]
    3633  [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")]
    37   public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution {
    38     private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
    39     private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
    40     private const string TrainingRSquaredResultName = "Pearson's R² (training)";
    41     private const string TestRSquaredResultName = "Pearson's R² (test)";
     34  public abstract class DiscriminantFunctionClassificationSolution : DiscriminantFunctionClassificationSolutionBase {
     35    protected readonly Dictionary<int, double> valueEvaluationCache;
     36    protected readonly Dictionary<int, double> classValueEvaluationCache;
    4237
    43     public new IDiscriminantFunctionClassificationModel Model {
    44       get { return (IDiscriminantFunctionClassificationModel)base.Model; }
    45       protected set {
    46         if (value != null && value != Model) {
    47           if (Model != null) {
    48             Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
    49           }
    50           value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
    51           base.Model = value;
    52         }
    53       }
     38    [StorableConstructor]
     39    protected DiscriminantFunctionClassificationSolution(bool deserializing)
     40      : base(deserializing) {
     41      valueEvaluationCache = new Dictionary<int, double>();
     42      classValueEvaluationCache = new Dictionary<int, double>();
     43    }
     44    protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
     45      : base(original, cloner) {
     46      valueEvaluationCache = new Dictionary<int, double>(original.valueEvaluationCache);
     47      classValueEvaluationCache = new Dictionary<int, double>(original.classValueEvaluationCache);
     48    }
     49    protected DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
     50      : base(model, problemData) {
     51      valueEvaluationCache = new Dictionary<int, double>();
     52      classValueEvaluationCache = new Dictionary<int, double>();
     53
     54      SetAccuracyMaximizingThresholds();
    5455    }
    5556
    56     public double TrainingMeanSquaredError {
    57       get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    58       private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
     57    public override IEnumerable<double> EstimatedClassValues {
     58      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     59    }
     60    public override IEnumerable<double> EstimatedTrainingClassValues {
     61      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
     62    }
     63    public override IEnumerable<double> EstimatedTestClassValues {
     64      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
    5965    }
    6066
    61     public double TestMeanSquaredError {
    62       get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    63       private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
     67    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     68      var rowsToEvaluate = rows.Except(classValueEvaluationCache.Keys);
     69      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     70      var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     71
     72      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     73        classValueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     74      }
     75
     76      return rows.Select(row => classValueEvaluationCache[row]);
    6477    }
    6578
    66     public double TrainingRSquared {
    67       get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }
    68       private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }
    69     }
    7079
    71     public double TestRSquared {
    72       get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; }
    73       private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; }
    74     }
    75 
    76     [StorableConstructor]
    77     protected DiscriminantFunctionClassificationSolution(bool deserializing) : base(deserializing) { }
    78     protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
    79       : base(original, cloner) {
    80       RegisterEventHandler();
    81     }
    82     public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)
    83       : this(new DiscriminantFunctionClassificationModel(model), problemData) {
    84     }
    85     public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
    86       : base(model, problemData) {
    87       Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
    88       Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
    89       Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
    90       Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
    91       RegisterEventHandler();
    92       SetAccuracyMaximizingThresholds();
    93       RecalculateResults();
    94     }
    95 
    96     [StorableHook(HookType.AfterDeserialization)]
    97     private void AfterDeserialization() {
    98       RegisterEventHandler();
    99     }
    100 
    101     protected new void RecalculateResults() {
    102       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    103       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    104       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    105       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    106 
    107       OnlineCalculatorError errorState;
    108       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    109       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    110       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    111       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    112 
    113       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    114       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    115       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    116       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    117     }
    118 
    119     private void RegisterEventHandler() {
    120       Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
    121     }
    122     private void Model_ThresholdsChanged(object sender, EventArgs e) {
    123       OnModelThresholdsChanged(e);
    124     }
    125 
    126     public void SetAccuracyMaximizingThresholds() {
    127       double[] classValues;
    128       double[] thresholds;
    129       var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    130       AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    131 
    132       Model.SetThresholdsAndClassValues(thresholds, classValues);
    133     }
    134 
    135     public void SetClassDistibutionCutPointThresholds() {
    136       double[] classValues;
    137       double[] thresholds;
    138       var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    139       NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    140 
    141       Model.SetThresholdsAndClassValues(thresholds, classValues);
    142     }
    143 
    144     protected override void OnModelChanged(EventArgs e) {
    145       base.OnModelChanged(e);
    146       SetAccuracyMaximizingThresholds();
    147       RecalculateResults();
    148     }
    149 
    150     protected override void OnProblemDataChanged(EventArgs e) {
    151       base.OnProblemDataChanged(e);
    152       SetAccuracyMaximizingThresholds();
    153       RecalculateResults();
    154     }
    155     protected virtual void OnModelThresholdsChanged(EventArgs e) {
    156       base.OnModelChanged(e);
    157       RecalculateResults();
    158     }
    159 
    160     public IEnumerable<double> EstimatedValues {
     80    public override IEnumerable<double> EstimatedValues {
    16181      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
    16282    }
    163 
    164     public IEnumerable<double> EstimatedTrainingValues {
     83    public override IEnumerable<double> EstimatedTrainingValues {
    16584      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
    16685    }
    167 
    168     public IEnumerable<double> EstimatedTestValues {
     86    public override IEnumerable<double> EstimatedTestValues {
    16987      get { return GetEstimatedValues(ProblemData.TestIndizes); }
    17088    }
    17189
    172     public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    173       return Model.GetEstimatedValues(ProblemData.Dataset, rows);
     90    public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
     91      var rowsToEvaluate = rows.Except(valueEvaluationCache.Keys);
     92      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     93      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     94
     95      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     96        valueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     97      }
     98
     99      return rows.Select(row => valueEvaluationCache[row]);
     100    }
     101
     102    protected override void OnModelChanged() {
     103      valueEvaluationCache.Clear();
     104      classValueEvaluationCache.Clear();
     105      base.OnModelChanged();
     106    }
     107    protected override void OnModelThresholdsChanged(System.EventArgs e) {
     108      classValueEvaluationCache.Clear();
     109      base.OnModelThresholdsChanged(e);
     110    }
     111    protected override void OnProblemDataChanged() {
     112      valueEvaluationCache.Clear();
     113      classValueEvaluationCache.Clear();
     114      base.OnProblemDataChanged();
    174115    }
    175116  }
Note: See TracChangeset for help on using the changeset viewer.