- Timestamp:
- 03/15/16 15:07:59 (9 years ago)
- Location:
- trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs
r13701 r13704 46 46 } 47 47 48 private List<double> modelWeights; 49 public IEnumerable<double> ModelWeights { 50 get { return modelWeights; } 51 } 52 53 [Storable(Name = "ModelWeights")] 54 private IEnumerable<double> StorableModelWeights { 55 get { return modelWeights; } 56 set { modelWeights = value.ToList(); } 57 } 58 48 59 [Storable] 49 60 private bool averageModelEstimates = true; … … 53 64 if (averageModelEstimates != value) { 54 65 averageModelEstimates = value; 55 On AverageModelEstimatesChanged();66 OnChanged(); 56 67 } 57 68 } … … 64 75 } 65 76 #endregion 77 78 [StorableHook(HookType.AfterDeserialization)] 79 private void AfterDeserialization() { 80 // BackwardsCompatibility 3.3.14 81 #region Backwards compatible code, remove with 3.4 82 if (modelWeights == null || !modelWeights.Any()) 83 modelWeights = new List<double>(models.Select(m => 1.0)); 84 #endregion 85 } 66 86 67 87 [StorableConstructor] … … 70 90 : base(original, cloner) { 71 91 this.models = original.Models.Select(cloner.Clone).ToList(); 92 this.modelWeights = new List<double>(original.ModelWeights); 72 93 this.averageModelEstimates = original.averageModelEstimates; 73 94 } … … 77 98 78 99 public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { } 79 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) 100 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) : this(models, models.Select(m => 1.0)) { } 101 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models, IEnumerable<double> modelWeights) 80 102 : base() { 81 103 this.name = ItemName; 82 104 this.description = ItemDescription; 105 106 83 107 this.models = new List<IRegressionModel>(models); 108 this.modelWeights = new List<double>(modelWeights); 84 109 } 85 110 86 111 #region IRegressionEnsembleModel Members 87 112 public void Add(IRegressionModel model) { 113 Add(model, 1.0); 114 } 115 public void Add(IRegressionModel model, double weight) { 88 116 models.Add(model); 117 modelWeights.Add(weight); 118 OnChanged(); 119 } 120 121 public void AddRange(IEnumerable<IRegressionModel> models) { 122 AddRange(models, models.Select(m => 1.0)); 123 } 124 public void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) { 125 this.models.AddRange(models); 126 modelWeights.AddRange(weights); 127 OnChanged(); 89 128 } 90 129 91 130 public void Remove(IRegressionModel model) { 92 models.Remove(model); 131 var index = models.IndexOf(model); 132 models.RemoveAt(index); 133 modelWeights.RemoveAt(index); 134 OnChanged(); 135 } 136 public void RemoveRange(IEnumerable<IRegressionModel> models) { 137 foreach (var model in models) { 138 var index = this.models.IndexOf(model); 139 this.models.RemoveAt(index); 140 modelWeights.RemoveAt(index); 141 } 142 OnChanged(); 143 } 144 145 public double GetModelWeight(IRegressionModel model) { 146 var index = models.IndexOf(model); 147 return modelWeights[index]; 148 } 149 public void SetModelWeight(IRegressionModel model, double weight) { 150 var index = models.IndexOf(model); 151 modelWeights[index] = weight; 152 OnChanged(); 93 153 } 94 154 … … 127 187 } 128 188 129 public event EventHandler AverageModelEstimatesChanged;130 private void On AverageModelEstimatesChanged() {131 var handler = AverageModelEstimatesChanged;189 public event EventHandler Changed; 190 private void OnChanged() { 191 var handler = Changed; 132 192 if (handler != null) 133 193 handler(this, EventArgs.Empty); -
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
r13702 r13704 79 79 } 80 80 } 81 82 RegisterModelEvents(); 81 83 RegisterRegressionSolutionsEventHandler(); 82 84 } … … 98 100 99 101 regressionSolutions = cloner.Clone(original.regressionSolutions); 102 RegisterModelEvents(); 100 103 RegisterRegressionSolutionsEventHandler(); 101 104 } … … 107 110 regressionSolutions = new ItemCollection<IRegressionSolution>(); 108 111 112 RegisterModelEvents(); 109 113 RegisterRegressionSolutionsEventHandler(); 110 114 } … … 133 137 134 138 RecalculateResults(); 139 RegisterModelEvents(); 135 140 RegisterRegressionSolutionsEventHandler(); 136 141 } … … 139 144 public override IDeepCloneable Clone(Cloner cloner) { 140 145 return new RegressionEnsembleSolution(this, cloner); 146 } 147 148 private void RegisterModelEvents() { 149 Model.Changed += Model_Changed; 141 150 } 142 151 private void RegisterRegressionSolutionsEventHandler() { … … 155 164 var rows = ProblemData.TrainingIndices; 156 165 var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys); 166 157 167 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 158 168 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator(); … … 236 246 } 237 247 238 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 239 regressionSolutions.AddRange(solutions); 248 private void Model_Changed(object sender, EventArgs e) { 249 var modelSet = new HashSet<IRegressionModel>(Model.Models); 250 foreach (var model in Model.Models) { 251 if (!trainingPartitions.ContainsKey(model)) trainingPartitions.Add(model, ProblemData.TrainingPartition); 252 if (!testPartitions.ContainsKey(model)) testPartitions.Add(model, ProblemData.TrainingPartition); 253 } 254 foreach (var model in trainingPartitions.Keys) { 255 if (modelSet.Contains(model)) continue; 256 trainingPartitions.Remove(model); 257 testPartitions.Remove(model); 258 } 240 259 241 260 trainingEvaluationCache.Clear(); 242 261 testEvaluationCache.Clear(); 243 262 evaluationCache.Clear(); 263 264 OnModelChanged(); 265 } 266 267 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 268 regressionSolutions.AddRange(solutions); 244 269 } 245 270 public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 246 271 regressionSolutions.RemoveRange(solutions); 247 248 trainingEvaluationCache.Clear();249 testEvaluationCache.Clear();250 evaluationCache.Clear();251 272 } 252 273 253 274 private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 254 foreach (var solution in e.Items) AddRegressionSolution(solution); 255 RecalculateResults(); 275 foreach (var solution in e.Items) { 276 trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition); 277 testPartitions.Add(solution.Model, solution.ProblemData.TestPartition); 278 } 279 Model.AddRange(e.Items.Select(s => s.Model)); 256 280 } 257 281 private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 258 foreach (var solution in e.Items) RemoveRegressionSolution(solution); 259 RecalculateResults(); 282 foreach (var solution in e.Items) { 283 trainingPartitions.Remove(solution.Model); 284 testPartitions.Remove(solution.Model); 285 } 286 Model.RemoveRange(e.Items.Select(s => s.Model)); 260 287 } 261 288 private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 262 foreach (var solution in e.OldItems) RemoveRegressionSolution(solution); 263 foreach (var solution in e.Items) AddRegressionSolution(solution); 264 RecalculateResults(); 265 } 266 267 private void AddRegressionSolution(IRegressionSolution solution) { 268 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 269 Model.Add(solution.Model); 270 271 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 272 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 273 274 trainingEvaluationCache.Clear(); 275 testEvaluationCache.Clear(); 276 evaluationCache.Clear(); 277 } 278 279 private void RemoveRegressionSolution(IRegressionSolution solution) { 280 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 281 Model.Remove(solution.Model); 282 283 trainingPartitions.Remove(solution.Model); 284 testPartitions.Remove(solution.Model); 285 286 trainingEvaluationCache.Clear(); 287 testEvaluationCache.Clear(); 288 evaluationCache.Clear(); 289 foreach (var solution in e.OldItems) { 290 trainingPartitions.Remove(solution.Model); 291 testPartitions.Remove(solution.Model); 292 } 293 Model.RemoveRange(e.OldItems.Select(s => s.Model)); 294 295 foreach (var solution in e.Items) { 296 trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition); 297 testPartitions.Add(solution.Model, solution.ProblemData.TestPartition); 298 } 299 Model.AddRange(e.Items.Select(s => s.Model)); 289 300 } 290 301 } -
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Regression/IRegressionEnsembleModel.cs
r13700 r13704 25 25 public interface IRegressionEnsembleModel : IRegressionModel { 26 26 void Add(IRegressionModel model); 27 void Add(IRegressionModel model, double weight); 28 void AddRange(IEnumerable<IRegressionModel> models); 29 void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights); 30 27 31 void Remove(IRegressionModel model); 32 void RemoveRange(IEnumerable<IRegressionModel> models); 28 33 29 34 IEnumerable<IRegressionModel> Models { get; } 35 IEnumerable<double> ModelWeights { get; } 36 37 double GetModelWeight(IRegressionModel model); 38 void SetModelWeight(IRegressionModel model, double weight); 30 39 31 40 bool AverageModelEstimates { get; set; } 32 event EventHandler AverageModelEstimatesChanged; 41 42 event EventHandler Changed; 33 43 34 44 IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows);
Note: See TracChangeset
for help on using the changeset viewer.