Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/01/11 17:48:53 (13 years ago)
Author:
mkommend
Message:

#1479: Integrated trunk changes.

Location:
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression
Files:
4 edited
1 copied

Legend:

Unmodified
Added
Removed
  • branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r5809 r6618  
    3434  public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
    3535
    36     [Storable]
    3736    private List<IRegressionModel> models;
    3837    public IEnumerable<IRegressionModel> Models {
    3938      get { return new List<IRegressionModel>(models); }
    4039    }
     40
     41    [Storable(Name = "Models")]
     42    private IEnumerable<IRegressionModel> StorableModels {
     43      get { return models; }
     44      set { models = value.ToList(); }
     45    }
     46
     47    #region backwards compatiblity 3.3.5
     48    [Storable(Name = "models", AllowOneWay = true)]
     49    private List<IRegressionModel> OldStorableModels {
     50      set { models = value; }
     51    }
     52    #endregion
     53
    4154    [StorableConstructor]
    4255    protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
     
    5770
    5871    #region IRegressionEnsembleModel Members
     72
     73    public void Add(IRegressionModel model) {
     74      models.Add(model);
     75    }
     76    public void Remove(IRegressionModel model) {
     77      models.Remove(model);
     78    }
    5979
    6080    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(Dataset dataset, IEnumerable<int> rows) {
     
    7999    }
    80100
     101    public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     102      return new RegressionEnsembleSolution(this.Models, problemData);
     103    }
     104    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
     105      return CreateRegressionSolution(problemData);
     106    }
     107
    81108    #endregion
    82109  }
  • branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r6377 r6618  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     25using HeuristicLab.Collections;
    2426using HeuristicLab.Common;
    2527using HeuristicLab.Core;
     28using HeuristicLab.Data;
    2629using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    27 using System;
    28 using HeuristicLab.Data;
    2930
    3031namespace HeuristicLab.Problems.DataAnalysis {
     
    3536  [Item("Regression Ensemble Solution", "A regression solution that contains an ensemble of multiple regression models")]
    3637  // [Creatable("Data Analysis")]
    37   public class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {
     38  public sealed class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {
    3839    public new IRegressionEnsembleModel Model {
    3940      get { return (IRegressionEnsembleModel)base.Model; }
     41    }
     42
     43    private readonly ItemCollection<IRegressionSolution> regressionSolutions;
     44    public IItemCollection<IRegressionSolution> RegressionSolutions {
     45      get { return regressionSolutions; }
    4046    }
    4147
     
    4652
    4753    [StorableConstructor]
    48     protected RegressionEnsembleSolution(bool deserializing) : base(deserializing) { }
    49     protected RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner)
     54    private RegressionEnsembleSolution(bool deserializing)
     55      : base(deserializing) {
     56      regressionSolutions = new ItemCollection<IRegressionSolution>();
     57    }
     58    [StorableHook(HookType.AfterDeserialization)]
     59    private void AfterDeserialization() {
     60      foreach (var model in Model.Models) {
     61        IRegressionProblemData problemData = (IRegressionProblemData)ProblemData.Clone();
     62        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
     63        problemData.TrainingPartition.End = trainingPartitions[model].End;
     64        problemData.TestPartition.Start = testPartitions[model].Start;
     65        problemData.TestPartition.End = testPartitions[model].End;
     66
     67        regressionSolutions.Add(model.CreateRegressionSolution(problemData));
     68      }
     69      RegisterRegressionSolutionsEventHandler();
     70    }
     71
     72    private RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner)
    5073      : base(original, cloner) {
    5174      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
     
    5780        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
    5881      }
    59       RecalculateResults();
     82
     83      regressionSolutions = cloner.Clone(original.regressionSolutions);
     84      RegisterRegressionSolutionsEventHandler();
    6085    }
    6186
    6287    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
    63       : base(new RegressionEnsembleModel(models), new RegressionEnsembleProblemData(problemData)) {
    64       trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    65       testPartitions = new Dictionary<IRegressionModel, IntRange>();
    66       foreach (var model in models) {
    67         trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone();
    68         testPartitions[model] = (IntRange)problemData.TestPartition.Clone();
    69       }
    70       RecalculateResults();
    71     }
     88      : this(models, problemData,
     89             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
     90             models.Select(m => (IntRange)problemData.TestPartition.Clone())
     91      ) { }
    7292
    7393    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
    74       : base(new RegressionEnsembleModel(models), new RegressionEnsembleProblemData(problemData)) {
     94      : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) {
    7595      this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    7696      this.testPartitions = new Dictionary<IRegressionModel, IntRange>();
     97      this.regressionSolutions = new ItemCollection<IRegressionSolution>();
     98
     99      List<IRegressionSolution> solutions = new List<IRegressionSolution>();
    77100      var modelEnumerator = models.GetEnumerator();
    78101      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
    79102      var testPartitionEnumerator = testPartitions.GetEnumerator();
     103
    80104      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
    81         this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone();
    82         this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone();
     105        var p = (IRegressionProblemData)problemData.Clone();
     106        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
     107        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
     108        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
     109        p.TestPartition.End = testPartitionEnumerator.Current.End;
     110
     111        solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p));
    83112      }
    84113      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
    85114        throw new ArgumentException();
    86115      }
    87       RecalculateResults();
     116
     117      RegisterRegressionSolutionsEventHandler();
     118      regressionSolutions.AddRange(solutions);
    88119    }
    89120
     
    91122      return new RegressionEnsembleSolution(this, cloner);
    92123    }
    93 
     124    private void RegisterRegressionSolutionsEventHandler() {
     125      regressionSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsAdded);
     126      regressionSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsRemoved);
     127      regressionSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_CollectionReset);
     128    }
     129
     130    protected override void RecalculateResults() {
     131      CalculateResults();
     132    }
     133
     134    #region Evaluation
    94135    public override IEnumerable<double> EstimatedTrainingValues {
    95136      get {
     
    160201      return estimatedValues.DefaultIfEmpty(double.NaN).Average();
    161202    }
     203    #endregion
     204
     205    public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     206      solutions.OfType<RegressionEnsembleSolution>().SelectMany(ensemble => ensemble.RegressionSolutions);
     207      regressionSolutions.AddRange(solutions);
     208    }
     209    public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     210      regressionSolutions.RemoveRange(solutions);
     211    }
     212
     213    private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
     214      foreach (var solution in e.Items) AddRegressionSolution(solution);
     215      RecalculateResults();
     216    }
     217    private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
     218      foreach (var solution in e.Items) RemoveRegressionSolution(solution);
     219      RecalculateResults();
     220    }
     221    private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
     222      foreach (var solution in e.OldItems) RemoveRegressionSolution(solution);
     223      foreach (var solution in e.Items) AddRegressionSolution(solution);
     224      RecalculateResults();
     225    }
     226
     227    private void AddRegressionSolution(IRegressionSolution solution) {
     228      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
     229      Model.Add(solution.Model);
     230      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
     231      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
     232    }
     233
     234    private void RemoveRegressionSolution(IRegressionSolution solution) {
     235      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
     236      Model.Remove(solution.Model);
     237      trainingPartitions.Remove(solution.Model);
     238      testPartitions.Remove(solution.Model);
     239    }
    162240  }
    163241}
  • branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs

    r6238 r6618  
    7777    #endregion
    7878
    79     public IValueParameter<StringValue> TargetVariableParameter {
    80       get { return (IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
     79    public ConstrainedValueParameter<StringValue> TargetVariableParameter {
     80      get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
    8181    }
    8282    public string TargetVariable {
  • branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs

    r6415 r6618  
    2323using System.Linq;
    2424using HeuristicLab.Common;
    25 using HeuristicLab.Data;
    26 using HeuristicLab.Optimization;
    2725using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2826
     
    3230  /// </summary>
    3331  [StorableClass]
    34   public class RegressionSolution : DataAnalysisSolution, IRegressionSolution {
    35     private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
    36     private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
    37     private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)";
    38     private const string TestSquaredCorrelationResultName = "Pearson's R² (test)";
    39     private const string TrainingRelativeErrorResultName = "Average relative error (training)";
    40     private const string TestRelativeErrorResultName = "Average relative error (test)";
    41     private const string TrainingNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (training)";
    42     private const string TestNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (test)";
    43 
    44     public new IRegressionModel Model {
    45       get { return (IRegressionModel)base.Model; }
    46       protected set { base.Model = value; }
    47     }
    48 
    49     public new IRegressionProblemData ProblemData {
    50       get { return (IRegressionProblemData)base.ProblemData; }
    51       protected set { base.ProblemData = value; }
    52     }
    53 
    54     public double TrainingMeanSquaredError {
    55       get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    56       private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
    57     }
    58 
    59     public double TestMeanSquaredError {
    60       get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    61       private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
    62     }
    63 
    64     public double TrainingRSquared {
    65       get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; }
    66       private set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }
    67     }
    68 
    69     public double TestRSquared {
    70       get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; }
    71       private set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; }
    72     }
    73 
    74     public double TrainingRelativeError {
    75       get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; }
    76       private set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; }
    77     }
    78 
    79     public double TestRelativeError {
    80       get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; }
    81       private set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; }
    82     }
    83 
    84     public double TrainingNormalizedMeanSquaredError {
    85       get { return ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value; }
    86       private set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    87     }
    88 
    89     public double TestNormalizedMeanSquaredError {
    90       get { return ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value; }
    91       private set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    92     }
    93 
     32  public abstract class RegressionSolution : RegressionSolutionBase {
     33    protected readonly Dictionary<int, double> evaluationCache;
    9434
    9535    [StorableConstructor]
    96     protected RegressionSolution(bool deserializing) : base(deserializing) { }
     36    protected RegressionSolution(bool deserializing)
     37      : base(deserializing) {
     38      evaluationCache = new Dictionary<int, double>();
     39    }
    9740    protected RegressionSolution(RegressionSolution original, Cloner cloner)
    9841      : base(original, cloner) {
     42      evaluationCache = new Dictionary<int, double>(original.evaluationCache);
    9943    }
    100     public RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
     44    protected RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
    10145      : base(model, problemData) {
    102       Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
    103       Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
    104       Add(new Result(TrainingSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
    105       Add(new Result(TestSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
    106       Add(new Result(TrainingRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the training partition", new PercentValue()));
    107       Add(new Result(TestRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the test partition", new PercentValue()));
    108       Add(new Result(TrainingNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the training partition", new DoubleValue()));
    109       Add(new Result(TestNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the test partition", new DoubleValue()));
    110 
    111       CalculateResults();
    112     }
    113 
    114     public override IDeepCloneable Clone(Cloner cloner) {
    115       return new RegressionSolution(this, cloner);
     46      evaluationCache = new Dictionary<int, double>();
    11647    }
    11748
     
    12051    }
    12152
    122     private void CalculateResults() {
    123       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    124       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    125       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    126       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    127 
    128       OnlineCalculatorError errorState;
    129       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    130       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    131       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    132       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    133 
    134       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    135       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    136       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    137       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    138 
    139       double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    140       TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;
    141       double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    142       TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;
    143 
    144       double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    145       TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN;
    146       double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    147       TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN;
     53    public override IEnumerable<double> EstimatedValues {
     54      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     55    }
     56    public override IEnumerable<double> EstimatedTrainingValues {
     57      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
     58    }
     59    public override IEnumerable<double> EstimatedTestValues {
     60      get { return GetEstimatedValues(ProblemData.TestIndizes); }
    14861    }
    14962
    150     public virtual IEnumerable<double> EstimatedValues {
    151       get {
    152         return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
     63    public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
     64      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
     65      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     66      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     67
     68      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     69        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
    15370      }
     71
     72      return rows.Select(row => evaluationCache[row]);
    15473    }
    15574
    156     public virtual IEnumerable<double> EstimatedTrainingValues {
    157       get {
    158         return GetEstimatedValues(ProblemData.TrainingIndizes);
    159       }
     75    protected override void OnProblemDataChanged() {
     76      evaluationCache.Clear();
     77      base.OnProblemDataChanged();
    16078    }
    16179
    162     public virtual IEnumerable<double> EstimatedTestValues {
    163       get {
    164         return GetEstimatedValues(ProblemData.TestIndizes);
    165       }
    166     }
    167 
    168     public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    169       return Model.GetEstimatedValues(ProblemData.Dataset, rows);
     80    protected override void OnModelChanged() {
     81      evaluationCache.Clear();
     82      base.OnModelChanged();
    17083    }
    17184  }
Note: See TracChangeset for help on using the changeset viewer.