Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/08/16 14:40:02 (8 years ago)
Author:
gkronber
Message:

#2434: merged trunk changes r12934:14026 from trunk to branch

Location:
branches/crossvalidation-2434
Files:
7 edited
2 copied

Legend:

Unmodified
Added
Removed
  • branches/crossvalidation-2434

  • branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis

  • branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/ConstantRegressionModel.cs

    r12509 r14029  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
    2425using HeuristicLab.Common;
    2526using HeuristicLab.Core;
     27using HeuristicLab.Data;
    2628using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2729
     
    2931  [StorableClass]
    3032  [Item("Constant Regression Model", "A model that always returns the same constant value regardless of the presented input data.")]
    31   public class ConstantRegressionModel : NamedItem, IRegressionModel {
     33  [Obsolete]
     34  public class ConstantRegressionModel : RegressionModel, IStringConvertibleValue {
     35    public override IEnumerable<string> VariablesUsedForPrediction { get { return Enumerable.Empty<string>(); } }
     36
    3237    [Storable]
    33     protected double constant;
     38    private double constant;
    3439    public double Constant {
    3540      get { return constant; }
     41      // setter not implemented because manipulation of the constant is not allowed
    3642    }
    3743
     
    4248      this.constant = original.constant;
    4349    }
     50
    4451    public override IDeepCloneable Clone(Cloner cloner) { return new ConstantRegressionModel(this, cloner); }
    4552
    46     public ConstantRegressionModel(double constant)
    47       : base() {
     53    public ConstantRegressionModel(double constant, string targetVariable)
     54      : base(targetVariable) {
    4855      this.name = ItemName;
    4956      this.description = ItemDescription;
    5057      this.constant = constant;
     58      this.ReadOnly = true; // changing a constant regression model is not supported
    5159    }
    5260
    53     public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
     61    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    5462      return rows.Select(row => Constant);
    5563    }
    5664
    57     public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    58       return new ConstantRegressionSolution(this, new RegressionProblemData(problemData));
     65    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     66      return new ConstantRegressionSolution(new ConstantModel(constant, TargetVariable), new RegressionProblemData(problemData));
    5967    }
     68
     69    public override string ToString() {
     70      return string.Format("Constant: {0}", GetValue());
     71    }
     72
     73    #region IStringConvertibleValue
     74    public bool ReadOnly { get; private set; }
     75    public bool Validate(string value, out string errorMessage) {
     76      throw new NotSupportedException(); // changing a constant regression model is not supported
     77    }
     78
     79    public string GetValue() {
     80      return string.Format("{0:E4}", constant);
     81    }
     82
     83    public bool SetValue(string value) {
     84      throw new NotSupportedException(); // changing a constant regression model is not supported
     85    }
     86
     87    public event EventHandler ValueChanged;
     88    #endregion
    6089  }
    6190}
  • branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/ConstantRegressionSolution.cs

    r12012 r14029  
    2828  [Item(Name = "Constant Regression Solution", Description = "Represents a constant regression solution (model + data).")]
    2929  public class ConstantRegressionSolution : RegressionSolution {
    30     public new ConstantRegressionModel Model {
    31       get { return (ConstantRegressionModel)base.Model; }
     30    public new ConstantModel Model {
     31      get { return (ConstantModel)base.Model; }
    3232      set { base.Model = value; }
    3333    }
     
    3636    protected ConstantRegressionSolution(bool deserializing) : base(deserializing) { }
    3737    protected ConstantRegressionSolution(ConstantRegressionSolution original, Cloner cloner) : base(original, cloner) { }
    38     public ConstantRegressionSolution(ConstantRegressionModel model, IRegressionProblemData problemData)
     38    public ConstantRegressionSolution(ConstantModel model, IRegressionProblemData problemData)
    3939      : base(model, problemData) {
    4040      RecalculateResults();
  • branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r12509 r14029  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     
    3233  [StorableClass]
    3334  [Item("RegressionEnsembleModel", "A regression model that contains an ensemble of multiple regression models")]
    34   public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
     35  public sealed class RegressionEnsembleModel : RegressionModel, IRegressionEnsembleModel {
     36    public override IEnumerable<string> VariablesUsedForPrediction {
     37      get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); }
     38    }
    3539
    3640    private List<IRegressionModel> models;
     
    4549    }
    4650
     51    private List<double> modelWeights;
     52    public IEnumerable<double> ModelWeights {
     53      get { return modelWeights; }
     54    }
     55
     56    [Storable(Name = "ModelWeights")]
     57    private IEnumerable<double> StorableModelWeights {
     58      get { return modelWeights; }
     59      set { modelWeights = value.ToList(); }
     60    }
     61
     62    [Storable]
     63    private bool averageModelEstimates = true;
     64    public bool AverageModelEstimates {
     65      get { return averageModelEstimates; }
     66      set {
     67        if (averageModelEstimates != value) {
     68          averageModelEstimates = value;
     69          OnChanged();
     70        }
     71      }
     72    }
     73
    4774    #region backwards compatiblity 3.3.5
    4875    [Storable(Name = "models", AllowOneWay = true)]
     
    5279    #endregion
    5380
     81    [StorableHook(HookType.AfterDeserialization)]
     82    private void AfterDeserialization() {
     83      // BackwardsCompatibility 3.3.14
     84      #region Backwards compatible code, remove with 3.4
     85      if (modelWeights == null || !modelWeights.Any())
     86        modelWeights = new List<double>(models.Select(m => 1.0));
     87      #endregion
     88    }
     89
    5490    [StorableConstructor]
    55     protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
    56     protected RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
     91    private RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
     92    private RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
    5793      : base(original, cloner) {
    58       this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
     94      this.models = original.Models.Select(cloner.Clone).ToList();
     95      this.modelWeights = new List<double>(original.ModelWeights);
     96      this.averageModelEstimates = original.averageModelEstimates;
     97    }
     98    public override IDeepCloneable Clone(Cloner cloner) {
     99      return new RegressionEnsembleModel(this, cloner);
    59100    }
    60101
    61102    public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { }
    62     public RegressionEnsembleModel(IEnumerable<IRegressionModel> models)
    63       : base() {
     103    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) : this(models, models.Select(m => 1.0)) { }
     104    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models, IEnumerable<double> modelWeights)
     105      : base(string.Empty) {
    64106      this.name = ItemName;
    65107      this.description = ItemDescription;
     108
    66109      this.models = new List<IRegressionModel>(models);
    67     }
    68 
    69     public override IDeepCloneable Clone(Cloner cloner) {
    70       return new RegressionEnsembleModel(this, cloner);
    71     }
    72 
    73     #region IRegressionEnsembleModel Members
     110      this.modelWeights = new List<double>(modelWeights);
     111
     112      if (this.models.Any()) this.TargetVariable = this.models.First().TargetVariable;
     113    }
    74114
    75115    public void Add(IRegressionModel model) {
     116      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
     117      Add(model, 1.0);
     118    }
     119    public void Add(IRegressionModel model, double weight) {
     120      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
     121
    76122      models.Add(model);
    77     }
     123      modelWeights.Add(weight);
     124      OnChanged();
     125    }
     126
     127    public void AddRange(IEnumerable<IRegressionModel> models) {
     128      AddRange(models, models.Select(m => 1.0));
     129    }
     130    public void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) {
     131      if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = models.First().TargetVariable;
     132
     133      this.models.AddRange(models);
     134      modelWeights.AddRange(weights);
     135      OnChanged();
     136    }
     137
    78138    public void Remove(IRegressionModel model) {
    79       models.Remove(model);
    80     }
    81 
     139      var index = models.IndexOf(model);
     140      models.RemoveAt(index);
     141      modelWeights.RemoveAt(index);
     142
     143      if (!models.Any()) TargetVariable = string.Empty;
     144      OnChanged();
     145    }
     146    public void RemoveRange(IEnumerable<IRegressionModel> models) {
     147      foreach (var model in models) {
     148        var index = this.models.IndexOf(model);
     149        this.models.RemoveAt(index);
     150        modelWeights.RemoveAt(index);
     151      }
     152
     153      if (!models.Any()) TargetVariable = string.Empty;
     154      OnChanged();
     155    }
     156
     157    public double GetModelWeight(IRegressionModel model) {
     158      var index = models.IndexOf(model);
     159      return modelWeights[index];
     160    }
     161    public void SetModelWeight(IRegressionModel model, double weight) {
     162      var index = models.IndexOf(model);
     163      modelWeights[index] = weight;
     164      OnChanged();
     165    }
     166
     167    #region evaluation
    82168    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
    83169      var estimatedValuesEnumerators = (from model in models
    84                                         select model.GetEstimatedValues(dataset, rows).GetEnumerator())
    85                                        .ToList();
     170                                        let weight = GetModelWeight(model)
     171                                        select model.GetEstimatedValues(dataset, rows).Select(e => weight * e)
     172                                        .GetEnumerator()).ToList();
    86173
    87174      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
     
    91178    }
    92179
     180    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
     181      double weightsSum = modelWeights.Sum();
     182      var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)
     183                            select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum();
     184
     185      if (AverageModelEstimates)
     186        return summedEstimates.Select(v => v / weightsSum);
     187      else
     188        return summedEstimates;
     189
     190    }
     191
     192    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
     193      var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator();
     194      var rowsEnumerator = rows.GetEnumerator();
     195
     196      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) {
     197        var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator();
     198        int currentRow = rowsEnumerator.Current;
     199        double weightsSum = 0.0;
     200        double filteredEstimatesSum = 0.0;
     201
     202        for (int m = 0; m < models.Count; m++) {
     203          estimatedValueEnumerator.MoveNext();
     204          var model = models[m];
     205          if (!modelSelectionPredicate(currentRow, model)) continue;
     206
     207          filteredEstimatesSum += estimatedValueEnumerator.Current;
     208          weightsSum += modelWeights[m];
     209        }
     210
     211        if (AverageModelEstimates)
     212          yield return filteredEstimatesSum / weightsSum;
     213        else
     214          yield return filteredEstimatesSum;
     215      }
     216    }
     217
    93218    #endregion
    94219
    95     #region IRegressionModel Members
    96 
    97     public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    98       foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) {
    99         yield return estimatedValuesVector.Average();
    100       }
    101     }
    102 
    103     public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    104       return new RegressionEnsembleSolution(this.Models, new RegressionEnsembleProblemData(problemData));
    105     }
    106     IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
    107       return CreateRegressionSolution(problemData);
    108     }
    109 
    110     #endregion
     220    public event EventHandler Changed;
     221    private void OnChanged() {
     222      var handler = Changed;
     223      if (handler != null)
     224        handler(this, EventArgs.Empty);
     225    }
     226
     227
     228    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     229      return new RegressionEnsembleSolution(this, new RegressionEnsembleProblemData(problemData));
     230    }
    111231  }
    112232}
  • branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r12820 r14029  
    7979        }
    8080      }
     81
     82      RegisterModelEvents();
    8183      RegisterRegressionSolutionsEventHandler();
    8284    }
     
    9395      }
    9496
     97      evaluationCache = new Dictionary<int, double>(original.ProblemData.Dataset.Rows);
    9598      trainingEvaluationCache = new Dictionary<int, double>(original.ProblemData.TrainingIndices.Count());
    9699      testEvaluationCache = new Dictionary<int, double>(original.ProblemData.TestIndices.Count());
    97100
    98101      regressionSolutions = cloner.Clone(original.regressionSolutions);
     102      RegisterModelEvents();
    99103      RegisterRegressionSolutionsEventHandler();
    100104    }
     
    106110      regressionSolutions = new ItemCollection<IRegressionSolution>();
    107111
     112      RegisterModelEvents();
    108113      RegisterRegressionSolutionsEventHandler();
    109114    }
    110115
    111116    public RegressionEnsembleSolution(IRegressionProblemData problemData)
    112       : this(Enumerable.Empty<IRegressionModel>(), problemData) {
    113     }
    114 
    115     public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
    116       : this(models, problemData,
    117              models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
    118              models.Select(m => (IntRange)problemData.TestPartition.Clone())
    119       ) { }
    120 
    121     public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
    122       : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) {
    123       this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    124       this.testPartitions = new Dictionary<IRegressionModel, IntRange>();
    125       this.regressionSolutions = new ItemCollection<IRegressionSolution>();
    126 
    127       List<IRegressionSolution> solutions = new List<IRegressionSolution>();
    128       var modelEnumerator = models.GetEnumerator();
    129       var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
    130       var testPartitionEnumerator = testPartitions.GetEnumerator();
    131 
    132       while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
    133         var p = (IRegressionProblemData)problemData.Clone();
    134         p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
    135         p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
    136         p.TestPartition.Start = testPartitionEnumerator.Current.Start;
    137         p.TestPartition.End = testPartitionEnumerator.Current.End;
    138 
    139         solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p));
    140       }
    141       if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
    142         throw new ArgumentException();
    143       }
    144 
     117      : this(new RegressionEnsembleModel(), problemData) {
     118    }
     119
     120    public RegressionEnsembleSolution(IRegressionEnsembleModel model, IRegressionProblemData problemData)
     121      : base(model, new RegressionEnsembleProblemData(problemData)) {
     122      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
     123      testPartitions = new Dictionary<IRegressionModel, IntRange>();
     124      regressionSolutions = new ItemCollection<IRegressionSolution>();
     125
     126      evaluationCache = new Dictionary<int, double>(problemData.Dataset.Rows);
    145127      trainingEvaluationCache = new Dictionary<int, double>(problemData.TrainingIndices.Count());
    146128      testEvaluationCache = new Dictionary<int, double>(problemData.TestIndices.Count());
    147129
     130
     131      var solutions = model.Models.Select(m => m.CreateRegressionSolution((IRegressionProblemData)problemData.Clone()));
     132      foreach (var solution in solutions) {
     133        regressionSolutions.Add(solution);
     134        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     135        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     136      }
     137
     138      RecalculateResults();
     139      RegisterModelEvents();
    148140      RegisterRegressionSolutionsEventHandler();
    149       regressionSolutions.AddRange(solutions);
    150     }
     141    }
     142
    151143
    152144    public override IDeepCloneable Clone(Cloner cloner) {
    153145      return new RegressionEnsembleSolution(this, cloner);
     146    }
     147
     148    private void RegisterModelEvents() {
     149      Model.Changed += Model_Changed;
    154150    }
    155151    private void RegisterRegressionSolutionsEventHandler() {
     
    168164        var rows = ProblemData.TrainingIndices;
    169165        var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys);
     166
    170167        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    171         var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
     168        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
    172169
    173170        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    184181        var rowsToEvaluate = rows.Except(testEvaluationCache.Keys);
    185182        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    186         var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();
     183        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, RowIsTestForModel).GetEnumerator();
    187184
    188185        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    193190      }
    194191    }
    195 
    196     private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
    197       var estimatedValuesEnumerators = (from model in Model.Models
    198                                         select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    199                                        .ToList();
    200       var rowsEnumerator = rows.GetEnumerator();
    201       // aggregate to make sure that MoveNext is called for all enumerators
    202       while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    203         int currentRow = rowsEnumerator.Current;
    204 
    205         var selectedEnumerators = from pair in estimatedValuesEnumerators
    206                                   where modelSelectionPredicate(currentRow, pair.Model)
    207                                   select pair.EstimatedValuesEnumerator;
    208 
    209         yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
    210       }
    211     }
    212 
    213192    private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) {
    214193      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
    215194              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
    216195    }
    217 
    218196    private bool RowIsTestForModel(int currentRow, IRegressionModel model) {
    219197      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     
    224202      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
    225203      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    226       var valuesEnumerator = (from xs in GetEstimatedValueVectors(ProblemData.Dataset, rowsToEvaluate)
    227                               select AggregateEstimatedValues(xs))
    228                              .GetEnumerator();
     204      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
    229205
    230206      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    235211    }
    236212
    237     public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
    238       if (!Model.Models.Any()) yield break;
    239       var estimatedValuesEnumerators = (from model in Model.Models
    240                                         select model.GetEstimatedValues(dataset, rows).GetEnumerator())
    241                                        .ToList();
    242 
    243       while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
    244         yield return from enumerator in estimatedValuesEnumerators
    245                      select enumerator.Current;
    246       }
    247     }
    248 
    249     private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) {
    250       return estimatedValues.DefaultIfEmpty(double.NaN).Average();
     213    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IEnumerable<int> rows) {
     214      return Model.GetEstimatedValueVectors(ProblemData.Dataset, rows);
    251215    }
    252216    #endregion
     
    282246    }
    283247
    284     public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
    285       regressionSolutions.AddRange(solutions);
     248    private void Model_Changed(object sender, EventArgs e) {
     249      var modelSet = new HashSet<IRegressionModel>(Model.Models);
     250      foreach (var model in Model.Models) {
     251        if (!trainingPartitions.ContainsKey(model)) trainingPartitions.Add(model, ProblemData.TrainingPartition);
     252        if (!testPartitions.ContainsKey(model)) testPartitions.Add(model, ProblemData.TrainingPartition);
     253      }
     254      foreach (var model in trainingPartitions.Keys) {
     255        if (modelSet.Contains(model)) continue;
     256        trainingPartitions.Remove(model);
     257        testPartitions.Remove(model);
     258      }
    286259
    287260      trainingEvaluationCache.Clear();
    288261      testEvaluationCache.Clear();
    289262      evaluationCache.Clear();
     263
     264      OnModelChanged();
     265    }
     266
     267    public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     268      regressionSolutions.AddRange(solutions);
    290269    }
    291270    public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
    292271      regressionSolutions.RemoveRange(solutions);
    293 
    294       trainingEvaluationCache.Clear();
    295       testEvaluationCache.Clear();
    296       evaluationCache.Clear();
    297272    }
    298273
    299274    private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    300       foreach (var solution in e.Items) AddRegressionSolution(solution);
    301       RecalculateResults();
     275      foreach (var solution in e.Items) {
     276        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     277        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     278      }
     279      Model.AddRange(e.Items.Select(s => s.Model));
    302280    }
    303281    private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    304       foreach (var solution in e.Items) RemoveRegressionSolution(solution);
    305       RecalculateResults();
     282      foreach (var solution in e.Items) {
     283        trainingPartitions.Remove(solution.Model);
     284        testPartitions.Remove(solution.Model);
     285      }
     286      Model.RemoveRange(e.Items.Select(s => s.Model));
    306287    }
    307288    private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    308       foreach (var solution in e.OldItems) RemoveRegressionSolution(solution);
    309       foreach (var solution in e.Items) AddRegressionSolution(solution);
    310       RecalculateResults();
    311     }
    312 
    313     private void AddRegressionSolution(IRegressionSolution solution) {
    314       if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
    315       Model.Add(solution.Model);
    316       trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    317       testPartitions[solution.Model] = solution.ProblemData.TestPartition;
    318 
    319       trainingEvaluationCache.Clear();
    320       testEvaluationCache.Clear();
    321       evaluationCache.Clear();
    322     }
    323 
    324     private void RemoveRegressionSolution(IRegressionSolution solution) {
    325       if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
    326       Model.Remove(solution.Model);
    327       trainingPartitions.Remove(solution.Model);
    328       testPartitions.Remove(solution.Model);
    329 
    330       trainingEvaluationCache.Clear();
    331       testEvaluationCache.Clear();
    332       evaluationCache.Clear();
     289      foreach (var solution in e.OldItems) {
     290        trainingPartitions.Remove(solution.Model);
     291        testPartitions.Remove(solution.Model);
     292      }
     293      Model.RemoveRange(e.OldItems.Select(s => s.Model));
     294
     295      foreach (var solution in e.Items) {
     296        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     297        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     298      }
     299      Model.AddRange(e.Items.Select(s => s.Model));
    333300    }
    334301  }
  • branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs

    r12509 r14029  
    110110    }
    111111
     112    public IEnumerable<double> TargetVariableValues {
     113      get { return Dataset.GetDoubleValues(TargetVariable); }
     114    }
     115    public IEnumerable<double> TargetVariableTrainingValues {
     116      get { return Dataset.GetDoubleValues(TargetVariable, TrainingIndices); }
     117    }
     118    public IEnumerable<double> TargetVariableTestValues {
     119      get { return Dataset.GetDoubleValues(TargetVariable, TestIndices); }
     120    }
     121
     122
    112123    [StorableConstructor]
    113124    protected RegressionProblemData(bool deserializing) : base(deserializing) { }
Note: See TracChangeset for help on using the changeset viewer.