Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/23/17 00:52:14 (7 years ago)
Author:
abeham
Message:

#2258: merged r13329:14000 from trunk into branch

Location:
branches/Async
Files:
21 edited
4 copied

Legend:

Unmodified
Added
Removed
  • branches/Async

  • branches/Async/HeuristicLab.Problems.DataAnalysis

  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleModel.cs

    r12509 r15280  
    3232  [StorableClass]
    3333  [Item("ClassificationEnsembleModel", "A classification model that contains an ensemble of multiple classification models")]
    34   public class ClassificationEnsembleModel : NamedItem, IClassificationEnsembleModel {
     34  public class ClassificationEnsembleModel : ClassificationModel, IClassificationEnsembleModel {
     35    public override IEnumerable<string> VariablesUsedForPrediction {
     36      get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); }
     37    }
    3538
    3639    [Storable]
     
    4952    public ClassificationEnsembleModel() : this(Enumerable.Empty<IClassificationModel>()) { }
    5053    public ClassificationEnsembleModel(IEnumerable<IClassificationModel> models)
    51       : base() {
     54      : base(string.Empty) {
    5255      this.name = ItemName;
    5356      this.description = ItemDescription;
    5457      this.models = new List<IClassificationModel>(models);
     58
     59      if (this.models.Any()) this.TargetVariable = this.models.First().TargetVariable;
    5560    }
    5661
     
    5964    }
    6065
    61     #region IClassificationEnsembleModel Members
    6266    public void Add(IClassificationModel model) {
     67      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
    6368      models.Add(model);
    6469    }
    6570    public void Remove(IClassificationModel model) {
    6671      models.Remove(model);
     72      if (!models.Any()) TargetVariable = string.Empty;
    6773    }
    6874
     
    7884    }
    7985
    80     #endregion
    8186
    82     #region IClassificationModel Members
    83 
    84     public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
     87    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
    8588      foreach (var estimatedValuesVector in GetEstimatedClassValueVectors(dataset, rows)) {
    8689        // return the class which is most often occuring
     
    9497    }
    9598
    96     IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
     99    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
    97100      return new ClassificationEnsembleSolution(models, new ClassificationEnsembleProblemData(problemData));
    98101    }
    99     #endregion
     102
     103
    100104  }
    101105}
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationPerformanceMeasures.cs

    r13101 r15280  
    153153      Add(new Result(TestF1ScoreResultName, "The F1 score of the model on the test partition.", new DoubleValue()));
    154154      Add(new Result(TestMatthewsCorrelationResultName, "The Matthews correlation value of the model on the test partition.", new DoubleValue()));
     155
     156      Reset();
     157    }
     158
     159
     160    public void Reset() {
    155161      TrainingTruePositiveRate = double.NaN;
    156162      TrainingTrueNegativeRate = double.NaN;
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblem.cs

    r12504 r15280  
    3535    public override IDeepCloneable Clone(Cloner cloner) { return new ClassificationProblem(this, cloner); }
    3636
    37     public ClassificationProblem()
    38       : base() {
    39       ProblemData = new ClassificationProblemData();
    40     }
     37    public ClassificationProblem() : base(new ClassificationProblemData()) { }
    4138  }
    4239}
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblemData.cs

    r12509 r15280  
    283283    private void AfterDeserialization() {
    284284      RegisterParameterEvents();
     285
     286      classNamesCache = new List<string>();
     287      for (int i = 0; i < ClassNamesParameter.Value.Rows; i++)
     288        classNamesCache.Add(ClassNamesParameter.Value[i, 0]);
     289
    285290      // BackwardsCompatibility3.4
    286291      #region Backwards compatible code, remove with 3.5
     
    297302      : base(original, cloner) {
    298303      RegisterParameterEvents();
     304      classNamesCache = new List<string>();
     305      for (int i = 0; i < ClassNamesParameter.Value.Rows; i++)
     306        classNamesCache.Add(ClassNamesParameter.Value[i, 0]);
    299307    }
    300308    public override IDeepCloneable Clone(Cloner cloner) {
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolutionBase.cs

    r13102 r15280  
    129129      TestNormalizedGiniCoefficient = testNormalizedGini;
    130130
     131      ClassificationPerformanceMeasures.Reset();
     132
    131133      trainingPerformanceCalculator.Calculate(originalTrainingClassValues, estimatedTrainingClassValues);
    132134      if (trainingPerformanceCalculator.ErrorState == OnlineCalculatorError.None)
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs

    r12509 r15280  
    3333  [StorableClass]
    3434  [Item("DiscriminantFunctionClassificationModel", "Represents a classification model that uses a discriminant function and classification thresholds.")]
    35   public class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel {
     35  public class DiscriminantFunctionClassificationModel : ClassificationModel, IDiscriminantFunctionClassificationModel {
     36    public override IEnumerable<string> VariablesUsedForPrediction {
     37      get { return model.VariablesUsedForPrediction; }
     38    }
     39
    3640    [Storable]
    3741    private IRegressionModel model;
     
    7377
    7478    public DiscriminantFunctionClassificationModel(IRegressionModel model, IDiscriminantFunctionThresholdCalculator thresholdCalculator)
    75       : base() {
     79      : base(model.TargetVariable) {
    7680      this.name = ItemName;
    7781      this.description = ItemDescription;
     82
    7883      this.model = model;
    7984      this.classValues = new double[0];
     
    115120    }
    116121
    117     public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
     122    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
    118123      if (!Thresholds.Any() && !ClassValues.Any()) throw new ArgumentException("No thresholds and class values were set for the current classification model.");
    119124      foreach (var x in GetEstimatedValues(dataset, rows)) {
     
    135140    #endregion
    136141
    137     public virtual IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData) {
     142    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
     143      return CreateDiscriminantFunctionClassificationSolution(problemData);
     144    }
     145    public virtual IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(
     146      IClassificationProblemData problemData) {
    138147      return new DiscriminantFunctionClassificationSolution(this, new ClassificationProblemData(problemData));
    139     }
    140 
    141     public virtual IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
    142       return CreateDiscriminantFunctionClassificationSolution(problemData);
    143148    }
    144149  }
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Clustering/ClusteringProblem.cs

    r12504 r15280  
    3333    public override IDeepCloneable Clone(Cloner cloner) { return new ClusteringProblem(this, cloner); }
    3434
    35     public ClusteringProblem()
    36       : base() {
    37       ProblemData = new ClusteringProblemData();
    38     }
     35    public ClusteringProblem() : base(new ClusteringProblemData()) { }
    3936  }
    4037}
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/ConstantModel.cs

    r13154 r15280  
    3131  [StorableClass]
    3232  [Item("Constant Model", "A model that always returns the same constant value regardless of the presented input data.")]
    33   public class ConstantModel : NamedItem, IRegressionModel, IClassificationModel, ITimeSeriesPrognosisModel, IStringConvertibleValue {
     33  public class ConstantModel : RegressionModel, IClassificationModel, ITimeSeriesPrognosisModel, IStringConvertibleValue {
     34    public override IEnumerable<string> VariablesUsedForPrediction { get { return Enumerable.Empty<string>(); } }
     35
     36
    3437    [Storable]
    35     private double constant;
     38    private readonly double constant;
    3639    public double Constant {
    3740      get { return constant; }
     
    4548      this.constant = original.constant;
    4649    }
     50
    4751    public override IDeepCloneable Clone(Cloner cloner) { return new ConstantModel(this, cloner); }
    4852
    49     public ConstantModel(double constant)
    50       : base() {
     53    public ConstantModel(double constant, string targetVariable)
     54      : base(targetVariable) {
    5155      this.name = ItemName;
    5256      this.description = ItemDescription;
     
    5559    }
    5660
    57     public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
     61    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    5862      return rows.Select(row => Constant);
    5963    }
     
    6569    }
    6670
    67     public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     71    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    6872      return new ConstantRegressionSolution(this, new RegressionProblemData(problemData));
    6973    }
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisProblem.cs

    r12012 r15280  
    4949    public T ProblemData {
    5050      get { return ProblemDataParameter.Value; }
    51       protected set {
    52         ProblemDataParameter.Value = value;
    53       }
     51      set { ProblemDataParameter.Value = value; }
    5452    }
    5553    #endregion
     
    6058    [StorableConstructor]
    6159    protected DataAnalysisProblem(bool deserializing) : base(deserializing) { }
    62     public DataAnalysisProblem()
    63       : base() {
    64       Parameters.Add(new ValueParameter<T>(ProblemDataParameterName, ProblemDataParameterDescription));
    65       RegisterEventHandlers();
    66     }
    6760
    6861    protected DataAnalysisProblem(T problemData)
    69       : this() {
    70       ProblemData = problemData;
     62      : base() {
     63      Parameters.Add(new ValueParameter<T>(ProblemDataParameterName, ProblemDataParameterDescription, problemData));
     64      RegisterEventHandlers();
    7165    }
    7266
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisProblemData.cs

    r12509 r15280  
    8080    }
    8181
     82    public virtual IEnumerable<int> AllIndices {
     83      get { return Enumerable.Range(0, Dataset.Rows); }
     84    }
    8285    public virtual IEnumerable<int> TrainingIndices {
    8386      get {
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/ConstantRegressionModel.cs

    r13100 r15280  
    3232  [Item("Constant Regression Model", "A model that always returns the same constant value regardless of the presented input data.")]
    3333  [Obsolete]
    34   public class ConstantRegressionModel : NamedItem, IRegressionModel, IStringConvertibleValue {
     34  public class ConstantRegressionModel : RegressionModel, IStringConvertibleValue {
     35    public override IEnumerable<string> VariablesUsedForPrediction { get { return Enumerable.Empty<string>(); } }
     36
    3537    [Storable]
    3638    private double constant;
     
    4648      this.constant = original.constant;
    4749    }
     50
    4851    public override IDeepCloneable Clone(Cloner cloner) { return new ConstantRegressionModel(this, cloner); }
    4952
    50     public ConstantRegressionModel(double constant)
    51       : base() {
     53    public ConstantRegressionModel(double constant, string targetVariable)
     54      : base(targetVariable) {
    5255      this.name = ItemName;
    5356      this.description = ItemDescription;
     
    5659    }
    5760
    58     public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
     61    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    5962      return rows.Select(row => Constant);
    6063    }
    6164
    62     public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    63       return new ConstantRegressionSolution(new ConstantModel(constant), new RegressionProblemData(problemData));
     65    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     66      return new ConstantRegressionSolution(new ConstantModel(constant, TargetVariable), new RegressionProblemData(problemData));
    6467    }
    6568
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r12509 r15280  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     
    3233  [StorableClass]
    3334  [Item("RegressionEnsembleModel", "A regression model that contains an ensemble of multiple regression models")]
    34   public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
     35  public sealed class RegressionEnsembleModel : RegressionModel, IRegressionEnsembleModel {
     36    public override IEnumerable<string> VariablesUsedForPrediction {
     37      get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); }
     38    }
    3539
    3640    private List<IRegressionModel> models;
     
    4549    }
    4650
     51    private List<double> modelWeights;
     52    public IEnumerable<double> ModelWeights {
     53      get { return modelWeights; }
     54    }
     55
     56    [Storable(Name = "ModelWeights")]
     57    private IEnumerable<double> StorableModelWeights {
     58      get { return modelWeights; }
     59      set { modelWeights = value.ToList(); }
     60    }
     61
     62    [Storable]
     63    private bool averageModelEstimates = true;
     64    public bool AverageModelEstimates {
     65      get { return averageModelEstimates; }
     66      set {
     67        if (averageModelEstimates != value) {
     68          averageModelEstimates = value;
     69          OnChanged();
     70        }
     71      }
     72    }
     73
    4774    #region backwards compatiblity 3.3.5
    4875    [Storable(Name = "models", AllowOneWay = true)]
     
    5279    #endregion
    5380
     81    [StorableHook(HookType.AfterDeserialization)]
     82    private void AfterDeserialization() {
     83      // BackwardsCompatibility 3.3.14
     84      #region Backwards compatible code, remove with 3.4
     85      if (modelWeights == null || !modelWeights.Any())
     86        modelWeights = new List<double>(models.Select(m => 1.0));
     87      #endregion
     88    }
     89
    5490    [StorableConstructor]
    55     protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
    56     protected RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
     91    private RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
     92    private RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
    5793      : base(original, cloner) {
    58       this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
     94      this.models = original.Models.Select(cloner.Clone).ToList();
     95      this.modelWeights = new List<double>(original.ModelWeights);
     96      this.averageModelEstimates = original.averageModelEstimates;
     97    }
     98    public override IDeepCloneable Clone(Cloner cloner) {
     99      return new RegressionEnsembleModel(this, cloner);
    59100    }
    60101
    61102    public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { }
    62     public RegressionEnsembleModel(IEnumerable<IRegressionModel> models)
    63       : base() {
     103    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) : this(models, models.Select(m => 1.0)) { }
     104    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models, IEnumerable<double> modelWeights)
     105      : base(string.Empty) {
    64106      this.name = ItemName;
    65107      this.description = ItemDescription;
     108
    66109      this.models = new List<IRegressionModel>(models);
    67     }
    68 
    69     public override IDeepCloneable Clone(Cloner cloner) {
    70       return new RegressionEnsembleModel(this, cloner);
    71     }
    72 
    73     #region IRegressionEnsembleModel Members
     110      this.modelWeights = new List<double>(modelWeights);
     111
     112      if (this.models.Any()) this.TargetVariable = this.models.First().TargetVariable;
     113    }
    74114
    75115    public void Add(IRegressionModel model) {
     116      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
     117      Add(model, 1.0);
     118    }
     119    public void Add(IRegressionModel model, double weight) {
     120      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
     121
    76122      models.Add(model);
    77     }
     123      modelWeights.Add(weight);
     124      OnChanged();
     125    }
     126
     127    public void AddRange(IEnumerable<IRegressionModel> models) {
     128      AddRange(models, models.Select(m => 1.0));
     129    }
     130    public void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) {
     131      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = models.First().TargetVariable;
     132
     133      this.models.AddRange(models);
     134      modelWeights.AddRange(weights);
     135      OnChanged();
     136    }
     137
    78138    public void Remove(IRegressionModel model) {
    79       models.Remove(model);
    80     }
    81 
     139      var index = models.IndexOf(model);
     140      models.RemoveAt(index);
     141      modelWeights.RemoveAt(index);
     142
     143      if (!models.Any()) TargetVariable = string.Empty;
     144      OnChanged();
     145    }
     146    public void RemoveRange(IEnumerable<IRegressionModel> models) {
     147      foreach (var model in models) {
     148        var index = this.models.IndexOf(model);
     149        this.models.RemoveAt(index);
     150        modelWeights.RemoveAt(index);
     151      }
     152
     153      if (!models.Any()) TargetVariable = string.Empty;
     154      OnChanged();
     155    }
     156
     157    public double GetModelWeight(IRegressionModel model) {
     158      var index = models.IndexOf(model);
     159      return modelWeights[index];
     160    }
     161    public void SetModelWeight(IRegressionModel model, double weight) {
     162      var index = models.IndexOf(model);
     163      modelWeights[index] = weight;
     164      OnChanged();
     165    }
     166
     167    #region evaluation
    82168    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
    83169      var estimatedValuesEnumerators = (from model in models
    84                                         select model.GetEstimatedValues(dataset, rows).GetEnumerator())
    85                                        .ToList();
     170                                        let weight = GetModelWeight(model)
     171                                        select model.GetEstimatedValues(dataset, rows).Select(e => weight * e)
     172                                        .GetEnumerator()).ToList();
    86173
    87174      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
     
    91178    }
    92179
     180    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
     181      double weightsSum = modelWeights.Sum();
     182      var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)
     183                            select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum();
     184
     185      if (AverageModelEstimates)
     186        return summedEstimates.Select(v => v / weightsSum);
     187      else
     188        return summedEstimates;
     189
     190    }
     191
     192    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
     193      var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator();
     194      var rowsEnumerator = rows.GetEnumerator();
     195
     196      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) {
     197        var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator();
     198        int currentRow = rowsEnumerator.Current;
     199        double weightsSum = 0.0;
     200        double filteredEstimatesSum = 0.0;
     201
     202        for (int m = 0; m < models.Count; m++) {
     203          estimatedValueEnumerator.MoveNext();
     204          var model = models[m];
     205          if (!modelSelectionPredicate(currentRow, model)) continue;
     206
     207          filteredEstimatesSum += estimatedValueEnumerator.Current;
     208          weightsSum += modelWeights[m];
     209        }
     210
     211        if (AverageModelEstimates)
     212          yield return filteredEstimatesSum / weightsSum;
     213        else
     214          yield return filteredEstimatesSum;
     215      }
     216    }
     217
    93218    #endregion
    94219
    95     #region IRegressionModel Members
    96 
    97     public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    98       foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) {
    99         yield return estimatedValuesVector.Average();
    100       }
    101     }
    102 
    103     public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    104       return new RegressionEnsembleSolution(this.Models, new RegressionEnsembleProblemData(problemData));
    105     }
    106     IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
    107       return CreateRegressionSolution(problemData);
    108     }
    109 
    110     #endregion
     220    public event EventHandler Changed;
     221    private void OnChanged() {
     222      var handler = Changed;
     223      if (handler != null)
     224        handler(this, EventArgs.Empty);
     225    }
     226
     227
     228    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     229      return new RegressionEnsembleSolution(this, new RegressionEnsembleProblemData(problemData));
     230    }
    111231  }
    112232}
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r12816 r15280  
    7979        }
    8080      }
     81
     82      RegisterModelEvents();
    8183      RegisterRegressionSolutionsEventHandler();
    8284    }
     
    9395      }
    9496
     97      evaluationCache = new Dictionary<int, double>(original.ProblemData.Dataset.Rows);
    9598      trainingEvaluationCache = new Dictionary<int, double>(original.ProblemData.TrainingIndices.Count());
    9699      testEvaluationCache = new Dictionary<int, double>(original.ProblemData.TestIndices.Count());
    97100
    98101      regressionSolutions = cloner.Clone(original.regressionSolutions);
     102      RegisterModelEvents();
    99103      RegisterRegressionSolutionsEventHandler();
    100104    }
     
    106110      regressionSolutions = new ItemCollection<IRegressionSolution>();
    107111
     112      RegisterModelEvents();
    108113      RegisterRegressionSolutionsEventHandler();
    109114    }
    110115
    111116    public RegressionEnsembleSolution(IRegressionProblemData problemData)
    112       : this(Enumerable.Empty<IRegressionModel>(), problemData) {
    113     }
    114 
    115     public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
    116       : this(models, problemData,
    117              models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
    118              models.Select(m => (IntRange)problemData.TestPartition.Clone())
    119       ) { }
    120 
    121     public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
    122       : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) {
    123       this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    124       this.testPartitions = new Dictionary<IRegressionModel, IntRange>();
    125       this.regressionSolutions = new ItemCollection<IRegressionSolution>();
    126 
    127       List<IRegressionSolution> solutions = new List<IRegressionSolution>();
    128       var modelEnumerator = models.GetEnumerator();
    129       var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
    130       var testPartitionEnumerator = testPartitions.GetEnumerator();
    131 
    132       while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
    133         var p = (IRegressionProblemData)problemData.Clone();
    134         p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
    135         p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
    136         p.TestPartition.Start = testPartitionEnumerator.Current.Start;
    137         p.TestPartition.End = testPartitionEnumerator.Current.End;
    138 
    139         solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p));
    140       }
    141       if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
    142         throw new ArgumentException();
    143       }
    144 
     117      : this(new RegressionEnsembleModel(), problemData) {
     118    }
     119
     120    public RegressionEnsembleSolution(IRegressionEnsembleModel model, IRegressionProblemData problemData)
     121      : base(model, new RegressionEnsembleProblemData(problemData)) {
     122      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
     123      testPartitions = new Dictionary<IRegressionModel, IntRange>();
     124      regressionSolutions = new ItemCollection<IRegressionSolution>();
     125
     126      evaluationCache = new Dictionary<int, double>(problemData.Dataset.Rows);
    145127      trainingEvaluationCache = new Dictionary<int, double>(problemData.TrainingIndices.Count());
    146128      testEvaluationCache = new Dictionary<int, double>(problemData.TestIndices.Count());
    147129
     130
     131      var solutions = model.Models.Select(m => m.CreateRegressionSolution((IRegressionProblemData)problemData.Clone()));
     132      foreach (var solution in solutions) {
     133        regressionSolutions.Add(solution);
     134        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     135        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     136      }
     137
     138      RecalculateResults();
     139      RegisterModelEvents();
    148140      RegisterRegressionSolutionsEventHandler();
    149       regressionSolutions.AddRange(solutions);
    150     }
     141    }
     142
    151143
    152144    public override IDeepCloneable Clone(Cloner cloner) {
    153145      return new RegressionEnsembleSolution(this, cloner);
     146    }
     147
     148    private void RegisterModelEvents() {
     149      Model.Changed += Model_Changed;
    154150    }
    155151    private void RegisterRegressionSolutionsEventHandler() {
     
    168164        var rows = ProblemData.TrainingIndices;
    169165        var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys);
     166
    170167        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    171         var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
     168        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
    172169
    173170        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    184181        var rowsToEvaluate = rows.Except(testEvaluationCache.Keys);
    185182        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    186         var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();
     183        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, RowIsTestForModel).GetEnumerator();
    187184
    188185        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    193190      }
    194191    }
    195 
    196     private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
    197       var estimatedValuesEnumerators = (from model in Model.Models
    198                                         select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    199                                        .ToList();
    200       var rowsEnumerator = rows.GetEnumerator();
    201       // aggregate to make sure that MoveNext is called for all enumerators
    202       while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    203         int currentRow = rowsEnumerator.Current;
    204 
    205         var selectedEnumerators = from pair in estimatedValuesEnumerators
    206                                   where modelSelectionPredicate(currentRow, pair.Model)
    207                                   select pair.EstimatedValuesEnumerator;
    208 
    209         yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
    210       }
    211     }
    212 
    213192    private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) {
    214193      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
    215194              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
    216195    }
    217 
    218196    private bool RowIsTestForModel(int currentRow, IRegressionModel model) {
    219197      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     
    224202      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
    225203      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    226       var valuesEnumerator = (from xs in GetEstimatedValueVectors(ProblemData.Dataset, rowsToEvaluate)
    227                               select AggregateEstimatedValues(xs))
    228                              .GetEnumerator();
     204      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
    229205
    230206      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    235211    }
    236212
    237     public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
    238       if (!Model.Models.Any()) yield break;
    239       var estimatedValuesEnumerators = (from model in Model.Models
    240                                         select model.GetEstimatedValues(dataset, rows).GetEnumerator())
    241                                        .ToList();
    242 
    243       while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
    244         yield return from enumerator in estimatedValuesEnumerators
    245                      select enumerator.Current;
    246       }
    247     }
    248 
    249     private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) {
    250       return estimatedValues.DefaultIfEmpty(double.NaN).Average();
     213    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IEnumerable<int> rows) {
     214      return Model.GetEstimatedValueVectors(ProblemData.Dataset, rows);
    251215    }
    252216    #endregion
     
    282246    }
    283247
    284     public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
    285       regressionSolutions.AddRange(solutions);
     248    private void Model_Changed(object sender, EventArgs e) {
     249      var modelSet = new HashSet<IRegressionModel>(Model.Models);
     250      foreach (var model in Model.Models) {
     251        if (!trainingPartitions.ContainsKey(model)) trainingPartitions.Add(model, ProblemData.TrainingPartition);
     252        if (!testPartitions.ContainsKey(model)) testPartitions.Add(model, ProblemData.TrainingPartition);
     253      }
     254      foreach (var model in trainingPartitions.Keys) {
     255        if (modelSet.Contains(model)) continue;
     256        trainingPartitions.Remove(model);
     257        testPartitions.Remove(model);
     258      }
    286259
    287260      trainingEvaluationCache.Clear();
    288261      testEvaluationCache.Clear();
    289262      evaluationCache.Clear();
     263
     264      OnModelChanged();
     265    }
     266
     267    public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     268      regressionSolutions.AddRange(solutions);
    290269    }
    291270    public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
    292271      regressionSolutions.RemoveRange(solutions);
    293 
    294       trainingEvaluationCache.Clear();
    295       testEvaluationCache.Clear();
    296       evaluationCache.Clear();
    297272    }
    298273
    299274    private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    300       foreach (var solution in e.Items) AddRegressionSolution(solution);
    301       RecalculateResults();
     275      foreach (var solution in e.Items) {
     276        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     277        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     278      }
     279      Model.AddRange(e.Items.Select(s => s.Model));
    302280    }
    303281    private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    304       foreach (var solution in e.Items) RemoveRegressionSolution(solution);
    305       RecalculateResults();
     282      foreach (var solution in e.Items) {
     283        trainingPartitions.Remove(solution.Model);
     284        testPartitions.Remove(solution.Model);
     285      }
     286      Model.RemoveRange(e.Items.Select(s => s.Model));
    306287    }
    307288    private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    308       foreach (var solution in e.OldItems) RemoveRegressionSolution(solution);
    309       foreach (var solution in e.Items) AddRegressionSolution(solution);
    310       RecalculateResults();
    311     }
    312 
    313     private void AddRegressionSolution(IRegressionSolution solution) {
    314       if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
    315       Model.Add(solution.Model);
    316       trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    317       testPartitions[solution.Model] = solution.ProblemData.TestPartition;
    318 
    319       trainingEvaluationCache.Clear();
    320       testEvaluationCache.Clear();
    321       evaluationCache.Clear();
    322     }
    323 
    324     private void RemoveRegressionSolution(IRegressionSolution solution) {
    325       if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
    326       Model.Remove(solution.Model);
    327       trainingPartitions.Remove(solution.Model);
    328       testPartitions.Remove(solution.Model);
    329 
    330       trainingEvaluationCache.Clear();
    331       testEvaluationCache.Clear();
    332       evaluationCache.Clear();
     289      foreach (var solution in e.OldItems) {
     290        trainingPartitions.Remove(solution.Model);
     291        testPartitions.Remove(solution.Model);
     292      }
     293      Model.RemoveRange(e.OldItems.Select(s => s.Model));
     294
     295      foreach (var solution in e.Items) {
     296        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     297        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     298      }
     299      Model.AddRange(e.Items.Select(s => s.Model));
    333300    }
    334301  }
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs

    r12509 r15280  
    110110    }
    111111
     112    public IEnumerable<double> TargetVariableValues {
     113      get { return Dataset.GetDoubleValues(TargetVariable); }
     114    }
     115    public IEnumerable<double> TargetVariableTrainingValues {
     116      get { return Dataset.GetDoubleValues(TargetVariable, TrainingIndices); }
     117    }
     118    public IEnumerable<double> TargetVariableTestValues {
     119      get { return Dataset.GetDoubleValues(TargetVariable, TestIndices); }
     120    }
     121
     122
    112123    [StorableConstructor]
    113124    protected RegressionProblemData(bool deserializing) : base(deserializing) { }
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/TimeSeriesPrognosis/Models/ConstantTimeSeriesPrognosisModel.cs

    r13100 r15280  
    3939    }
    4040
    41     public ConstantTimeSeriesPrognosisModel(double constant) : base(constant) { }
     41    public ConstantTimeSeriesPrognosisModel(double constant, string targetVariable) : base(constant, targetVariable) { }
    4242
    4343    public IEnumerable<IEnumerable<double>> GetPrognosedValues(IDataset dataset, IEnumerable<int> rows, IEnumerable<int> horizons) {
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/TimeSeriesPrognosis/Models/TimeSeriesPrognosisAutoRegressiveModel.cs

    r12509 r15280  
    3030  [StorableClass]
    3131  [Item("Autoregressive TimeSeries Model", "A linear autoregressive time series model used to predict future values.")]
    32   public class TimeSeriesPrognosisAutoRegressiveModel : NamedItem, ITimeSeriesPrognosisModel {
     32  public class TimeSeriesPrognosisAutoRegressiveModel : RegressionModel, ITimeSeriesPrognosisModel {
     33    public override IEnumerable<string> VariablesUsedForPrediction {
     34      get { return new[] { TargetVariable }; }
     35    }
     36
    3337    [Storable]
    3438    public double[] Phi { get; private set; }
    3539    [Storable]
    3640    public double Constant { get; private set; }
    37     [Storable]
    38     public string TargetVariable { get; private set; }
    3941
    4042    public int TimeOffset { get { return Phi.Length; } }
     
    4648      this.Phi = (double[])original.Phi.Clone();
    4749      this.Constant = original.Constant;
    48       this.TargetVariable = original.TargetVariable;
    4950    }
    5051    public override IDeepCloneable Clone(Cloner cloner) {
     
    5253    }
    5354    public TimeSeriesPrognosisAutoRegressiveModel(string targetVariable, double[] phi, double constant)
    54       : base("AR(1) Model") {
     55      : base(targetVariable, "AR(1) Model") {
    5556      Phi = (double[])phi.Clone();
    5657      Constant = constant;
    57       TargetVariable = targetVariable;
    5858    }
    5959
     
    9191    }
    9292
    93     public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
     93    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    9494      var targetVariables = dataset.GetReadOnlyDoubleValues(TargetVariable);
    9595      foreach (int row in rows) {
     
    111111      return new TimeSeriesPrognosisSolution(this, new TimeSeriesPrognosisProblemData(problemData));
    112112    }
    113     public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     113    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    114114      throw new NotSupportedException();
    115115    }
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/TimeSeriesPrognosis/TimeSeriesPrognosisResults.cs

    r13100 r15280  
    373373      //mean model
    374374      double trainingMean = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).Average();
    375       var meanModel = new ConstantModel(trainingMean);
     375      var meanModel = new ConstantModel(trainingMean, problemData.TargetVariable);
    376376
    377377      //AR1 model
     
    395395      PrognosisTrainingMeanAbsoluteError = errorState == OnlineCalculatorError.None ? trainingMAE : double.NaN;
    396396      double trainingR = OnlinePearsonsRCalculator.Calculate(originalTrainingValues, estimatedTrainingValues, out errorState);
    397       PrognosisTrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR*trainingR : double.NaN;
     397      PrognosisTrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR * trainingR : double.NaN;
    398398      double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(originalTrainingValues, estimatedTrainingValues, out errorState);
    399399      PrognosisTrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;
     
    431431      PrognosisTestMeanAbsoluteError = errorState == OnlineCalculatorError.None ? testMAE : double.NaN;
    432432      double testR = OnlinePearsonsRCalculator.Calculate(originalTestValues, estimatedTestValues, out errorState);
    433       PrognosisTestRSquared = errorState == OnlineCalculatorError.None ? testR*testR : double.NaN;
     433      PrognosisTestRSquared = errorState == OnlineCalculatorError.None ? testR * testR : double.NaN;
    434434      double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(originalTestValues, estimatedTestValues, out errorState);
    435435      PrognosisTestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;
     
    448448        //mean model
    449449        double trainingMean = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).Average();
    450         var meanModel = new ConstantModel(trainingMean);
     450        var meanModel = new ConstantModel(trainingMean, problemData.TargetVariable);
    451451
    452452        //AR1 model
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/TimeSeriesPrognosis/TimeSeriesPrognosisSolutionBase.cs

    r13100 r15280  
    150150      OnlineCalculatorError errorState;
    151151      double trainingMean = ProblemData.TrainingIndices.Any() ? ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).Average() : double.NaN;
    152       var meanModel = new ConstantModel(trainingMean);
     152      var meanModel = new ConstantModel(trainingMean,ProblemData.TargetVariable);
    153153
    154154      double alpha, beta;
  • branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Transformations/Transformation.cs

    r12612 r15280  
    3131
    3232  [Item("Transformation", "Represents the base class for a transformation.")]
     33  [StorableClass]
    3334  public abstract class Transformation : ParameterizedNamedItem, ITransformation {
    3435    protected const string ColumnParameterName = "Column";
Note: See TracChangeset for help on using the changeset viewer.