Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/02/16 09:02:09 (8 years ago)
Author:
gkronber
Message:

#2590 merged r13697:13698, r13700:13702, r13704:13705, r13711, r13715 from trunk to stable

Location:
stable
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Problems.DataAnalysis

  • stable/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r12702 r13976  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     
    3233  [StorableClass]
    3334  [Item("RegressionEnsembleModel", "A regression model that contains an ensemble of multiple regression models")]
    34   public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
     35  public sealed class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {
    3536
    3637    private List<IRegressionModel> models;
     
    4546    }
    4647
     48    private List<double> modelWeights;
     49    public IEnumerable<double> ModelWeights {
     50      get { return modelWeights; }
     51    }
     52
     53    [Storable(Name = "ModelWeights")]
     54    private IEnumerable<double> StorableModelWeights {
     55      get { return modelWeights; }
     56      set { modelWeights = value.ToList(); }
     57    }
     58
     59    [Storable]
     60    private bool averageModelEstimates = true;
     61    public bool AverageModelEstimates {
     62      get { return averageModelEstimates; }
     63      set {
     64        if (averageModelEstimates != value) {
     65          averageModelEstimates = value;
     66          OnChanged();
     67        }
     68      }
     69    }
     70
    4771    #region backwards compatiblity 3.3.5
    4872    [Storable(Name = "models", AllowOneWay = true)]
     
    5276    #endregion
    5377
     78    [StorableHook(HookType.AfterDeserialization)]
     79    private void AfterDeserialization() {
     80      // BackwardsCompatibility 3.3.14
     81      #region Backwards compatible code, remove with 3.4
     82      if (modelWeights == null || !modelWeights.Any())
     83        modelWeights = new List<double>(models.Select(m => 1.0));
     84      #endregion
     85    }
     86
    5487    [StorableConstructor]
    55     protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
    56     protected RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
     88    private RegressionEnsembleModel(bool deserializing) : base(deserializing) { }
     89    private RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)
    5790      : base(original, cloner) {
    58       this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
     91      this.models = original.Models.Select(cloner.Clone).ToList();
     92      this.modelWeights = new List<double>(original.ModelWeights);
     93      this.averageModelEstimates = original.averageModelEstimates;
     94    }
     95    public override IDeepCloneable Clone(Cloner cloner) {
     96      return new RegressionEnsembleModel(this, cloner);
    5997    }
    6098
    6199    public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { }
    62     public RegressionEnsembleModel(IEnumerable<IRegressionModel> models)
     100    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) : this(models, models.Select(m => 1.0)) { }
     101    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models, IEnumerable<double> modelWeights)
    63102      : base() {
    64103      this.name = ItemName;
    65104      this.description = ItemDescription;
     105
     106
    66107      this.models = new List<IRegressionModel>(models);
    67     }
    68 
    69     public override IDeepCloneable Clone(Cloner cloner) {
    70       return new RegressionEnsembleModel(this, cloner);
    71     }
    72 
    73     #region IRegressionEnsembleModel Members
     108      this.modelWeights = new List<double>(modelWeights);
     109    }
    74110
    75111    public void Add(IRegressionModel model) {
     112      Add(model, 1.0);
     113    }
     114    public void Add(IRegressionModel model, double weight) {
    76115      models.Add(model);
    77     }
     116      modelWeights.Add(weight);
     117      OnChanged();
     118    }
     119
     120    public void AddRange(IEnumerable<IRegressionModel> models) {
     121      AddRange(models, models.Select(m => 1.0));
     122    }
     123    public void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) {
     124      this.models.AddRange(models);
     125      modelWeights.AddRange(weights);
     126      OnChanged();
     127    }
     128
    78129    public void Remove(IRegressionModel model) {
    79       models.Remove(model);
    80     }
    81 
     130      var index = models.IndexOf(model);
     131      models.RemoveAt(index);
     132      modelWeights.RemoveAt(index);
     133      OnChanged();
     134    }
     135    public void RemoveRange(IEnumerable<IRegressionModel> models) {
     136      foreach (var model in models) {
     137        var index = this.models.IndexOf(model);
     138        this.models.RemoveAt(index);
     139        modelWeights.RemoveAt(index);
     140      }
     141      OnChanged();
     142    }
     143
     144    public double GetModelWeight(IRegressionModel model) {
     145      var index = models.IndexOf(model);
     146      return modelWeights[index];
     147    }
     148    public void SetModelWeight(IRegressionModel model, double weight) {
     149      var index = models.IndexOf(model);
     150      modelWeights[index] = weight;
     151      OnChanged();
     152    }
     153
     154    #region evaluation
    82155    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
    83156      var estimatedValuesEnumerators = (from model in models
    84                                         select model.GetEstimatedValues(dataset, rows).GetEnumerator())
    85                                        .ToList();
     157                                        let weight = GetModelWeight(model)
     158                                        select model.GetEstimatedValues(dataset, rows).Select(e => weight * e)
     159                                        .GetEnumerator()).ToList();
    86160
    87161      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
     
    91165    }
    92166
     167    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
     168      double weightsSum = modelWeights.Sum();
     169      var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)
     170                            select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum();
     171
     172      if (AverageModelEstimates)
     173        return summedEstimates.Select(v => v / weightsSum);
     174      else
     175        return summedEstimates;
     176
     177    }
     178
     179    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
     180      var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator();
     181      var rowsEnumerator = rows.GetEnumerator();
     182
     183      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) {
     184        var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator();
     185        int currentRow = rowsEnumerator.Current;
     186        double weightsSum = 0.0;
     187        double filteredEstimatesSum = 0.0;
     188
     189        for (int m = 0; m < models.Count; m++) {
     190          estimatedValueEnumerator.MoveNext();
     191          var model = models[m];
     192          if (!modelSelectionPredicate(currentRow, model)) continue;
     193
     194          filteredEstimatesSum += estimatedValueEnumerator.Current;
     195          weightsSum += modelWeights[m];
     196        }
     197
     198        if (AverageModelEstimates)
     199          yield return filteredEstimatesSum / weightsSum;
     200        else
     201          yield return filteredEstimatesSum;
     202      }
     203    }
     204
    93205    #endregion
    94206
    95     #region IRegressionModel Members
    96 
    97     public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    98       foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) {
    99         yield return estimatedValuesVector.Average();
    100       }
    101     }
     207    public event EventHandler Changed;
     208    private void OnChanged() {
     209      var handler = Changed;
     210      if (handler != null)
     211        handler(this, EventArgs.Empty);
     212    }
     213
    102214
    103215    public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    104       return new RegressionEnsembleSolution(this.Models, new RegressionEnsembleProblemData(problemData));
     216      return new RegressionEnsembleSolution(this, new RegressionEnsembleProblemData(problemData));
    105217    }
    106218    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
    107219      return CreateRegressionSolution(problemData);
    108220    }
    109 
    110     #endregion
    111221  }
    112222}
  • stable/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r13049 r13976  
    7979        }
    8080      }
     81
     82      RegisterModelEvents();
    8183      RegisterRegressionSolutionsEventHandler();
    8284    }
     
    9395      }
    9496
     97      evaluationCache = new Dictionary<int, double>(original.ProblemData.Dataset.Rows);
    9598      trainingEvaluationCache = new Dictionary<int, double>(original.ProblemData.TrainingIndices.Count());
    9699      testEvaluationCache = new Dictionary<int, double>(original.ProblemData.TestIndices.Count());
    97100
    98101      regressionSolutions = cloner.Clone(original.regressionSolutions);
     102      RegisterModelEvents();
    99103      RegisterRegressionSolutionsEventHandler();
    100104    }
     
    106110      regressionSolutions = new ItemCollection<IRegressionSolution>();
    107111
     112      RegisterModelEvents();
    108113      RegisterRegressionSolutionsEventHandler();
    109114    }
    110115
    111116    public RegressionEnsembleSolution(IRegressionProblemData problemData)
    112       : this(Enumerable.Empty<IRegressionModel>(), problemData) {
    113     }
    114 
    115     public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
    116       : this(models, problemData,
    117              models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
    118              models.Select(m => (IntRange)problemData.TestPartition.Clone())
    119       ) { }
    120 
    121     public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
    122       : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) {
    123       this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    124       this.testPartitions = new Dictionary<IRegressionModel, IntRange>();
    125       this.regressionSolutions = new ItemCollection<IRegressionSolution>();
    126 
    127       List<IRegressionSolution> solutions = new List<IRegressionSolution>();
    128       var modelEnumerator = models.GetEnumerator();
    129       var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
    130       var testPartitionEnumerator = testPartitions.GetEnumerator();
    131 
    132       while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
    133         var p = (IRegressionProblemData)problemData.Clone();
    134         p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
    135         p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
    136         p.TestPartition.Start = testPartitionEnumerator.Current.Start;
    137         p.TestPartition.End = testPartitionEnumerator.Current.End;
    138 
    139         solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p));
    140       }
    141       if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
    142         throw new ArgumentException();
    143       }
    144 
     117      : this(new RegressionEnsembleModel(), problemData) {
     118    }
     119
     120    public RegressionEnsembleSolution(IRegressionEnsembleModel model, IRegressionProblemData problemData)
     121      : base(model, new RegressionEnsembleProblemData(problemData)) {
     122      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
     123      testPartitions = new Dictionary<IRegressionModel, IntRange>();
     124      regressionSolutions = new ItemCollection<IRegressionSolution>();
     125
     126      evaluationCache = new Dictionary<int, double>(problemData.Dataset.Rows);
    145127      trainingEvaluationCache = new Dictionary<int, double>(problemData.TrainingIndices.Count());
    146128      testEvaluationCache = new Dictionary<int, double>(problemData.TestIndices.Count());
    147129
     130
     131      var solutions = model.Models.Select(m => m.CreateRegressionSolution((IRegressionProblemData)problemData.Clone()));
     132      foreach (var solution in solutions) {
     133        regressionSolutions.Add(solution);
     134        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     135        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     136      }
     137
     138      RecalculateResults();
     139      RegisterModelEvents();
    148140      RegisterRegressionSolutionsEventHandler();
    149       regressionSolutions.AddRange(solutions);
    150     }
     141    }
     142
    151143
    152144    public override IDeepCloneable Clone(Cloner cloner) {
    153145      return new RegressionEnsembleSolution(this, cloner);
     146    }
     147
     148    private void RegisterModelEvents() {
     149      Model.Changed += Model_Changed;
    154150    }
    155151    private void RegisterRegressionSolutionsEventHandler() {
     
    168164        var rows = ProblemData.TrainingIndices;
    169165        var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys);
     166
    170167        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    171         var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
     168        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
    172169
    173170        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    184181        var rowsToEvaluate = rows.Except(testEvaluationCache.Keys);
    185182        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    186         var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();
     183        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, RowIsTestForModel).GetEnumerator();
    187184
    188185        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    193190      }
    194191    }
    195 
    196     private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
    197       var estimatedValuesEnumerators = (from model in Model.Models
    198                                         select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    199                                        .ToList();
    200       var rowsEnumerator = rows.GetEnumerator();
    201       // aggregate to make sure that MoveNext is called for all enumerators
    202       while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    203         int currentRow = rowsEnumerator.Current;
    204 
    205         var selectedEnumerators = from pair in estimatedValuesEnumerators
    206                                   where modelSelectionPredicate(currentRow, pair.Model)
    207                                   select pair.EstimatedValuesEnumerator;
    208 
    209         yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
    210       }
    211     }
    212 
    213192    private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) {
    214193      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
    215194              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
    216195    }
    217 
    218196    private bool RowIsTestForModel(int currentRow, IRegressionModel model) {
    219197      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     
    224202      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
    225203      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    226       var valuesEnumerator = (from xs in GetEstimatedValueVectors(ProblemData.Dataset, rowsToEvaluate)
    227                               select AggregateEstimatedValues(xs))
    228                              .GetEnumerator();
     204      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
    229205
    230206      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    235211    }
    236212
    237     public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
    238       if (!Model.Models.Any()) yield break;
    239       var estimatedValuesEnumerators = (from model in Model.Models
    240                                         select model.GetEstimatedValues(dataset, rows).GetEnumerator())
    241                                        .ToList();
    242 
    243       while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
    244         yield return from enumerator in estimatedValuesEnumerators
    245                      select enumerator.Current;
    246       }
    247     }
    248 
    249     private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) {
    250       return estimatedValues.DefaultIfEmpty(double.NaN).Average();
     213    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IEnumerable<int> rows) {
     214      return Model.GetEstimatedValueVectors(ProblemData.Dataset, rows);
    251215    }
    252216    #endregion
     
    282246    }
    283247
    284     public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
    285       regressionSolutions.AddRange(solutions);
     248    private void Model_Changed(object sender, EventArgs e) {
     249      var modelSet = new HashSet<IRegressionModel>(Model.Models);
     250      foreach (var model in Model.Models) {
     251        if (!trainingPartitions.ContainsKey(model)) trainingPartitions.Add(model, ProblemData.TrainingPartition);
     252        if (!testPartitions.ContainsKey(model)) testPartitions.Add(model, ProblemData.TrainingPartition);
     253      }
     254      foreach (var model in trainingPartitions.Keys) {
     255        if (modelSet.Contains(model)) continue;
     256        trainingPartitions.Remove(model);
     257        testPartitions.Remove(model);
     258      }
    286259
    287260      trainingEvaluationCache.Clear();
    288261      testEvaluationCache.Clear();
    289262      evaluationCache.Clear();
     263
     264      OnModelChanged();
     265    }
     266
     267    public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     268      regressionSolutions.AddRange(solutions);
    290269    }
    291270    public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
    292271      regressionSolutions.RemoveRange(solutions);
    293 
    294       trainingEvaluationCache.Clear();
    295       testEvaluationCache.Clear();
    296       evaluationCache.Clear();
    297272    }
    298273
    299274    private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    300       foreach (var solution in e.Items) AddRegressionSolution(solution);
    301       RecalculateResults();
     275      foreach (var solution in e.Items) {
     276        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     277        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     278      }
     279      Model.AddRange(e.Items.Select(s => s.Model));
    302280    }
    303281    private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    304       foreach (var solution in e.Items) RemoveRegressionSolution(solution);
    305       RecalculateResults();
     282      foreach (var solution in e.Items) {
     283        trainingPartitions.Remove(solution.Model);
     284        testPartitions.Remove(solution.Model);
     285      }
     286      Model.RemoveRange(e.Items.Select(s => s.Model));
    306287    }
    307288    private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    308       foreach (var solution in e.OldItems) RemoveRegressionSolution(solution);
    309       foreach (var solution in e.Items) AddRegressionSolution(solution);
    310       RecalculateResults();
    311     }
    312 
    313     private void AddRegressionSolution(IRegressionSolution solution) {
    314       if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
    315       Model.Add(solution.Model);
    316       trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    317       testPartitions[solution.Model] = solution.ProblemData.TestPartition;
    318 
    319       trainingEvaluationCache.Clear();
    320       testEvaluationCache.Clear();
    321       evaluationCache.Clear();
    322     }
    323 
    324     private void RemoveRegressionSolution(IRegressionSolution solution) {
    325       if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
    326       Model.Remove(solution.Model);
    327       trainingPartitions.Remove(solution.Model);
    328       testPartitions.Remove(solution.Model);
    329 
    330       trainingEvaluationCache.Clear();
    331       testEvaluationCache.Clear();
    332       evaluationCache.Clear();
     289      foreach (var solution in e.OldItems) {
     290        trainingPartitions.Remove(solution.Model);
     291        testPartitions.Remove(solution.Model);
     292      }
     293      Model.RemoveRange(e.OldItems.Select(s => s.Model));
     294
     295      foreach (var solution in e.Items) {
     296        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     297        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     298      }
     299      Model.AddRange(e.Items.Select(s => s.Model));
    333300    }
    334301  }
Note: See TracChangeset for help on using the changeset viewer.