Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/14/11 13:59:25 (13 years ago)
Author:
epitzer
Message:

#1530 integrate changes from trunk

Location:
branches/PersistenceSpeedUp
Files:
17 edited
5 copied

Legend:

Unmodified
Added
Removed
  • branches/PersistenceSpeedUp

  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis

  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleModel.cs

    r5809 r6760  
    3939      get { return new List<IClassificationModel>(models); }
    4040    }
     41
    4142    [StorableConstructor]
    4243    protected ClassificationEnsembleModel(bool deserializing) : base(deserializing) { }
     
    4546      this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
    4647    }
     48
     49    public ClassificationEnsembleModel() : this(Enumerable.Empty<IClassificationModel>()) { }
    4750    public ClassificationEnsembleModel(IEnumerable<IClassificationModel> models)
    4851      : base() {
     
    5760
    5861    #region IClassificationEnsembleModel Members
     62    public void Add(IClassificationModel model) {
     63      models.Add(model);
     64    }
     65    public void Remove(IClassificationModel model) {
     66      models.Remove(model);
     67    }
    5968
    6069    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
     
    8594    }
    8695
     96    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
     97      return new ClassificationEnsembleSolution(models, problemData);
     98    }
    8799    #endregion
    88100  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r6184 r6760  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     25using HeuristicLab.Collections;
    2426using HeuristicLab.Common;
    2527using HeuristicLab.Core;
     28using HeuristicLab.Data;
    2629using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2730
     
    3235  [StorableClass]
    3336  [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")]
    34   // [Creatable("Data Analysis")]
    35   public class ClassificationEnsembleSolution : NamedItem, IClassificationEnsembleSolution {
     37  [Creatable("Data Analysis - Ensembles")]
     38  public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution {
     39    public new IClassificationEnsembleModel Model {
     40      get { return (IClassificationEnsembleModel)base.Model; }
     41    }
     42    public new ClassificationEnsembleProblemData ProblemData {
     43      get { return (ClassificationEnsembleProblemData)base.ProblemData; }
     44      set { base.ProblemData = value; }
     45    }
     46
     47    private readonly ItemCollection<IClassificationSolution> classificationSolutions;
     48    public IItemCollection<IClassificationSolution> ClassificationSolutions {
     49      get { return classificationSolutions; }
     50    }
    3651
    3752    [Storable]
    38     private List<IClassificationModel> models;
    39     public IEnumerable<IClassificationModel> Models {
    40       get { return new List<IClassificationModel>(models); }
    41     }
     53    private Dictionary<IClassificationModel, IntRange> trainingPartitions;
     54    [Storable]
     55    private Dictionary<IClassificationModel, IntRange> testPartitions;
     56
    4257    [StorableConstructor]
    43     protected ClassificationEnsembleSolution(bool deserializing) : base(deserializing) { }
    44     protected ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
     58    private ClassificationEnsembleSolution(bool deserializing)
     59      : base(deserializing) {
     60      classificationSolutions = new ItemCollection<IClassificationSolution>();
     61    }
     62    [StorableHook(HookType.AfterDeserialization)]
     63    private void AfterDeserialization() {
     64      foreach (var model in Model.Models) {
     65        IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone();
     66        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
     67        problemData.TrainingPartition.End = trainingPartitions[model].End;
     68        problemData.TestPartition.Start = testPartitions[model].Start;
     69        problemData.TestPartition.End = testPartitions[model].End;
     70
     71        classificationSolutions.Add(model.CreateClassificationSolution(problemData));
     72      }
     73      RegisterClassificationSolutionsEventHandler();
     74    }
     75
     76    private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
    4577      : base(original, cloner) {
    46       this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
    47     }
    48     public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models)
    49       : base() {
    50       this.name = ItemName;
    51       this.description = ItemDescription;
    52       this.models = new List<IClassificationModel>(models);
     78      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     79      testPartitions = new Dictionary<IClassificationModel, IntRange>();
     80      foreach (var pair in original.trainingPartitions) {
     81        trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     82      }
     83      foreach (var pair in original.testPartitions) {
     84        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     85      }
     86
     87      classificationSolutions = cloner.Clone(original.classificationSolutions);
     88      RegisterClassificationSolutionsEventHandler();
     89    }
     90
     91    public ClassificationEnsembleSolution()
     92      : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) {
     93      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     94      testPartitions = new Dictionary<IClassificationModel, IntRange>();
     95      classificationSolutions = new ItemCollection<IClassificationSolution>();
     96
     97      RegisterClassificationSolutionsEventHandler();
     98    }
     99
     100    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData)
     101      : this(models, problemData,
     102             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
     103             models.Select(m => (IntRange)problemData.TestPartition.Clone())
     104      ) { }
     105
     106    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
     107      : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) {
     108      this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     109      this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
     110      this.classificationSolutions = new ItemCollection<IClassificationSolution>();
     111
     112      List<IClassificationSolution> solutions = new List<IClassificationSolution>();
     113      var modelEnumerator = models.GetEnumerator();
     114      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
     115      var testPartitionEnumerator = testPartitions.GetEnumerator();
     116
     117      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
     118        var p = (IClassificationProblemData)problemData.Clone();
     119        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
     120        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
     121        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
     122        p.TestPartition.End = testPartitionEnumerator.Current.End;
     123
     124        solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p));
     125      }
     126      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
     127        throw new ArgumentException();
     128      }
     129
     130      RegisterClassificationSolutionsEventHandler();
     131      classificationSolutions.AddRange(solutions);
    53132    }
    54133
     
    56135      return new ClassificationEnsembleSolution(this, cloner);
    57136    }
    58 
    59     #region IClassificationEnsembleModel Members
     137    private void RegisterClassificationSolutionsEventHandler() {
     138      classificationSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsAdded);
     139      classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved);
     140      classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset);
     141    }
     142
     143    protected override void RecalculateResults() {
     144      CalculateResults();
     145    }
     146
     147    #region Evaluation
     148    public override IEnumerable<double> EstimatedTrainingClassValues {
     149      get {
     150        var rows = ProblemData.TrainingIndizes;
     151        var estimatedValuesEnumerators = (from model in Model.Models
     152                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     153                                         .ToList();
     154        var rowsEnumerator = rows.GetEnumerator();
     155        // aggregate to make sure that MoveNext is called for all enumerators
     156        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     157          int currentRow = rowsEnumerator.Current;
     158
     159          var selectedEnumerators = from pair in estimatedValuesEnumerators
     160                                    where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
     161                                    select pair.EstimatedValuesEnumerator;
     162          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     163        }
     164      }
     165    }
     166
     167    public override IEnumerable<double> EstimatedTestClassValues {
     168      get {
     169        var rows = ProblemData.TestIndizes;
     170        var estimatedValuesEnumerators = (from model in Model.Models
     171                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     172                                         .ToList();
     173        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
     174        // aggregate to make sure that MoveNext is called for all enumerators
     175        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     176          int currentRow = rowsEnumerator.Current;
     177
     178          var selectedEnumerators = from pair in estimatedValuesEnumerators
     179                                    where RowIsTestForModel(currentRow, pair.Model)
     180                                    select pair.EstimatedValuesEnumerator;
     181
     182          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     183        }
     184      }
     185    }
     186
     187    private bool RowIsTrainingForModel(int currentRow, IClassificationModel model) {
     188      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
     189              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
     190    }
     191
     192    private bool RowIsTestForModel(int currentRow, IClassificationModel model) {
     193      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     194              (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
     195    }
     196
     197    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     198      return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows)
     199             select AggregateEstimatedClassValues(xs);
     200    }
    60201
    61202    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
    62       var estimatedValuesEnumerators = (from model in models
     203      var estimatedValuesEnumerators = (from model in Model.Models
    63204                                        select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
    64205                                       .ToList();
     
    70211    }
    71212
     213    private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) {
     214      return estimatedClassValues
     215      .GroupBy(x => x)
     216      .OrderBy(g => -g.Count())
     217      .Select(g => g.Key)
     218      .DefaultIfEmpty(double.NaN)
     219      .First();
     220    }
    72221    #endregion
    73222
    74     #region IClassificationModel Members
    75 
    76     public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    77       foreach (var estimatedValuesVector in GetEstimatedClassValueVectors(dataset, rows)) {
    78         // return the class which is most often occuring
    79         yield return
    80           estimatedValuesVector
    81           .GroupBy(x => x)
    82           .OrderBy(g => -g.Count())
    83           .Select(g => g.Key)
    84           .First();
    85       }
    86     }
    87 
    88     #endregion
     223    protected override void OnProblemDataChanged() {
     224      IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset,
     225                                                                     ProblemData.AllowedInputVariables,
     226                                                                     ProblemData.TargetVariable);
     227      problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
     228      problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
     229      problemData.TestPartition.Start = ProblemData.TestPartition.Start;
     230      problemData.TestPartition.End = ProblemData.TestPartition.End;
     231
     232      foreach (var solution in ClassificationSolutions) {
     233        if (solution is ClassificationEnsembleSolution)
     234          solution.ProblemData = ProblemData;
     235        else
     236          solution.ProblemData = problemData;
     237      }
     238      foreach (var trainingPartition in trainingPartitions.Values) {
     239        trainingPartition.Start = ProblemData.TrainingPartition.Start;
     240        trainingPartition.End = ProblemData.TrainingPartition.End;
     241      }
     242      foreach (var testPartition in testPartitions.Values) {
     243        testPartition.Start = ProblemData.TestPartition.Start;
     244        testPartition.End = ProblemData.TestPartition.End;
     245      }
     246
     247      base.OnProblemDataChanged();
     248    }
     249
     250    public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
     251      classificationSolutions.AddRange(solutions);
     252    }
     253    public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) {
     254      classificationSolutions.RemoveRange(solutions);
     255    }
     256
     257    private void classificationSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     258      foreach (var solution in e.Items) AddClassificationSolution(solution);
     259      RecalculateResults();
     260    }
     261    private void classificationSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     262      foreach (var solution in e.Items) RemoveClassificationSolution(solution);
     263      RecalculateResults();
     264    }
     265    private void classificationSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) {
     266      foreach (var solution in e.OldItems) RemoveClassificationSolution(solution);
     267      foreach (var solution in e.Items) AddClassificationSolution(solution);
     268      RecalculateResults();
     269    }
     270
     271    private void AddClassificationSolution(IClassificationSolution solution) {
     272      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
     273      Model.Add(solution.Model);
     274      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
     275      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
     276    }
     277
     278    private void RemoveClassificationSolution(IClassificationSolution solution) {
     279      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
     280      Model.Remove(solution.Model);
     281      trainingPartitions.Remove(solution.Model);
     282      testPartitions.Remove(solution.Model);
     283    }
    89284  }
    90285}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblemData.cs

    r6232 r6760  
    3434  [Item("ClassificationProblemData", "Represents an item containing all data defining a classification problem.")]
    3535  public class ClassificationProblemData : DataAnalysisProblemData, IClassificationProblemData {
    36     private const string TargetVariableParameterName = "TargetVariable";
    37     private const string ClassNamesParameterName = "ClassNames";
    38     private const string ClassificationPenaltiesParameterName = "ClassificationPenalties";
    39     private const int MaximumNumberOfClasses = 20;
    40     private const int InspectedRowsToDetermineTargets = 500;
     36    protected const string TargetVariableParameterName = "TargetVariable";
     37    protected const string ClassNamesParameterName = "ClassNames";
     38    protected const string ClassificationPenaltiesParameterName = "ClassificationPenalties";
     39    protected const int MaximumNumberOfClasses = 20;
     40    protected const int InspectedRowsToDetermineTargets = 500;
    4141
    4242    #region default data
     
    171171     {1176881,7,5,3,7,4,10,7,5,5,4        }
    172172};
    173     private static Dataset defaultDataset;
    174     private static IEnumerable<string> defaultAllowedInputVariables;
    175     private static string defaultTargetVariable;
     173    private static readonly Dataset defaultDataset;
     174    private static readonly IEnumerable<string> defaultAllowedInputVariables;
     175    private static readonly string defaultTargetVariable;
     176
     177    private static readonly ClassificationProblemData emptyProblemData;
     178    public static ClassificationProblemData EmptyProblemData {
     179      get { return EmptyProblemData; }
     180    }
     181
    176182    static ClassificationProblemData() {
    177183      defaultDataset = new Dataset(defaultVariableNames, defaultData);
     
    181187      defaultAllowedInputVariables = defaultVariableNames.Except(new List<string>() { "sample", "class" });
    182188      defaultTargetVariable = "class";
     189
     190      var problemData = new ClassificationProblemData();
     191      problemData.Parameters.Clear();
     192      problemData.Name = "Empty Classification ProblemData";
     193      problemData.Description = "This ProblemData acts as place holder before the correct problem data is loaded.";
     194      problemData.isEmpty = true;
     195
     196      problemData.Parameters.Add(new FixedValueParameter<Dataset>(DatasetParameterName, "", new Dataset()));
     197      problemData.Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemList<StringValue>>(InputVariablesParameterName, ""));
     198      problemData.Parameters.Add(new FixedValueParameter<IntRange>(TrainingPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly()));
     199      problemData.Parameters.Add(new FixedValueParameter<IntRange>(TestPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly()));
     200      problemData.Parameters.Add(new ConstrainedValueParameter<StringValue>(TargetVariableParameterName, new ItemSet<StringValue>()));
     201      problemData.Parameters.Add(new FixedValueParameter<StringMatrix>(ClassNamesParameterName, "", new StringMatrix(0, 0).AsReadOnly()));
     202      problemData.Parameters.Add(new FixedValueParameter<DoubleMatrix>(ClassificationPenaltiesParameterName, "", (DoubleMatrix)new DoubleMatrix(0, 0).AsReadOnly()));
     203      emptyProblemData = problemData;
    183204    }
    184205    #endregion
    185206
    186207    #region parameter properties
    187     public IValueParameter<StringValue> TargetVariableParameter {
    188       get { return (IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
     208    public ConstrainedValueParameter<StringValue> TargetVariableParameter {
     209      get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
    189210    }
    190211    public IFixedValueParameter<StringMatrix> ClassNamesParameter {
     
    205226      get {
    206227        if (classValues == null) {
    207           classValues = Dataset.GetEnumeratedVariableValues(TargetVariableParameter.Value.Value).Distinct().ToList();
     228          classValues = Dataset.GetDoubleValues(TargetVariableParameter.Value.Value).Distinct().ToList();
    208229          classValues.Sort();
    209230        }
     
    249270      RegisterParameterEvents();
    250271    }
    251     public override IDeepCloneable Clone(Cloner cloner) { return new ClassificationProblemData(this, cloner); }
     272    public override IDeepCloneable Clone(Cloner cloner) {
     273      if (this == emptyProblemData) return emptyProblemData;
     274      return new ClassificationProblemData(this, cloner);
     275    }
    252276
    253277    public ClassificationProblemData() : this(defaultDataset, defaultAllowedInputVariables, defaultTargetVariable) { }
     
    267291    private static IEnumerable<string> CheckVariablesForPossibleTargetVariables(Dataset dataset) {
    268292      int maxSamples = Math.Min(InspectedRowsToDetermineTargets, dataset.Rows);
    269       var validTargetVariables = from v in dataset.VariableNames
    270                                  let DistinctValues = dataset.GetVariableValues(v)
    271                                    .Take(maxSamples)
    272                                    .Distinct()
    273                                    .Count()
    274                                  where DistinctValues < MaximumNumberOfClasses
    275                                  select v;
     293      var validTargetVariables = (from v in dataset.DoubleVariables
     294                                  let distinctValues = dataset.GetDoubleValues(v)
     295                                    .Take(maxSamples)
     296                                    .Distinct()
     297                                    .Count()
     298                                  where distinctValues < MaximumNumberOfClasses
     299                                  select v).ToArray();
    276300
    277301      if (!validTargetVariables.Any())
     
    283307
    284308    private void ResetTargetVariableDependentMembers() {
    285       DergisterParameterEvents();
     309      DeregisterParameterEvents();
    286310
    287311      classNames = null;
     
    357381      ClassificationPenaltiesParameter.Value.ItemChanged += new EventHandler<EventArgs<int, int>>(MatrixParameter_ItemChanged);
    358382    }
    359     private void DergisterParameterEvents() {
     383    private void DeregisterParameterEvents() {
    360384      TargetVariableParameter.ValueChanged -= new EventHandler(TargetVariableParameter_ValueChanged);
    361385      ClassNamesParameter.Value.Reset -= new EventHandler(Parameter_ValueChanged);
     
    386410      dataset.Name = Path.GetFileName(fileName);
    387411
    388       ClassificationProblemData problemData = new ClassificationProblemData(dataset, dataset.VariableNames.Skip(1), dataset.VariableNames.First());
     412      ClassificationProblemData problemData = new ClassificationProblemData(dataset, dataset.DoubleVariables.Skip(1), dataset.DoubleVariables.First());
    389413      problemData.Name = "Data imported from " + Path.GetFileName(fileName);
    390414      return problemData;
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs

    r6184 r6760  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
    2524using HeuristicLab.Common;
    26 using HeuristicLab.Data;
    27 using HeuristicLab.Optimization;
    2825using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2926
     
    3330  /// </summary>
    3431  [StorableClass]
    35   public class ClassificationSolution : DataAnalysisSolution, IClassificationSolution {
    36     private const string TrainingAccuracyResultName = "Accuracy (training)";
    37     private const string TestAccuracyResultName = "Accuracy (test)";
    38 
    39     public new IClassificationModel Model {
    40       get { return (IClassificationModel)base.Model; }
    41       protected set { base.Model = value; }
    42     }
    43 
    44     public new IClassificationProblemData ProblemData {
    45       get { return (IClassificationProblemData)base.ProblemData; }
    46       protected set { base.ProblemData = value; }
    47     }
    48 
    49     public double TrainingAccuracy {
    50       get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; }
    51       private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; }
    52     }
    53 
    54     public double TestAccuracy {
    55       get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; }
    56       private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; }
    57     }
     32  public abstract class ClassificationSolution : ClassificationSolutionBase {
     33    protected readonly Dictionary<int, double> evaluationCache;
    5834
    5935    [StorableConstructor]
    60     protected ClassificationSolution(bool deserializing) : base(deserializing) { }
     36    protected ClassificationSolution(bool deserializing)
     37      : base(deserializing) {
     38      evaluationCache = new Dictionary<int, double>();
     39    }
    6140    protected ClassificationSolution(ClassificationSolution original, Cloner cloner)
    6241      : base(original, cloner) {
     42      evaluationCache = new Dictionary<int, double>(original.evaluationCache);
    6343    }
    6444    public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData)
    6545      : base(model, problemData) {
    66       Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));
    67       Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));
    68       RecalculateResults();
     46      evaluationCache = new Dictionary<int, double>();
    6947    }
    7048
    71     public override IDeepCloneable Clone(Cloner cloner) {
    72       return new ClassificationSolution(this, cloner);
     49    public override IEnumerable<double> EstimatedClassValues {
     50      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     51    }
     52    public override IEnumerable<double> EstimatedTrainingClassValues {
     53      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
     54    }
     55    public override IEnumerable<double> EstimatedTestClassValues {
     56      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
    7357    }
    7458
    75     protected override void OnProblemDataChanged(EventArgs e) {
    76       base.OnProblemDataChanged(e);
    77       RecalculateResults();
     59    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     60      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
     61      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     62      var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     63
     64      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     65        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     66      }
     67
     68      return rows.Select(row => evaluationCache[row]);
    7869    }
    7970
    80     protected override void OnModelChanged(EventArgs e) {
    81       base.OnModelChanged(e);
    82       RecalculateResults();
     71    protected override void OnProblemDataChanged() {
     72      evaluationCache.Clear();
     73      base.OnProblemDataChanged();
    8374    }
    8475
    85     protected void RecalculateResults() {
    86       double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values
    87       IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    88       double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values
    89       IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    90 
    91       OnlineCalculatorError errorState;
    92       double trainingAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues, out errorState);
    93       if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN;
    94       double testAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTestClassValues, originalTestClassValues, out errorState);
    95       if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN;
    96 
    97       TrainingAccuracy = trainingAccuracy;
    98       TestAccuracy = testAccuracy;
    99     }
    100 
    101     public virtual IEnumerable<double> EstimatedClassValues {
    102       get {
    103         return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
    104       }
    105     }
    106 
    107     public virtual IEnumerable<double> EstimatedTrainingClassValues {
    108       get {
    109         return GetEstimatedClassValues(ProblemData.TrainingIndizes);
    110       }
    111     }
    112 
    113     public virtual IEnumerable<double> EstimatedTestClassValues {
    114       get {
    115         return GetEstimatedClassValues(ProblemData.TestIndizes);
    116       }
    117     }
    118 
    119     public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    120       return Model.GetEstimatedClassValues(ProblemData.Dataset, rows);
     76    protected override void OnModelChanged() {
     77      evaluationCache.Clear();
     78      base.OnModelChanged();
    12179    }
    12280  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs

    r5809 r6760  
    3333  [StorableClass]
    3434  [Item("DiscriminantFunctionClassificationModel", "Represents a classification model that uses a discriminant function and classification thresholds.")]
    35   public class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel {
     35  public abstract class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel {
    3636    [Storable]
    3737    private IRegressionModel model;
     
    7070    }
    7171
    72     public override IDeepCloneable Clone(Cloner cloner) {
    73       return new DiscriminantFunctionClassificationModel(this, cloner);
    74     }
    75 
    7672    public void SetThresholdsAndClassValues(IEnumerable<double> thresholds, IEnumerable<double> classValues) {
    7773      var classValuesArr = classValues.ToArray();
     
    106102    }
    107103    #endregion
     104
     105    public abstract IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData);
     106    public abstract IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData);
    108107  }
    109108}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs

    r5942 r6760  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
     
    2625using HeuristicLab.Core;
    2726using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    28 using HeuristicLab.Data;
    29 using HeuristicLab.Optimization;
    3027
    3128namespace HeuristicLab.Problems.DataAnalysis {
     
    3532  [StorableClass]
    3633  [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")]
    37   public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution {
    38     private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
    39     private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
    40     private const string TrainingRSquaredResultName = "Pearson's R² (training)";
    41     private const string TestRSquaredResultName = "Pearson's R² (test)";
     34  public abstract class DiscriminantFunctionClassificationSolution : DiscriminantFunctionClassificationSolutionBase {
     35    protected readonly Dictionary<int, double> valueEvaluationCache;
     36    protected readonly Dictionary<int, double> classValueEvaluationCache;
    4237
    43     public new IDiscriminantFunctionClassificationModel Model {
    44       get { return (IDiscriminantFunctionClassificationModel)base.Model; }
    45       protected set {
    46         if (value != null && value != Model) {
    47           if (Model != null) {
    48             Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
    49           }
    50           value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
    51           base.Model = value;
    52         }
    53       }
     38    [StorableConstructor]
     39    protected DiscriminantFunctionClassificationSolution(bool deserializing)
     40      : base(deserializing) {
     41      valueEvaluationCache = new Dictionary<int, double>();
     42      classValueEvaluationCache = new Dictionary<int, double>();
     43    }
     44    protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
     45      : base(original, cloner) {
     46      valueEvaluationCache = new Dictionary<int, double>(original.valueEvaluationCache);
     47      classValueEvaluationCache = new Dictionary<int, double>(original.classValueEvaluationCache);
     48    }
     49    protected DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
     50      : base(model, problemData) {
     51      valueEvaluationCache = new Dictionary<int, double>();
     52      classValueEvaluationCache = new Dictionary<int, double>();
     53
     54      SetAccuracyMaximizingThresholds();
    5455    }
    5556
    56     public double TrainingMeanSquaredError {
    57       get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    58       private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
     57    public override IEnumerable<double> EstimatedClassValues {
     58      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     59    }
     60    public override IEnumerable<double> EstimatedTrainingClassValues {
     61      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
     62    }
     63    public override IEnumerable<double> EstimatedTestClassValues {
     64      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
    5965    }
    6066
    61     public double TestMeanSquaredError {
    62       get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    63       private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
     67    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     68      var rowsToEvaluate = rows.Except(classValueEvaluationCache.Keys);
     69      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     70      var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     71
     72      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     73        classValueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     74      }
     75
     76      return rows.Select(row => classValueEvaluationCache[row]);
    6477    }
    6578
    66     public double TrainingRSquared {
    67       get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }
    68       private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }
    69     }
    7079
    71     public double TestRSquared {
    72       get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; }
    73       private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; }
    74     }
    75 
    76     [StorableConstructor]
    77     protected DiscriminantFunctionClassificationSolution(bool deserializing) : base(deserializing) { }
    78     protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
    79       : base(original, cloner) {
    80       RegisterEventHandler();
    81     }
    82     public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)
    83       : this(new DiscriminantFunctionClassificationModel(model), problemData) {
    84     }
    85     public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
    86       : base(model, problemData) {
    87       Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
    88       Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
    89       Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
    90       Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
    91       RegisterEventHandler();
    92       SetAccuracyMaximizingThresholds();
    93       RecalculateResults();
    94     }
    95 
    96     [StorableHook(HookType.AfterDeserialization)]
    97     private void AfterDeserialization() {
    98       RegisterEventHandler();
    99     }
    100 
    101     protected new void RecalculateResults() {
    102       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    103       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    104       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    105       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    106 
    107       OnlineCalculatorError errorState;
    108       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    109       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    110       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    111       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    112 
    113       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    114       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    115       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    116       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    117     }
    118 
    119     private void RegisterEventHandler() {
    120       Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
    121     }
    122     private void Model_ThresholdsChanged(object sender, EventArgs e) {
    123       OnModelThresholdsChanged(e);
    124     }
    125 
    126     public void SetAccuracyMaximizingThresholds() {
    127       double[] classValues;
    128       double[] thresholds;
    129       var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    130       AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    131 
    132       Model.SetThresholdsAndClassValues(thresholds, classValues);
    133     }
    134 
    135     public void SetClassDistibutionCutPointThresholds() {
    136       double[] classValues;
    137       double[] thresholds;
    138       var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    139       NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    140 
    141       Model.SetThresholdsAndClassValues(thresholds, classValues);
    142     }
    143 
    144     protected override void OnModelChanged(EventArgs e) {
    145       base.OnModelChanged(e);
    146       SetAccuracyMaximizingThresholds();
    147       RecalculateResults();
    148     }
    149 
    150     protected override void OnProblemDataChanged(EventArgs e) {
    151       base.OnProblemDataChanged(e);
    152       SetAccuracyMaximizingThresholds();
    153       RecalculateResults();
    154     }
    155     protected virtual void OnModelThresholdsChanged(EventArgs e) {
    156       base.OnModelChanged(e);
    157       RecalculateResults();
    158     }
    159 
    160     public IEnumerable<double> EstimatedValues {
     80    public override IEnumerable<double> EstimatedValues {
    16181      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
    16282    }
    163 
    164     public IEnumerable<double> EstimatedTrainingValues {
     83    public override IEnumerable<double> EstimatedTrainingValues {
    16584      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
    16685    }
    167 
    168     public IEnumerable<double> EstimatedTestValues {
     86    public override IEnumerable<double> EstimatedTestValues {
    16987      get { return GetEstimatedValues(ProblemData.TestIndizes); }
    17088    }
    17189
    172     public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    173       return Model.GetEstimatedValues(ProblemData.Dataset, rows);
     90    public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
     91      var rowsToEvaluate = rows.Except(valueEvaluationCache.Keys);
     92      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     93      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     94
     95      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     96        valueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     97      }
     98
     99      return rows.Select(row => valueEvaluationCache[row]);
     100    }
     101
     102    protected override void OnModelChanged() {
     103      valueEvaluationCache.Clear();
     104      classValueEvaluationCache.Clear();
     105      base.OnModelChanged();
     106    }
     107    protected override void OnModelThresholdsChanged(System.EventArgs e) {
     108      classValueEvaluationCache.Clear();
     109      base.OnModelThresholdsChanged(e);
     110    }
     111    protected override void OnProblemDataChanged() {
     112      valueEvaluationCache.Clear();
     113      classValueEvaluationCache.Clear();
     114      base.OnProblemDataChanged();
    174115    }
    175116  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Clustering/ClusteringProblemData.cs

    r6228 r6760  
    9595      dataset.Name = Path.GetFileName(fileName);
    9696
    97       ClusteringProblemData problemData = new ClusteringProblemData(dataset, dataset.VariableNames);
     97      ClusteringProblemData problemData = new ClusteringProblemData(dataset, dataset.DoubleVariables);
    9898      problemData.Name = "Data imported from " + Path.GetFileName(fileName);
    9999      return problemData;
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Clustering/ClusteringSolution.cs

    r6184 r6760  
    4545    }
    4646
     47    protected override void RecalculateResults() {
     48    }
     49
    4750    #region IClusteringSolution Members
    4851
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisProblem.cs

    r5809 r6760  
    4848    public T ProblemData {
    4949      get { return ProblemDataParameter.Value; }
    50       protected set { ProblemDataParameter.Value = value; }
     50      protected set {
     51        ProblemDataParameter.Value = value;
     52      }
    5153    }
    5254    #endregion
    5355    protected DataAnalysisProblem(DataAnalysisProblem<T> original, Cloner cloner)
    5456      : base(original, cloner) {
     57      RegisterEventHandlers();
    5558    }
    5659    [StorableConstructor]
     
    5962      : base() {
    6063      Parameters.Add(new ValueParameter<T>(ProblemDataParameterName, ProblemDataParameterDescription));
     64      RegisterEventHandlers();
     65    }
     66
     67    [StorableHook(HookType.AfterDeserialization)]
     68    private void AfterDeserialization() {
     69      RegisterEventHandlers();
    6170    }
    6271
    6372    private void RegisterEventHandlers() {
    64       ProblemDataParameter.Value.Changed += new EventHandler(ProblemDataParameter_ValueChanged);
     73      ProblemDataParameter.ValueChanged += new EventHandler(ProblemDataParameter_ValueChanged);
     74      if (ProblemDataParameter.Value != null) ProblemDataParameter.Value.Changed += new EventHandler(ProblemData_Changed);
    6575    }
     76
    6677    private void ProblemDataParameter_ValueChanged(object sender, EventArgs e) {
     78      ProblemDataParameter.Value.Changed += new EventHandler(ProblemData_Changed);
    6779      OnProblemDataChanged();
     80      OnReset();
     81    }
     82
     83    private void ProblemData_Changed(object sender, EventArgs e) {
    6884      OnReset();
    6985    }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisProblemData.cs

    r5847 r6760  
    3333  [StorableClass]
    3434  public abstract class DataAnalysisProblemData : ParameterizedNamedItem, IDataAnalysisProblemData {
    35     private const string DatasetParameterName = "Dataset";
    36     private const string InputVariablesParameterName = "InputVariables";
    37     private const string TrainingPartitionParameterName = "TrainingPartition";
    38     private const string TestPartitionParameterName = "TestPartition";
     35    protected const string DatasetParameterName = "Dataset";
     36    protected const string InputVariablesParameterName = "InputVariables";
     37    protected const string TrainingPartitionParameterName = "TrainingPartition";
     38    protected const string TestPartitionParameterName = "TestPartition";
    3939
    4040    #region parameter properites
     
    5353    #endregion
    5454
    55     #region propeties
     55    #region properties
     56    protected bool isEmpty = false;
     57    public bool IsEmpty {
     58      get { return isEmpty; }
     59    }
    5660    public Dataset Dataset {
    5761      get { return DatasetParameter.Value; }
     
    7175    }
    7276
    73     public IEnumerable<int> TrainingIndizes {
     77    public virtual IEnumerable<int> TrainingIndizes {
    7478      get {
    7579        return Enumerable.Range(TrainingPartition.Start, TrainingPartition.End - TrainingPartition.Start)
    76                          .Where(i => i >= 0 && i < Dataset.Rows && (i < TestPartition.Start || TestPartition.End <= i));
     80                         .Where(IsTrainingSample);
    7781      }
    7882    }
    79     public IEnumerable<int> TestIndizes {
     83    public virtual IEnumerable<int> TestIndizes {
    8084      get {
    8185        return Enumerable.Range(TestPartition.Start, TestPartition.End - TestPartition.Start)
    82            .Where(i => i >= 0 && i < Dataset.Rows);
     86           .Where(IsTestSample);
    8387      }
     88    }
     89
     90    public virtual bool IsTrainingSample(int index) {
     91      return index >= 0 && index < Dataset.Rows &&
     92        TrainingPartition.Start <= index && index < TrainingPartition.End &&
     93        (index < TestPartition.Start || TestPartition.End <= index);
     94    }
     95
     96    public virtual bool IsTestSample(int index) {
     97      return index >= 0 && index < Dataset.Rows &&
     98             TestPartition.Start <= index && index < TestPartition.End;
    8499    }
    85100    #endregion
    86101
    87     protected DataAnalysisProblemData(DataAnalysisProblemData original, Cloner cloner) : base(original, cloner) { }
     102    protected DataAnalysisProblemData(DataAnalysisProblemData original, Cloner cloner)
     103      : base(original, cloner) {
     104      isEmpty = original.isEmpty;
     105      RegisterEventHandlers();
     106    }
    88107    [StorableConstructor]
    89108    protected DataAnalysisProblemData(bool deserializing) : base(deserializing) { }
     109    [StorableHook(HookType.AfterDeserialization)]
     110    private void AfterDeserialization() {
     111      RegisterEventHandlers();
     112    }
    90113
    91114    protected DataAnalysisProblemData(Dataset dataset, IEnumerable<string> allowedInputVariables) {
     
    93116      if (allowedInputVariables == null) throw new ArgumentNullException("The allowedInputVariables must not be null.");
    94117
    95       if (allowedInputVariables.Except(dataset.VariableNames).Any())
    96         throw new ArgumentException("All allowed input variables must be present in the dataset.");
     118      if (allowedInputVariables.Except(dataset.DoubleVariables).Any())
     119        throw new ArgumentException("All allowed input variables must be present in the dataset and of type double.");
    97120
    98       var inputVariables = new CheckedItemList<StringValue>(dataset.VariableNames.Select(x => new StringValue(x)));
     121      var inputVariables = new CheckedItemList<StringValue>(dataset.DoubleVariables.Select(x => new StringValue(x)));
    99122      foreach (StringValue x in inputVariables)
    100123        inputVariables.SetItemCheckedState(x, allowedInputVariables.Contains(x.Value));
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisSolution.cs

    r5914 r6760  
    4848          if (value != null) {
    4949            this[ModelResultName].Value = value;
    50             OnModelChanged(EventArgs.Empty);
     50            OnModelChanged();
    5151          }
    5252        }
     
    5656    public IDataAnalysisProblemData ProblemData {
    5757      get { return (IDataAnalysisProblemData)this[ProblemDataResultName].Value; }
    58       protected set {
     58      set {
    5959        if (this[ProblemDataResultName].Value != value) {
    6060          if (value != null) {
     
    6262            this[ProblemDataResultName].Value = value;
    6363            ProblemData.Changed += new EventHandler(ProblemData_Changed);
    64             OnProblemDataChanged(EventArgs.Empty);
     64            OnProblemDataChanged();
    6565          }
    6666        }
     
    8080      name = ItemName;
    8181      description = ItemDescription;
    82       Add(new Result(ModelResultName, "The symbolic data analysis model.", model));
    83       Add(new Result(ProblemDataResultName, "The symbolic data analysis problem data.", problemData));
     82      Add(new Result(ModelResultName, "The data analysis model.", model));
     83      Add(new Result(ProblemDataResultName, "The data analysis problem data.", problemData));
    8484
    8585      problemData.Changed += new EventHandler(ProblemData_Changed);
    8686    }
    8787
     88    protected abstract void RecalculateResults();
     89
    8890    private void ProblemData_Changed(object sender, EventArgs e) {
    89       OnProblemDataChanged(e);
     91      OnProblemDataChanged();
    9092    }
    9193
    9294    public event EventHandler ModelChanged;
    93     protected virtual void OnModelChanged(EventArgs e) {
     95    protected virtual void OnModelChanged() {
     96      RecalculateResults();
    9497      var listeners = ModelChanged;
    95       if (listeners != null) listeners(this, e);
     98      if (listeners != null) listeners(this, EventArgs.Empty);
    9699    }
    97100
    98101    public event EventHandler ProblemDataChanged;
    99     protected virtual void OnProblemDataChanged(EventArgs e) {
     102    protected virtual void OnProblemDataChanged() {
     103      RecalculateResults();
    100104      var listeners = ProblemDataChanged;
    101       if (listeners != null) listeners(this, e);
     105      if (listeners != null) listeners(this, EventArgs.Empty);
    102106    }
    103107
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r5809 r6760  
    3434  public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
    3535
    36     [Storable]
    3736    private List<IRegressionModel> models;
    3837    public IEnumerable<IRegressionModel> Models {
    3938      get { return new List<IRegressionModel>(models); }
    4039    }
     40
     41    [Storable(Name = "Models")]
     42    private IEnumerable<IRegressionModel> StorableModels {
     43      get { return models; }
     44      set { models = value.ToList(); }
     45    }
     46
     47    #region backwards compatiblity 3.3.5
     48    [Storable(Name = "models", AllowOneWay = true)]
     49    private List<IRegressionModel> OldStorableModels {
     50      set { models = value; }
     51    }
     52    #endregion
     53
    4154    [StorableConstructor]
    4255    protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
     
    4558      this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
    4659    }
     60
     61    public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { }
    4762    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models)
    4863      : base() {
     
    5772
    5873    #region IRegressionEnsembleModel Members
     74
     75    public void Add(IRegressionModel model) {
     76      models.Add(model);
     77    }
     78    public void Remove(IRegressionModel model) {
     79      models.Remove(model);
     80    }
    5981
    6082    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(Dataset dataset, IEnumerable<int> rows) {
     
    79101    }
    80102
     103    public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     104      return new RegressionEnsembleSolution(this.Models, problemData);
     105    }
     106    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
     107      return CreateRegressionSolution(problemData);
     108    }
     109
    81110    #endregion
    82111  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r6184 r6760  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     25using HeuristicLab.Collections;
    2426using HeuristicLab.Common;
    2527using HeuristicLab.Core;
     28using HeuristicLab.Data;
    2629using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    27 using System;
    28 using HeuristicLab.Data;
    2930
    3031namespace HeuristicLab.Problems.DataAnalysis {
     
    3435  [StorableClass]
    3536  [Item("Regression Ensemble Solution", "A regression solution that contains an ensemble of multiple regression models")]
    36   // [Creatable("Data Analysis")]
    37   public class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {
     37  [Creatable("Data Analysis - Ensembles")]
     38  public sealed class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {
    3839    public new IRegressionEnsembleModel Model {
    3940      get { return (IRegressionEnsembleModel)base.Model; }
     41    }
     42
     43    public new RegressionEnsembleProblemData ProblemData {
     44      get { return (RegressionEnsembleProblemData)base.ProblemData; }
     45      set { base.ProblemData = value; }
     46    }
     47
     48    private readonly ItemCollection<IRegressionSolution> regressionSolutions;
     49    public IItemCollection<IRegressionSolution> RegressionSolutions {
     50      get { return regressionSolutions; }
    4051    }
    4152
     
    4657
    4758    [StorableConstructor]
    48     protected RegressionEnsembleSolution(bool deserializing) : base(deserializing) { }
    49     protected RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner)
     59    private RegressionEnsembleSolution(bool deserializing)
     60      : base(deserializing) {
     61      regressionSolutions = new ItemCollection<IRegressionSolution>();
     62    }
     63    [StorableHook(HookType.AfterDeserialization)]
     64    private void AfterDeserialization() {
     65      foreach (var model in Model.Models) {
     66        IRegressionProblemData problemData = (IRegressionProblemData) ProblemData.Clone();
     67        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
     68        problemData.TrainingPartition.End = trainingPartitions[model].End;
     69        problemData.TestPartition.Start = testPartitions[model].Start;
     70        problemData.TestPartition.End = testPartitions[model].End;
     71
     72        regressionSolutions.Add(model.CreateRegressionSolution(problemData));
     73      }
     74      RegisterRegressionSolutionsEventHandler();
     75    }
     76
     77    private RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner)
    5078      : base(original, cloner) {
    51     }
    52     public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
    53       : base(new RegressionEnsembleModel(models), problemData) {
    5479      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    5580      testPartitions = new Dictionary<IRegressionModel, IntRange>();
    56       foreach (var model in models) {
    57         trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone();
    58         testPartitions[model] = (IntRange)problemData.TestPartition.Clone();
    59       }
    60       RecalculateResults();
    61     }
     81      foreach (var pair in original.trainingPartitions) {
     82        trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     83      }
     84      foreach (var pair in original.testPartitions) {
     85        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     86      }
     87
     88      regressionSolutions = cloner.Clone(original.regressionSolutions);
     89      RegisterRegressionSolutionsEventHandler();
     90    }
     91
     92    public RegressionEnsembleSolution()
     93      : base(new RegressionEnsembleModel(), RegressionEnsembleProblemData.EmptyProblemData) {
     94      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
     95      testPartitions = new Dictionary<IRegressionModel, IntRange>();
     96      regressionSolutions = new ItemCollection<IRegressionSolution>();
     97
     98      RegisterRegressionSolutionsEventHandler();
     99    }
     100
     101    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
     102      : this(models, problemData,
     103             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
     104             models.Select(m => (IntRange)problemData.TestPartition.Clone())
     105      ) { }
    62106
    63107    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
    64       : base(new RegressionEnsembleModel(models), problemData) {
     108      : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) {
    65109      this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    66110      this.testPartitions = new Dictionary<IRegressionModel, IntRange>();
     111      this.regressionSolutions = new ItemCollection<IRegressionSolution>();
     112
     113      List<IRegressionSolution> solutions = new List<IRegressionSolution>();
    67114      var modelEnumerator = models.GetEnumerator();
    68115      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
    69116      var testPartitionEnumerator = testPartitions.GetEnumerator();
     117
    70118      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
    71         this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone();
    72         this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone();
     119        var p = (IRegressionProblemData)problemData.Clone();
     120        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
     121        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
     122        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
     123        p.TestPartition.End = testPartitionEnumerator.Current.End;
     124
     125        solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p));
    73126      }
    74127      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
     
    76129      }
    77130
    78       RecalculateResults();
    79     }
    80 
    81     private void RecalculateResults() {
    82       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    83       var trainingIndizes = Enumerable.Range(ProblemData.TrainingPartition.Start,
    84         ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);
    85       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, trainingIndizes);
    86       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    87       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    88 
    89       OnlineCalculatorError errorState;
    90       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    91       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    92       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    93       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    94 
    95       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    96       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    97       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    98       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    99 
    100       double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    101       TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;
    102       double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    103       TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;
    104 
    105       double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    106       TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN;
    107       double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    108       TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN;
     131      RegisterRegressionSolutionsEventHandler();
     132      regressionSolutions.AddRange(solutions);
    109133    }
    110134
     
    112136      return new RegressionEnsembleSolution(this, cloner);
    113137    }
    114 
     138    private void RegisterRegressionSolutionsEventHandler() {
     139      regressionSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsAdded);
     140      regressionSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsRemoved);
     141      regressionSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_CollectionReset);
     142    }
     143
     144    protected override void RecalculateResults() {
     145      CalculateResults();
     146    }
     147
     148    #region Evaluation
    115149    public override IEnumerable<double> EstimatedTrainingValues {
    116150      get {
    117         var rows = Enumerable.Range(ProblemData.TrainingPartition.Start, ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);
     151        var rows = ProblemData.TrainingIndizes;
    118152        var estimatedValuesEnumerators = (from model in Model.Models
    119153                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    120154                                         .ToList();
    121155        var rowsEnumerator = rows.GetEnumerator();
     156        // aggregate to make sure that MoveNext is called for all enumerators
    122157        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    123158          int currentRow = rowsEnumerator.Current;
    124159
    125160          var selectedEnumerators = from pair in estimatedValuesEnumerators
    126                                     where trainingPartitions == null || !trainingPartitions.ContainsKey(pair.Model) ||
    127                                          (trainingPartitions[pair.Model].Start <= currentRow && currentRow < trainingPartitions[pair.Model].End)
     161                                    where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
    128162                                    select pair.EstimatedValuesEnumerator;
    129163          yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
     
    134168    public override IEnumerable<double> EstimatedTestValues {
    135169      get {
     170        var rows = ProblemData.TestIndizes;
    136171        var estimatedValuesEnumerators = (from model in Model.Models
    137                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, ProblemData.TestIndizes).GetEnumerator() })
     172                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    138173                                         .ToList();
    139174        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
     175        // aggregate to make sure that MoveNext is called for all enumerators
    140176        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    141177          int currentRow = rowsEnumerator.Current;
    142178
    143179          var selectedEnumerators = from pair in estimatedValuesEnumerators
    144                                     where testPartitions == null || !testPartitions.ContainsKey(pair.Model) ||
    145                                       (testPartitions[pair.Model].Start <= currentRow && currentRow < testPartitions[pair.Model].End)
     180                                    where RowIsTestForModel(currentRow, pair.Model)
    146181                                    select pair.EstimatedValuesEnumerator;
    147182
     
    149184        }
    150185      }
     186    }
     187
     188    private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) {
     189      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
     190              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
     191    }
     192
     193    private bool RowIsTestForModel(int currentRow, IRegressionModel model) {
     194      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     195              (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
    151196    }
    152197
     
    168213
    169214    private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) {
    170       return estimatedValues.Average();
    171     }
    172 
    173     //[Storable]
    174     //private string name;
    175     //public string Name {
    176     //  get {
    177     //    return name;
    178     //  }
    179     //  set {
    180     //    if (value != null && value != name) {
    181     //      var cancelEventArgs = new CancelEventArgs<string>(value);
    182     //      OnNameChanging(cancelEventArgs);
    183     //      if (cancelEventArgs.Cancel == false) {
    184     //        name = value;
    185     //        OnNamedChanged(EventArgs.Empty);
    186     //      }
    187     //    }
    188     //  }
    189     //}
    190 
    191     //public bool CanChangeName {
    192     //  get { return true; }
    193     //}
    194 
    195     //[Storable]
    196     //private string description;
    197     //public string Description {
    198     //  get {
    199     //    return description;
    200     //  }
    201     //  set {
    202     //    if (value != null && value != description) {
    203     //      description = value;
    204     //      OnDescriptionChanged(EventArgs.Empty);
    205     //    }
    206     //  }
    207     //}
    208 
    209     //public bool CanChangeDescription {
    210     //  get { return true; }
    211     //}
    212 
    213     //#region events
    214     //public event EventHandler<CancelEventArgs<string>> NameChanging;
    215     //private void OnNameChanging(CancelEventArgs<string> cancelEventArgs) {
    216     //  var listener = NameChanging;
    217     //  if (listener != null) listener(this, cancelEventArgs);
    218     //}
    219 
    220     //public event EventHandler NameChanged;
    221     //private void OnNamedChanged(EventArgs e) {
    222     //  var listener = NameChanged;
    223     //  if (listener != null) listener(this, e);
    224     //}
    225 
    226     //public event EventHandler DescriptionChanged;
    227     //private void OnDescriptionChanged(EventArgs e) {
    228     //  var listener = DescriptionChanged;
    229     //  if (listener != null) listener(this, e);
    230     //}
    231     // #endregion
     215      return estimatedValues.DefaultIfEmpty(double.NaN).Average();
     216    }
     217    #endregion
     218
     219    protected override void OnProblemDataChanged() {
     220      IRegressionProblemData problemData = new RegressionProblemData(ProblemData.Dataset,
     221                                                                     ProblemData.AllowedInputVariables,
     222                                                                     ProblemData.TargetVariable);
     223      problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
     224      problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
     225      problemData.TestPartition.Start = ProblemData.TestPartition.Start;
     226      problemData.TestPartition.End = ProblemData.TestPartition.End;
     227
     228      foreach (var solution in RegressionSolutions) {
     229        if (solution is RegressionEnsembleSolution)
     230          solution.ProblemData = ProblemData;
     231        else
     232          solution.ProblemData = problemData;
     233      }
     234      foreach (var trainingPartition in trainingPartitions.Values) {
     235        trainingPartition.Start = ProblemData.TrainingPartition.Start;
     236        trainingPartition.End = ProblemData.TrainingPartition.End;
     237      }
     238      foreach (var testPartition in testPartitions.Values) {
     239        testPartition.Start = ProblemData.TestPartition.Start;
     240        testPartition.End = ProblemData.TestPartition.End;
     241      }
     242
     243      base.OnProblemDataChanged();
     244    }
     245
     246    public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     247      regressionSolutions.AddRange(solutions);
     248    }
     249    public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     250      regressionSolutions.RemoveRange(solutions);
     251    }
     252
     253    private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
     254      foreach (var solution in e.Items) AddRegressionSolution(solution);
     255      RecalculateResults();
     256    }
     257    private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
     258      foreach (var solution in e.Items) RemoveRegressionSolution(solution);
     259      RecalculateResults();
     260    }
     261    private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
     262      foreach (var solution in e.OldItems) RemoveRegressionSolution(solution);
     263      foreach (var solution in e.Items) AddRegressionSolution(solution);
     264      RecalculateResults();
     265    }
     266
     267    private void AddRegressionSolution(IRegressionSolution solution) {
     268      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
     269      Model.Add(solution.Model);
     270      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
     271      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
     272    }
     273
     274    private void RemoveRegressionSolution(IRegressionSolution solution) {
     275      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
     276      Model.Remove(solution.Model);
     277      trainingPartitions.Remove(solution.Model);
     278      testPartitions.Remove(solution.Model);
     279    }
    232280  }
    233281}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs

    r5809 r6760  
    3333  [StorableClass]
    3434  [Item("RegressionProblemData", "Represents an item containing all data defining a regression problem.")]
    35   public sealed class RegressionProblemData : DataAnalysisProblemData, IRegressionProblemData {
    36     private const string TargetVariableParameterName = "TargetVariable";
     35  public class RegressionProblemData : DataAnalysisProblemData, IRegressionProblemData {
     36    protected const string TargetVariableParameterName = "TargetVariable";
    3737
    3838    #region default data
     
    6464          {0.83763905,  0.468046718}
    6565    };
    66     private static Dataset defaultDataset;
    67     private static IEnumerable<string> defaultAllowedInputVariables;
    68     private static string defaultTargetVariable;
     66    private static readonly Dataset defaultDataset;
     67    private static readonly IEnumerable<string> defaultAllowedInputVariables;
     68    private static readonly string defaultTargetVariable;
     69
     70    private static readonly RegressionProblemData emptyProblemData;
     71    public static RegressionProblemData EmptyProblemData {
     72      get { return emptyProblemData; }
     73    }
    6974
    7075    static RegressionProblemData() {
     
    7479      defaultAllowedInputVariables = new List<string>() { "x" };
    7580      defaultTargetVariable = "y";
     81
     82      var problemData = new RegressionProblemData();
     83      problemData.Parameters.Clear();
     84      problemData.Name = "Empty Regression ProblemData";
     85      problemData.Description = "This ProblemData acts as place holder before the correct problem data is loaded.";
     86      problemData.isEmpty = true;
     87
     88      problemData.Parameters.Add(new FixedValueParameter<Dataset>(DatasetParameterName, "", new Dataset()));
     89      problemData.Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemList<StringValue>>(InputVariablesParameterName, ""));
     90      problemData.Parameters.Add(new FixedValueParameter<IntRange>(TrainingPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly()));
     91      problemData.Parameters.Add(new FixedValueParameter<IntRange>(TestPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly()));
     92      problemData.Parameters.Add(new ConstrainedValueParameter<StringValue>(TargetVariableParameterName, new ItemSet<StringValue>()));
     93      emptyProblemData = problemData;
    7694    }
    7795    #endregion
    7896
    79     public IValueParameter<StringValue> TargetVariableParameter {
    80       get { return (IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
     97    public ConstrainedValueParameter<StringValue> TargetVariableParameter {
     98      get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
    8199    }
    82100    public string TargetVariable {
     
    85103
    86104    [StorableConstructor]
    87     private RegressionProblemData(bool deserializing) : base(deserializing) { }
     105    protected RegressionProblemData(bool deserializing) : base(deserializing) { }
    88106    [StorableHook(HookType.AfterDeserialization)]
    89107    private void AfterDeserialization() {
     
    91109    }
    92110
    93 
    94     private RegressionProblemData(RegressionProblemData original, Cloner cloner)
     111    protected RegressionProblemData(RegressionProblemData original, Cloner cloner)
    95112      : base(original, cloner) {
    96113      RegisterParameterEvents();
    97114    }
    98     public override IDeepCloneable Clone(Cloner cloner) { return new RegressionProblemData(this, cloner); }
     115    public override IDeepCloneable Clone(Cloner cloner) {
     116      if (this == emptyProblemData) return emptyProblemData;
     117      return new RegressionProblemData(this, cloner);
     118    }
    99119
    100120    public RegressionProblemData()
     
    124144      dataset.Name = Path.GetFileName(fileName);
    125145
    126       RegressionProblemData problemData = new RegressionProblemData(dataset, dataset.VariableNames.Skip(1), dataset.VariableNames.First());
     146      RegressionProblemData problemData = new RegressionProblemData(dataset, dataset.DoubleVariables.Skip(1), dataset.DoubleVariables.First());
    127147      problemData.Name = "Data imported from " + Path.GetFileName(fileName);
    128148      return problemData;
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs

    r6184 r6760  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
    2524using HeuristicLab.Common;
    26 using HeuristicLab.Data;
    27 using HeuristicLab.Optimization;
    2825using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2926
     
    3330  /// </summary>
    3431  [StorableClass]
    35   public class RegressionSolution : DataAnalysisSolution, IRegressionSolution {
    36     private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
    37     private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
    38     private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)";
    39     private const string TestSquaredCorrelationResultName = "Pearson's R² (test)";
    40     private const string TrainingRelativeErrorResultName = "Average relative error (training)";
    41     private const string TestRelativeErrorResultName = "Average relative error (test)";
    42     private const string TrainingNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (training)";
    43     private const string TestNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (test)";
     32  public abstract class RegressionSolution : RegressionSolutionBase {
     33    protected readonly Dictionary<int, double> evaluationCache;
    4434
    45     public new IRegressionModel Model {
    46       get { return (IRegressionModel)base.Model; }
    47       protected set { base.Model = value; }
     35    [StorableConstructor]
     36    protected RegressionSolution(bool deserializing)
     37      : base(deserializing) {
     38      evaluationCache = new Dictionary<int, double>();
     39    }
     40    protected RegressionSolution(RegressionSolution original, Cloner cloner)
     41      : base(original, cloner) {
     42      evaluationCache = new Dictionary<int, double>(original.evaluationCache);
     43    }
     44    protected RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
     45      : base(model, problemData) {
     46      evaluationCache = new Dictionary<int, double>();
    4847    }
    4948
    50     public new IRegressionProblemData ProblemData {
    51       get { return (IRegressionProblemData)base.ProblemData; }
    52       protected set { base.ProblemData = value; }
     49    protected override void RecalculateResults() {
     50      CalculateResults();
    5351    }
    5452
    55     public double TrainingMeanSquaredError {
    56       get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    57       protected set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
     53    public override IEnumerable<double> EstimatedValues {
     54      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     55    }
     56    public override IEnumerable<double> EstimatedTrainingValues {
     57      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
     58    }
     59    public override IEnumerable<double> EstimatedTestValues {
     60      get { return GetEstimatedValues(ProblemData.TestIndizes); }
    5861    }
    5962
    60     public double TestMeanSquaredError {
    61       get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    62       protected set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
     63    public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
     64      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
     65      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     66      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     67
     68      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     69        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     70      }
     71
     72      return rows.Select(row => evaluationCache[row]);
    6373    }
    6474
    65     public double TrainingRSquared {
    66       get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; }
    67       protected set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }
     75    protected override void OnProblemDataChanged() {
     76      evaluationCache.Clear();
     77      base.OnProblemDataChanged();
    6878    }
    6979
    70     public double TestRSquared {
    71       get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; }
    72       protected set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; }
    73     }
    74 
    75     public double TrainingRelativeError {
    76       get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; }
    77       protected set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; }
    78     }
    79 
    80     public double TestRelativeError {
    81       get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; }
    82       protected set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; }
    83     }
    84 
    85     public double TrainingNormalizedMeanSquaredError {
    86       get { return ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value; }
    87       protected set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    88     }
    89 
    90     public double TestNormalizedMeanSquaredError {
    91       get { return ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value; }
    92       protected set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    93     }
    94 
    95 
    96     [StorableConstructor]
    97     protected RegressionSolution(bool deserializing) : base(deserializing) { }
    98     protected RegressionSolution(RegressionSolution original, Cloner cloner)
    99       : base(original, cloner) {
    100     }
    101     public RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
    102       : base(model, problemData) {
    103       Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
    104       Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
    105       Add(new Result(TrainingSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
    106       Add(new Result(TestSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
    107       Add(new Result(TrainingRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the training partition", new PercentValue()));
    108       Add(new Result(TestRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the test partition", new PercentValue()));
    109       Add(new Result(TrainingNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the training partition", new DoubleValue()));
    110       Add(new Result(TestNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the test partition", new DoubleValue()));
    111 
    112       RecalculateResults();
    113     }
    114 
    115     public override IDeepCloneable Clone(Cloner cloner) {
    116       return new RegressionSolution(this, cloner);
    117     }
    118 
    119     protected override void OnProblemDataChanged(EventArgs e) {
    120       base.OnProblemDataChanged(e);
    121       RecalculateResults();
    122     }
    123     protected override void OnModelChanged(EventArgs e) {
    124       base.OnModelChanged(e);
    125       RecalculateResults();
    126     }
    127 
    128     private void RecalculateResults() {
    129       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    130       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    131       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    132       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    133 
    134       OnlineCalculatorError errorState;
    135       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    136       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    137       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    138       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    139 
    140       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    141       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    142       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    143       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    144 
    145       double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    146       TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;
    147       double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    148       TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;
    149 
    150       double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    151       TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN;
    152       double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    153       TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN;
    154     }
    155 
    156     public virtual IEnumerable<double> EstimatedValues {
    157       get {
    158         return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
    159       }
    160     }
    161 
    162     public virtual IEnumerable<double> EstimatedTrainingValues {
    163       get {
    164         return GetEstimatedValues(ProblemData.TrainingIndizes);
    165       }
    166     }
    167 
    168     public virtual IEnumerable<double> EstimatedTestValues {
    169       get {
    170         return GetEstimatedValues(ProblemData.TestIndizes);
    171       }
    172     }
    173 
    174     public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    175       return Model.GetEstimatedValues(ProblemData.Dataset, rows);
     80    protected override void OnModelChanged() {
     81      evaluationCache.Clear();
     82      base.OnModelChanged();
    17683    }
    17784  }
Note: See TracChangeset for help on using the changeset viewer.