Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/14/11 13:59:25 (13 years ago)
Author:
epitzer
Message:

#1530 integrate changes from trunk

Location:
branches/PersistenceSpeedUp
Files:
6 edited
2 copied

Legend:

Unmodified
Added
Removed
  • branches/PersistenceSpeedUp

  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis

  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r5809 r6760  
    3434  public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
    3535
    36     [Storable]
    3736    private List<IRegressionModel> models;
    3837    public IEnumerable<IRegressionModel> Models {
    3938      get { return new List<IRegressionModel>(models); }
    4039    }
     40
     41    [Storable(Name = "Models")]
     42    private IEnumerable<IRegressionModel> StorableModels {
     43      get { return models; }
     44      set { models = value.ToList(); }
     45    }
     46
     47    #region backwards compatiblity 3.3.5
     48    [Storable(Name = "models", AllowOneWay = true)]
     49    private List<IRegressionModel> OldStorableModels {
     50      set { models = value; }
     51    }
     52    #endregion
     53
    4154    [StorableConstructor]
    4255    protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
     
    4558      this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
    4659    }
     60
     61    public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { }
    4762    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models)
    4863      : base() {
     
    5772
    5873    #region IRegressionEnsembleModel Members
     74
     75    public void Add(IRegressionModel model) {
     76      models.Add(model);
     77    }
     78    public void Remove(IRegressionModel model) {
     79      models.Remove(model);
     80    }
    5981
    6082    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(Dataset dataset, IEnumerable<int> rows) {
     
    79101    }
    80102
     103    public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     104      return new RegressionEnsembleSolution(this.Models, problemData);
     105    }
     106    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
     107      return CreateRegressionSolution(problemData);
     108    }
     109
    81110    #endregion
    82111  }
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r6184 r6760  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     25using HeuristicLab.Collections;
    2426using HeuristicLab.Common;
    2527using HeuristicLab.Core;
     28using HeuristicLab.Data;
    2629using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    27 using System;
    28 using HeuristicLab.Data;
    2930
    3031namespace HeuristicLab.Problems.DataAnalysis {
     
    3435  [StorableClass]
    3536  [Item("Regression Ensemble Solution", "A regression solution that contains an ensemble of multiple regression models")]
    36   // [Creatable("Data Analysis")]
    37   public class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {
     37  [Creatable("Data Analysis - Ensembles")]
     38  public sealed class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {
    3839    public new IRegressionEnsembleModel Model {
    3940      get { return (IRegressionEnsembleModel)base.Model; }
     41    }
     42
     43    public new RegressionEnsembleProblemData ProblemData {
     44      get { return (RegressionEnsembleProblemData)base.ProblemData; }
     45      set { base.ProblemData = value; }
     46    }
     47
     48    private readonly ItemCollection<IRegressionSolution> regressionSolutions;
     49    public IItemCollection<IRegressionSolution> RegressionSolutions {
     50      get { return regressionSolutions; }
    4051    }
    4152
     
    4657
    4758    [StorableConstructor]
    48     protected RegressionEnsembleSolution(bool deserializing) : base(deserializing) { }
    49     protected RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner)
     59    private RegressionEnsembleSolution(bool deserializing)
     60      : base(deserializing) {
     61      regressionSolutions = new ItemCollection<IRegressionSolution>();
     62    }
     63    [StorableHook(HookType.AfterDeserialization)]
     64    private void AfterDeserialization() {
     65      foreach (var model in Model.Models) {
     66        IRegressionProblemData problemData = (IRegressionProblemData) ProblemData.Clone();
     67        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
     68        problemData.TrainingPartition.End = trainingPartitions[model].End;
     69        problemData.TestPartition.Start = testPartitions[model].Start;
     70        problemData.TestPartition.End = testPartitions[model].End;
     71
     72        regressionSolutions.Add(model.CreateRegressionSolution(problemData));
     73      }
     74      RegisterRegressionSolutionsEventHandler();
     75    }
     76
     77    private RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner)
    5078      : base(original, cloner) {
    51     }
    52     public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
    53       : base(new RegressionEnsembleModel(models), problemData) {
    5479      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    5580      testPartitions = new Dictionary<IRegressionModel, IntRange>();
    56       foreach (var model in models) {
    57         trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone();
    58         testPartitions[model] = (IntRange)problemData.TestPartition.Clone();
    59       }
    60       RecalculateResults();
    61     }
     81      foreach (var pair in original.trainingPartitions) {
     82        trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     83      }
     84      foreach (var pair in original.testPartitions) {
     85        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
     86      }
     87
     88      regressionSolutions = cloner.Clone(original.regressionSolutions);
     89      RegisterRegressionSolutionsEventHandler();
     90    }
     91
     92    public RegressionEnsembleSolution()
     93      : base(new RegressionEnsembleModel(), RegressionEnsembleProblemData.EmptyProblemData) {
     94      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
     95      testPartitions = new Dictionary<IRegressionModel, IntRange>();
     96      regressionSolutions = new ItemCollection<IRegressionSolution>();
     97
     98      RegisterRegressionSolutionsEventHandler();
     99    }
     100
     101    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
     102      : this(models, problemData,
     103             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
     104             models.Select(m => (IntRange)problemData.TestPartition.Clone())
     105      ) { }
    62106
    63107    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
    64       : base(new RegressionEnsembleModel(models), problemData) {
     108      : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) {
    65109      this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    66110      this.testPartitions = new Dictionary<IRegressionModel, IntRange>();
     111      this.regressionSolutions = new ItemCollection<IRegressionSolution>();
     112
     113      List<IRegressionSolution> solutions = new List<IRegressionSolution>();
    67114      var modelEnumerator = models.GetEnumerator();
    68115      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
    69116      var testPartitionEnumerator = testPartitions.GetEnumerator();
     117
    70118      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
    71         this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone();
    72         this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone();
     119        var p = (IRegressionProblemData)problemData.Clone();
     120        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
     121        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
     122        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
     123        p.TestPartition.End = testPartitionEnumerator.Current.End;
     124
     125        solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p));
    73126      }
    74127      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
     
    76129      }
    77130
    78       RecalculateResults();
    79     }
    80 
    81     private void RecalculateResults() {
    82       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    83       var trainingIndizes = Enumerable.Range(ProblemData.TrainingPartition.Start,
    84         ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);
    85       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, trainingIndizes);
    86       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    87       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    88 
    89       OnlineCalculatorError errorState;
    90       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    91       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    92       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    93       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    94 
    95       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    96       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    97       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    98       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    99 
    100       double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    101       TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;
    102       double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    103       TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;
    104 
    105       double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    106       TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN;
    107       double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    108       TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN;
     131      RegisterRegressionSolutionsEventHandler();
     132      regressionSolutions.AddRange(solutions);
    109133    }
    110134
     
    112136      return new RegressionEnsembleSolution(this, cloner);
    113137    }
    114 
     138    private void RegisterRegressionSolutionsEventHandler() {
     139      regressionSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsAdded);
     140      regressionSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsRemoved);
     141      regressionSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_CollectionReset);
     142    }
     143
     144    protected override void RecalculateResults() {
     145      CalculateResults();
     146    }
     147
     148    #region Evaluation
    115149    public override IEnumerable<double> EstimatedTrainingValues {
    116150      get {
    117         var rows = Enumerable.Range(ProblemData.TrainingPartition.Start, ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);
     151        var rows = ProblemData.TrainingIndizes;
    118152        var estimatedValuesEnumerators = (from model in Model.Models
    119153                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    120154                                         .ToList();
    121155        var rowsEnumerator = rows.GetEnumerator();
     156        // aggregate to make sure that MoveNext is called for all enumerators
    122157        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    123158          int currentRow = rowsEnumerator.Current;
    124159
    125160          var selectedEnumerators = from pair in estimatedValuesEnumerators
    126                                     where trainingPartitions == null || !trainingPartitions.ContainsKey(pair.Model) ||
    127                                          (trainingPartitions[pair.Model].Start <= currentRow && currentRow < trainingPartitions[pair.Model].End)
     161                                    where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model)
    128162                                    select pair.EstimatedValuesEnumerator;
    129163          yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
     
    134168    public override IEnumerable<double> EstimatedTestValues {
    135169      get {
     170        var rows = ProblemData.TestIndizes;
    136171        var estimatedValuesEnumerators = (from model in Model.Models
    137                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, ProblemData.TestIndizes).GetEnumerator() })
     172                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    138173                                         .ToList();
    139174        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
     175        // aggregate to make sure that MoveNext is called for all enumerators
    140176        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    141177          int currentRow = rowsEnumerator.Current;
    142178
    143179          var selectedEnumerators = from pair in estimatedValuesEnumerators
    144                                     where testPartitions == null || !testPartitions.ContainsKey(pair.Model) ||
    145                                       (testPartitions[pair.Model].Start <= currentRow && currentRow < testPartitions[pair.Model].End)
     180                                    where RowIsTestForModel(currentRow, pair.Model)
    146181                                    select pair.EstimatedValuesEnumerator;
    147182
     
    149184        }
    150185      }
     186    }
     187
     188    private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) {
     189      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
     190              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
     191    }
     192
     193    private bool RowIsTestForModel(int currentRow, IRegressionModel model) {
     194      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     195              (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
    151196    }
    152197
     
    168213
    169214    private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) {
    170       return estimatedValues.Average();
    171     }
    172 
    173     //[Storable]
    174     //private string name;
    175     //public string Name {
    176     //  get {
    177     //    return name;
    178     //  }
    179     //  set {
    180     //    if (value != null && value != name) {
    181     //      var cancelEventArgs = new CancelEventArgs<string>(value);
    182     //      OnNameChanging(cancelEventArgs);
    183     //      if (cancelEventArgs.Cancel == false) {
    184     //        name = value;
    185     //        OnNamedChanged(EventArgs.Empty);
    186     //      }
    187     //    }
    188     //  }
    189     //}
    190 
    191     //public bool CanChangeName {
    192     //  get { return true; }
    193     //}
    194 
    195     //[Storable]
    196     //private string description;
    197     //public string Description {
    198     //  get {
    199     //    return description;
    200     //  }
    201     //  set {
    202     //    if (value != null && value != description) {
    203     //      description = value;
    204     //      OnDescriptionChanged(EventArgs.Empty);
    205     //    }
    206     //  }
    207     //}
    208 
    209     //public bool CanChangeDescription {
    210     //  get { return true; }
    211     //}
    212 
    213     //#region events
    214     //public event EventHandler<CancelEventArgs<string>> NameChanging;
    215     //private void OnNameChanging(CancelEventArgs<string> cancelEventArgs) {
    216     //  var listener = NameChanging;
    217     //  if (listener != null) listener(this, cancelEventArgs);
    218     //}
    219 
    220     //public event EventHandler NameChanged;
    221     //private void OnNamedChanged(EventArgs e) {
    222     //  var listener = NameChanged;
    223     //  if (listener != null) listener(this, e);
    224     //}
    225 
    226     //public event EventHandler DescriptionChanged;
    227     //private void OnDescriptionChanged(EventArgs e) {
    228     //  var listener = DescriptionChanged;
    229     //  if (listener != null) listener(this, e);
    230     //}
    231     // #endregion
     215      return estimatedValues.DefaultIfEmpty(double.NaN).Average();
     216    }
     217    #endregion
     218
     219    protected override void OnProblemDataChanged() {
     220      IRegressionProblemData problemData = new RegressionProblemData(ProblemData.Dataset,
     221                                                                     ProblemData.AllowedInputVariables,
     222                                                                     ProblemData.TargetVariable);
     223      problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
     224      problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
     225      problemData.TestPartition.Start = ProblemData.TestPartition.Start;
     226      problemData.TestPartition.End = ProblemData.TestPartition.End;
     227
     228      foreach (var solution in RegressionSolutions) {
     229        if (solution is RegressionEnsembleSolution)
     230          solution.ProblemData = ProblemData;
     231        else
     232          solution.ProblemData = problemData;
     233      }
     234      foreach (var trainingPartition in trainingPartitions.Values) {
     235        trainingPartition.Start = ProblemData.TrainingPartition.Start;
     236        trainingPartition.End = ProblemData.TrainingPartition.End;
     237      }
     238      foreach (var testPartition in testPartitions.Values) {
     239        testPartition.Start = ProblemData.TestPartition.Start;
     240        testPartition.End = ProblemData.TestPartition.End;
     241      }
     242
     243      base.OnProblemDataChanged();
     244    }
     245
     246    public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     247      regressionSolutions.AddRange(solutions);
     248    }
     249    public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     250      regressionSolutions.RemoveRange(solutions);
     251    }
     252
     253    private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
     254      foreach (var solution in e.Items) AddRegressionSolution(solution);
     255      RecalculateResults();
     256    }
     257    private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
     258      foreach (var solution in e.Items) RemoveRegressionSolution(solution);
     259      RecalculateResults();
     260    }
     261    private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
     262      foreach (var solution in e.OldItems) RemoveRegressionSolution(solution);
     263      foreach (var solution in e.Items) AddRegressionSolution(solution);
     264      RecalculateResults();
     265    }
     266
     267    private void AddRegressionSolution(IRegressionSolution solution) {
     268      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
     269      Model.Add(solution.Model);
     270      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
     271      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
     272    }
     273
     274    private void RemoveRegressionSolution(IRegressionSolution solution) {
     275      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
     276      Model.Remove(solution.Model);
     277      trainingPartitions.Remove(solution.Model);
     278      testPartitions.Remove(solution.Model);
     279    }
    232280  }
    233281}
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs

    r5809 r6760  
    3333  [StorableClass]
    3434  [Item("RegressionProblemData", "Represents an item containing all data defining a regression problem.")]
    35   public sealed class RegressionProblemData : DataAnalysisProblemData, IRegressionProblemData {
    36     private const string TargetVariableParameterName = "TargetVariable";
     35  public class RegressionProblemData : DataAnalysisProblemData, IRegressionProblemData {
     36    protected const string TargetVariableParameterName = "TargetVariable";
    3737
    3838    #region default data
     
    6464          {0.83763905,  0.468046718}
    6565    };
    66     private static Dataset defaultDataset;
    67     private static IEnumerable<string> defaultAllowedInputVariables;
    68     private static string defaultTargetVariable;
     66    private static readonly Dataset defaultDataset;
     67    private static readonly IEnumerable<string> defaultAllowedInputVariables;
     68    private static readonly string defaultTargetVariable;
     69
     70    private static readonly RegressionProblemData emptyProblemData;
     71    public static RegressionProblemData EmptyProblemData {
     72      get { return emptyProblemData; }
     73    }
    6974
    7075    static RegressionProblemData() {
     
    7479      defaultAllowedInputVariables = new List<string>() { "x" };
    7580      defaultTargetVariable = "y";
     81
     82      var problemData = new RegressionProblemData();
     83      problemData.Parameters.Clear();
     84      problemData.Name = "Empty Regression ProblemData";
     85      problemData.Description = "This ProblemData acts as place holder before the correct problem data is loaded.";
     86      problemData.isEmpty = true;
     87
     88      problemData.Parameters.Add(new FixedValueParameter<Dataset>(DatasetParameterName, "", new Dataset()));
     89      problemData.Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemList<StringValue>>(InputVariablesParameterName, ""));
     90      problemData.Parameters.Add(new FixedValueParameter<IntRange>(TrainingPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly()));
     91      problemData.Parameters.Add(new FixedValueParameter<IntRange>(TestPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly()));
     92      problemData.Parameters.Add(new ConstrainedValueParameter<StringValue>(TargetVariableParameterName, new ItemSet<StringValue>()));
     93      emptyProblemData = problemData;
    7694    }
    7795    #endregion
    7896
    79     public IValueParameter<StringValue> TargetVariableParameter {
    80       get { return (IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
     97    public ConstrainedValueParameter<StringValue> TargetVariableParameter {
     98      get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
    8199    }
    82100    public string TargetVariable {
     
    85103
    86104    [StorableConstructor]
    87     private RegressionProblemData(bool deserializing) : base(deserializing) { }
     105    protected RegressionProblemData(bool deserializing) : base(deserializing) { }
    88106    [StorableHook(HookType.AfterDeserialization)]
    89107    private void AfterDeserialization() {
     
    91109    }
    92110
    93 
    94     private RegressionProblemData(RegressionProblemData original, Cloner cloner)
     111    protected RegressionProblemData(RegressionProblemData original, Cloner cloner)
    95112      : base(original, cloner) {
    96113      RegisterParameterEvents();
    97114    }
    98     public override IDeepCloneable Clone(Cloner cloner) { return new RegressionProblemData(this, cloner); }
     115    public override IDeepCloneable Clone(Cloner cloner) {
     116      if (this == emptyProblemData) return emptyProblemData;
     117      return new RegressionProblemData(this, cloner);
     118    }
    99119
    100120    public RegressionProblemData()
     
    124144      dataset.Name = Path.GetFileName(fileName);
    125145
    126       RegressionProblemData problemData = new RegressionProblemData(dataset, dataset.VariableNames.Skip(1), dataset.VariableNames.First());
     146      RegressionProblemData problemData = new RegressionProblemData(dataset, dataset.DoubleVariables.Skip(1), dataset.DoubleVariables.First());
    127147      problemData.Name = "Data imported from " + Path.GetFileName(fileName);
    128148      return problemData;
  • branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs

    r6184 r6760  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
    2524using HeuristicLab.Common;
    26 using HeuristicLab.Data;
    27 using HeuristicLab.Optimization;
    2825using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2926
     
    3330  /// </summary>
    3431  [StorableClass]
    35   public class RegressionSolution : DataAnalysisSolution, IRegressionSolution {
    36     private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
    37     private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
    38     private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)";
    39     private const string TestSquaredCorrelationResultName = "Pearson's R² (test)";
    40     private const string TrainingRelativeErrorResultName = "Average relative error (training)";
    41     private const string TestRelativeErrorResultName = "Average relative error (test)";
    42     private const string TrainingNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (training)";
    43     private const string TestNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (test)";
     32  public abstract class RegressionSolution : RegressionSolutionBase {
     33    protected readonly Dictionary<int, double> evaluationCache;
    4434
    45     public new IRegressionModel Model {
    46       get { return (IRegressionModel)base.Model; }
    47       protected set { base.Model = value; }
     35    [StorableConstructor]
     36    protected RegressionSolution(bool deserializing)
     37      : base(deserializing) {
     38      evaluationCache = new Dictionary<int, double>();
     39    }
     40    protected RegressionSolution(RegressionSolution original, Cloner cloner)
     41      : base(original, cloner) {
     42      evaluationCache = new Dictionary<int, double>(original.evaluationCache);
     43    }
     44    protected RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
     45      : base(model, problemData) {
     46      evaluationCache = new Dictionary<int, double>();
    4847    }
    4948
    50     public new IRegressionProblemData ProblemData {
    51       get { return (IRegressionProblemData)base.ProblemData; }
    52       protected set { base.ProblemData = value; }
     49    protected override void RecalculateResults() {
     50      CalculateResults();
    5351    }
    5452
    55     public double TrainingMeanSquaredError {
    56       get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    57       protected set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
     53    public override IEnumerable<double> EstimatedValues {
     54      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     55    }
     56    public override IEnumerable<double> EstimatedTrainingValues {
     57      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
     58    }
     59    public override IEnumerable<double> EstimatedTestValues {
     60      get { return GetEstimatedValues(ProblemData.TestIndizes); }
    5861    }
    5962
    60     public double TestMeanSquaredError {
    61       get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    62       protected set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
     63    public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
     64      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
     65      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     66      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     67
     68      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     69        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     70      }
     71
     72      return rows.Select(row => evaluationCache[row]);
    6373    }
    6474
    65     public double TrainingRSquared {
    66       get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; }
    67       protected set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }
     75    protected override void OnProblemDataChanged() {
     76      evaluationCache.Clear();
     77      base.OnProblemDataChanged();
    6878    }
    6979
    70     public double TestRSquared {
    71       get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; }
    72       protected set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; }
    73     }
    74 
    75     public double TrainingRelativeError {
    76       get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; }
    77       protected set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; }
    78     }
    79 
    80     public double TestRelativeError {
    81       get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; }
    82       protected set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; }
    83     }
    84 
    85     public double TrainingNormalizedMeanSquaredError {
    86       get { return ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value; }
    87       protected set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    88     }
    89 
    90     public double TestNormalizedMeanSquaredError {
    91       get { return ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value; }
    92       protected set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    93     }
    94 
    95 
    96     [StorableConstructor]
    97     protected RegressionSolution(bool deserializing) : base(deserializing) { }
    98     protected RegressionSolution(RegressionSolution original, Cloner cloner)
    99       : base(original, cloner) {
    100     }
    101     public RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
    102       : base(model, problemData) {
    103       Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
    104       Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
    105       Add(new Result(TrainingSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
    106       Add(new Result(TestSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
    107       Add(new Result(TrainingRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the training partition", new PercentValue()));
    108       Add(new Result(TestRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the test partition", new PercentValue()));
    109       Add(new Result(TrainingNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the training partition", new DoubleValue()));
    110       Add(new Result(TestNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the test partition", new DoubleValue()));
    111 
    112       RecalculateResults();
    113     }
    114 
    115     public override IDeepCloneable Clone(Cloner cloner) {
    116       return new RegressionSolution(this, cloner);
    117     }
    118 
    119     protected override void OnProblemDataChanged(EventArgs e) {
    120       base.OnProblemDataChanged(e);
    121       RecalculateResults();
    122     }
    123     protected override void OnModelChanged(EventArgs e) {
    124       base.OnModelChanged(e);
    125       RecalculateResults();
    126     }
    127 
    128     private void RecalculateResults() {
    129       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    130       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    131       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    132       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    133 
    134       OnlineCalculatorError errorState;
    135       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    136       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    137       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    138       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    139 
    140       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    141       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    142       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    143       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    144 
    145       double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    146       TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;
    147       double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    148       TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;
    149 
    150       double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    151       TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN;
    152       double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    153       TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN;
    154     }
    155 
    156     public virtual IEnumerable<double> EstimatedValues {
    157       get {
    158         return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
    159       }
    160     }
    161 
    162     public virtual IEnumerable<double> EstimatedTrainingValues {
    163       get {
    164         return GetEstimatedValues(ProblemData.TrainingIndizes);
    165       }
    166     }
    167 
    168     public virtual IEnumerable<double> EstimatedTestValues {
    169       get {
    170         return GetEstimatedValues(ProblemData.TestIndizes);
    171       }
    172     }
    173 
    174     public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    175       return Model.GetEstimatedValues(ProblemData.Dataset, rows);
     80    protected override void OnModelChanged() {
     81      evaluationCache.Clear();
     82      base.OnModelChanged();
    17683    }
    17784  }
Note: See TracChangeset for help on using the changeset viewer.