Changeset 13704


Ignore:
Timestamp:
03/15/16 15:07:59 (5 years ago)
Author:
mkommend
Message:

#2590: Added model weights for ensembles.

Location:
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r13701 r13704  
    4646    }
    4747
     48    private List<double> modelWeights;
     49    public IEnumerable<double> ModelWeights {
     50      get { return modelWeights; }
     51    }
     52
     53    [Storable(Name = "ModelWeights")]
     54    private IEnumerable<double> StorableModelWeights {
     55      get { return modelWeights; }
     56      set { modelWeights = value.ToList(); }
     57    }
     58
    4859    [Storable]
    4960    private bool averageModelEstimates = true;
     
    5364        if (averageModelEstimates != value) {
    5465          averageModelEstimates = value;
    55           OnAverageModelEstimatesChanged();
     66          OnChanged();
    5667        }
    5768      }
     
    6475    }
    6576    #endregion
     77
     78    [StorableHook(HookType.AfterDeserialization)]
     79    private void AfterDeserialization() {
     80      // BackwardsCompatibility 3.3.14
     81      #region Backwards compatible code, remove with 3.4
     82      if (modelWeights == null || !modelWeights.Any())
     83        modelWeights = new List<double>(models.Select(m => 1.0));
     84      #endregion
     85    }
    6686
    6787    [StorableConstructor]
     
    7090      : base(original, cloner) {
    7191      this.models = original.Models.Select(cloner.Clone).ToList();
     92      this.modelWeights = new List<double>(original.ModelWeights);
    7293      this.averageModelEstimates = original.averageModelEstimates;
    7394    }
     
    7798
    7899    public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { }
    79     public RegressionEnsembleModel(IEnumerable<IRegressionModel> models)
     100    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) : this(models, models.Select(m => 1.0)) { }
     101    public RegressionEnsembleModel(IEnumerable<IRegressionModel> models, IEnumerable<double> modelWeights)
    80102      : base() {
    81103      this.name = ItemName;
    82104      this.description = ItemDescription;
     105
     106
    83107      this.models = new List<IRegressionModel>(models);
     108      this.modelWeights = new List<double>(modelWeights);
    84109    }
    85110
    86111    #region IRegressionEnsembleModel Members
    87112    public void Add(IRegressionModel model) {
     113      Add(model, 1.0);
     114    }
     115    public void Add(IRegressionModel model, double weight) {
    88116      models.Add(model);
     117      modelWeights.Add(weight);
     118      OnChanged();
     119    }
     120
     121    public void AddRange(IEnumerable<IRegressionModel> models) {
     122      AddRange(models, models.Select(m => 1.0));
     123    }
     124    public void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) {
     125      this.models.AddRange(models);
     126      modelWeights.AddRange(weights);
     127      OnChanged();
    89128    }
    90129
    91130    public void Remove(IRegressionModel model) {
    92       models.Remove(model);
     131      var index = models.IndexOf(model);
     132      models.RemoveAt(index);
     133      modelWeights.RemoveAt(index);
     134      OnChanged();
     135    }
     136    public void RemoveRange(IEnumerable<IRegressionModel> models) {
     137      foreach (var model in models) {
     138        var index = this.models.IndexOf(model);
     139        this.models.RemoveAt(index);
     140        modelWeights.RemoveAt(index);
     141      }
     142      OnChanged();
     143    }
     144
     145    public double GetModelWeight(IRegressionModel model) {
     146      var index = models.IndexOf(model);
     147      return modelWeights[index];
     148    }
     149    public void SetModelWeight(IRegressionModel model, double weight) {
     150      var index = models.IndexOf(model);
     151      modelWeights[index] = weight;
     152      OnChanged();
    93153    }
    94154
     
    127187    }
    128188
    129     public event EventHandler AverageModelEstimatesChanged;
    130     private void OnAverageModelEstimatesChanged() {
    131       var handler = AverageModelEstimatesChanged;
     189    public event EventHandler Changed;
     190    private void OnChanged() {
     191      var handler = Changed;
    132192      if (handler != null)
    133193        handler(this, EventArgs.Empty);
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r13702 r13704  
    7979        }
    8080      }
     81
     82      RegisterModelEvents();
    8183      RegisterRegressionSolutionsEventHandler();
    8284    }
     
    98100
    99101      regressionSolutions = cloner.Clone(original.regressionSolutions);
     102      RegisterModelEvents();
    100103      RegisterRegressionSolutionsEventHandler();
    101104    }
     
    107110      regressionSolutions = new ItemCollection<IRegressionSolution>();
    108111
     112      RegisterModelEvents();
    109113      RegisterRegressionSolutionsEventHandler();
    110114    }
     
    133137
    134138      RecalculateResults();
     139      RegisterModelEvents();
    135140      RegisterRegressionSolutionsEventHandler();
    136141    }
     
    139144    public override IDeepCloneable Clone(Cloner cloner) {
    140145      return new RegressionEnsembleSolution(this, cloner);
     146    }
     147
     148    private void RegisterModelEvents() {
     149      Model.Changed += Model_Changed;
    141150    }
    142151    private void RegisterRegressionSolutionsEventHandler() {
     
    155164        var rows = ProblemData.TrainingIndices;
    156165        var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys);
     166
    157167        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    158168        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
     
    236246    }
    237247
    238     public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
    239       regressionSolutions.AddRange(solutions);
     248    private void Model_Changed(object sender, EventArgs e) {
     249      var modelSet = new HashSet<IRegressionModel>(Model.Models);
     250      foreach (var model in Model.Models) {
     251        if (!trainingPartitions.ContainsKey(model)) trainingPartitions.Add(model, ProblemData.TrainingPartition);
     252        if (!testPartitions.ContainsKey(model)) testPartitions.Add(model, ProblemData.TrainingPartition);
     253      }
     254      foreach (var model in trainingPartitions.Keys) {
     255        if (modelSet.Contains(model)) continue;
     256        trainingPartitions.Remove(model);
     257        testPartitions.Remove(model);
     258      }
    240259
    241260      trainingEvaluationCache.Clear();
    242261      testEvaluationCache.Clear();
    243262      evaluationCache.Clear();
     263
     264      OnModelChanged();
     265    }
     266
     267    public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     268      regressionSolutions.AddRange(solutions);
    244269    }
    245270    public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
    246271      regressionSolutions.RemoveRange(solutions);
    247 
    248       trainingEvaluationCache.Clear();
    249       testEvaluationCache.Clear();
    250       evaluationCache.Clear();
    251272    }
    252273
    253274    private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    254       foreach (var solution in e.Items) AddRegressionSolution(solution);
    255       RecalculateResults();
     275      foreach (var solution in e.Items) {
     276        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     277        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     278      }
     279      Model.AddRange(e.Items.Select(s => s.Model));
    256280    }
    257281    private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    258       foreach (var solution in e.Items) RemoveRegressionSolution(solution);
    259       RecalculateResults();
     282      foreach (var solution in e.Items) {
     283        trainingPartitions.Remove(solution.Model);
     284        testPartitions.Remove(solution.Model);
     285      }
     286      Model.RemoveRange(e.Items.Select(s => s.Model));
    260287    }
    261288    private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    262       foreach (var solution in e.OldItems) RemoveRegressionSolution(solution);
    263       foreach (var solution in e.Items) AddRegressionSolution(solution);
    264       RecalculateResults();
    265     }
    266 
    267     private void AddRegressionSolution(IRegressionSolution solution) {
    268       if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
    269       Model.Add(solution.Model);
    270 
    271       trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    272       testPartitions[solution.Model] = solution.ProblemData.TestPartition;
    273 
    274       trainingEvaluationCache.Clear();
    275       testEvaluationCache.Clear();
    276       evaluationCache.Clear();
    277     }
    278 
    279     private void RemoveRegressionSolution(IRegressionSolution solution) {
    280       if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
    281       Model.Remove(solution.Model);
    282 
    283       trainingPartitions.Remove(solution.Model);
    284       testPartitions.Remove(solution.Model);
    285 
    286       trainingEvaluationCache.Clear();
    287       testEvaluationCache.Clear();
    288       evaluationCache.Clear();
     289      foreach (var solution in e.OldItems) {
     290        trainingPartitions.Remove(solution.Model);
     291        testPartitions.Remove(solution.Model);
     292      }
     293      Model.RemoveRange(e.OldItems.Select(s => s.Model));
     294
     295      foreach (var solution in e.Items) {
     296        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     297        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     298      }
     299      Model.AddRange(e.Items.Select(s => s.Model));
    289300    }
    290301  }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Regression/IRegressionEnsembleModel.cs

    r13700 r13704  
    2525  public interface IRegressionEnsembleModel : IRegressionModel {
    2626    void Add(IRegressionModel model);
     27    void Add(IRegressionModel model, double weight);
     28    void AddRange(IEnumerable<IRegressionModel> models);
     29    void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights);
     30
    2731    void Remove(IRegressionModel model);
     32    void RemoveRange(IEnumerable<IRegressionModel> models);
    2833
    2934    IEnumerable<IRegressionModel> Models { get; }
     35    IEnumerable<double> ModelWeights { get; }
     36
     37    double GetModelWeight(IRegressionModel model);
     38    void SetModelWeight(IRegressionModel model, double weight);
    3039
    3140    bool AverageModelEstimates { get; set; }
    32     event EventHandler AverageModelEstimatesChanged;
     41
     42    event EventHandler Changed;
    3343
    3444    IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows);
Note: See TracChangeset for help on using the changeset viewer.