Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/20/11 15:07:45 (14 years ago)
Author:
gkronber
Message:

#1450 adapted views for regression solution to work for ensembles of regression solutions as well.

Location:
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation
Files:
1 added
4 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisProblemData.cs

    r6236 r6238  
    7171    }
    7272
    73     public IEnumerable<int> TrainingIndizes {
     73    public virtual IEnumerable<int> TrainingIndizes {
    7474      get {
    7575        return Enumerable.Range(TrainingPartition.Start, TrainingPartition.End - TrainingPartition.Start)
     
    7777      }
    7878    }
    79     public IEnumerable<int> TestIndizes {
     79    public virtual IEnumerable<int> TestIndizes {
    8080      get {
    8181        return Enumerable.Range(TestPartition.Start, TestPartition.End - TestPartition.Start)
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r6184 r6238  
    5151    }
    5252    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
    53       : base(new RegressionEnsembleModel(models), problemData) {
     53      : base(new RegressionEnsembleModel(models), new RegressionEnsembleProblemData(problemData)) {
    5454      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    5555      testPartitions = new Dictionary<IRegressionModel, IntRange>();
     
    6262
    6363    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
    64       : base(new RegressionEnsembleModel(models), problemData) {
     64      : base(new RegressionEnsembleModel(models), new RegressionEnsembleProblemData(problemData)) {
    6565      this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    6666      this.testPartitions = new Dictionary<IRegressionModel, IntRange>();
     
    7575        throw new ArgumentException();
    7676      }
    77 
    7877      RecalculateResults();
    79     }
    80 
    81     private void RecalculateResults() {
    82       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    83       var trainingIndizes = Enumerable.Range(ProblemData.TrainingPartition.Start,
    84         ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);
    85       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, trainingIndizes);
    86       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    87       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    88 
    89       OnlineCalculatorError errorState;
    90       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    91       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    92       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    93       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    94 
    95       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    96       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    97       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    98       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    99 
    100       double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    101       TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;
    102       double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    103       TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;
    104 
    105       double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    106       TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN;
    107       double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    108       TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN;
    10978    }
    11079
     
    11584    public override IEnumerable<double> EstimatedTrainingValues {
    11685      get {
    117         var rows = Enumerable.Range(ProblemData.TrainingPartition.Start, ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);
     86        var rows = ProblemData.TrainingIndizes;
    11887        var estimatedValuesEnumerators = (from model in Model.Models
    11988                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    12089                                         .ToList();
    12190        var rowsEnumerator = rows.GetEnumerator();
     91        // aggregate to make sure that MoveNext is called for all enumerators
    12292        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    12393          int currentRow = rowsEnumerator.Current;
     
    134104    public override IEnumerable<double> EstimatedTestValues {
    135105      get {
     106        var rows = ProblemData.TestIndizes;
    136107        var estimatedValuesEnumerators = (from model in Model.Models
    137                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, ProblemData.TestIndizes).GetEnumerator() })
     108                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    138109                                         .ToList();
    139110        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
     111        // aggregate to make sure that MoveNext is called for all enumerators
    140112        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    141113          int currentRow = rowsEnumerator.Current;
     
    168140
    169141    private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) {
    170       return estimatedValues.Average();
     142      return estimatedValues.DefaultIfEmpty(double.NaN).Average();
    171143    }
    172144
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs

    r5809 r6238  
    3333  [StorableClass]
    3434  [Item("RegressionProblemData", "Represents an item containing all data defining a regression problem.")]
    35   public sealed class RegressionProblemData : DataAnalysisProblemData, IRegressionProblemData {
     35  public class RegressionProblemData : DataAnalysisProblemData, IRegressionProblemData {
    3636    private const string TargetVariableParameterName = "TargetVariable";
    3737
     
    8585
    8686    [StorableConstructor]
    87     private RegressionProblemData(bool deserializing) : base(deserializing) { }
     87    protected RegressionProblemData(bool deserializing) : base(deserializing) { }
    8888    [StorableHook(HookType.AfterDeserialization)]
    8989    private void AfterDeserialization() {
     
    9292
    9393
    94     private RegressionProblemData(RegressionProblemData original, Cloner cloner)
     94    protected RegressionProblemData(RegressionProblemData original, Cloner cloner)
    9595      : base(original, cloner) {
    9696      RegisterParameterEvents();
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs

    r6184 r6238  
    5555    public double TrainingMeanSquaredError {
    5656      get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    57       protected set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
     57      private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
    5858    }
    5959
    6060    public double TestMeanSquaredError {
    6161      get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    62       protected set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
     62      private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
    6363    }
    6464
    6565    public double TrainingRSquared {
    6666      get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; }
    67       protected set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }
     67      private set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }
    6868    }
    6969
    7070    public double TestRSquared {
    7171      get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; }
    72       protected set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; }
     72      private set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; }
    7373    }
    7474
    7575    public double TrainingRelativeError {
    7676      get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; }
    77       protected set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; }
     77      private set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; }
    7878    }
    7979
    8080    public double TestRelativeError {
    8181      get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; }
    82       protected set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; }
     82      private set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; }
    8383    }
    8484
    8585    public double TrainingNormalizedMeanSquaredError {
    8686      get { return ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value; }
    87       protected set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; }
     87      private set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    8888    }
    8989
    9090    public double TestNormalizedMeanSquaredError {
    9191      get { return ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value; }
    92       protected set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; }
     92      private set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    9393    }
    9494
     
    126126    }
    127127
    128     private void RecalculateResults() {
     128    protected void RecalculateResults() {
    129129      double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    130130      IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
Note: See TracChangeset for help on using the changeset viewer.