- Timestamp:
- 07/02/16 09:02:09 (8 years ago)
- Location:
- stable
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
stable
- Property svn:mergeinfo changed
/trunk/sources merged: 13697-13698,13700-13702,13704-13705,13711,13715
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Problems.DataAnalysis
- Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Problems.DataAnalysis merged: 13697-13698,13700-13702,13704-13705,13715
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs
r12702 r13976 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; … … 32 33 [StorableClass] 33 34 [Item("RegressionEnsembleModel", "A regression model that contains an ensemble of multiple regression models")] 34 public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel {35 public sealed class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel { 35 36 36 37 private List<IRegressionModel> models; … … 45 46 } 46 47 48 private List<double> modelWeights; 49 public IEnumerable<double> ModelWeights { 50 get { return modelWeights; } 51 } 52 53 [Storable(Name = "ModelWeights")] 54 private IEnumerable<double> StorableModelWeights { 55 get { return modelWeights; } 56 set { modelWeights = value.ToList(); } 57 } 58 59 [Storable] 60 private bool averageModelEstimates = true; 61 public bool AverageModelEstimates { 62 get { return averageModelEstimates; } 63 set { 64 if (averageModelEstimates != value) { 65 averageModelEstimates = value; 66 OnChanged(); 67 } 68 } 69 } 70 47 71 #region backwards compatiblity 3.3.5 48 72 [Storable(Name = "models", AllowOneWay = true)] … … 52 76 #endregion 53 77 78 [StorableHook(HookType.AfterDeserialization)] 79 private void AfterDeserialization() { 80 // BackwardsCompatibility 3.3.14 81 #region Backwards compatible code, remove with 3.4 82 if (modelWeights == null || !modelWeights.Any()) 83 modelWeights = new List<double>(models.Select(m => 1.0)); 84 #endregion 85 } 86 54 87 [StorableConstructor] 55 pr otectedRegressionEnsembleModel(bool deserializing) : base(deserializing) { }56 pr otectedRegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)88 private RegressionEnsembleModel(bool deserializing) : base(deserializing) { } 89 private RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner) 57 90 : base(original, cloner) { 58 this.models = original.Models.Select(m => cloner.Clone(m)).ToList(); 91 this.models = original.Models.Select(cloner.Clone).ToList(); 92 this.modelWeights = new List<double>(original.ModelWeights); 93 this.averageModelEstimates = original.averageModelEstimates; 94 } 95 public override IDeepCloneable Clone(Cloner cloner) { 96 return new RegressionEnsembleModel(this, cloner); 59 97 } 60 98 61 99 public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { } 62 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) 100 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) : this(models, models.Select(m => 1.0)) { } 101 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models, IEnumerable<double> modelWeights) 63 102 : base() { 64 103 this.name = ItemName; 65 104 this.description = ItemDescription; 105 106 66 107 this.models = new List<IRegressionModel>(models); 67 } 68 69 public override IDeepCloneable Clone(Cloner cloner) { 70 return new RegressionEnsembleModel(this, cloner); 71 } 72 73 #region IRegressionEnsembleModel Members 108 this.modelWeights = new List<double>(modelWeights); 109 } 74 110 75 111 public void Add(IRegressionModel model) { 112 Add(model, 1.0); 113 } 114 public void Add(IRegressionModel model, double weight) { 76 115 models.Add(model); 77 } 116 modelWeights.Add(weight); 117 OnChanged(); 118 } 119 120 public void AddRange(IEnumerable<IRegressionModel> models) { 121 AddRange(models, models.Select(m => 1.0)); 122 } 123 public void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) { 124 this.models.AddRange(models); 125 modelWeights.AddRange(weights); 126 OnChanged(); 127 } 128 78 129 public void Remove(IRegressionModel model) { 79 models.Remove(model); 80 } 81 130 var index = models.IndexOf(model); 131 models.RemoveAt(index); 132 modelWeights.RemoveAt(index); 133 OnChanged(); 134 } 135 public void RemoveRange(IEnumerable<IRegressionModel> models) { 136 foreach (var model in models) { 137 var index = this.models.IndexOf(model); 138 this.models.RemoveAt(index); 139 modelWeights.RemoveAt(index); 140 } 141 OnChanged(); 142 } 143 144 public double GetModelWeight(IRegressionModel model) { 145 var index = models.IndexOf(model); 146 return modelWeights[index]; 147 } 148 public void SetModelWeight(IRegressionModel model, double weight) { 149 var index = models.IndexOf(model); 150 modelWeights[index] = weight; 151 OnChanged(); 152 } 153 154 #region evaluation 82 155 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) { 83 156 var estimatedValuesEnumerators = (from model in models 84 select model.GetEstimatedValues(dataset, rows).GetEnumerator()) 85 .ToList(); 157 let weight = GetModelWeight(model) 158 select model.GetEstimatedValues(dataset, rows).Select(e => weight * e) 159 .GetEnumerator()).ToList(); 86 160 87 161 while (estimatedValuesEnumerators.All(en => en.MoveNext())) { … … 91 165 } 92 166 167 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 168 double weightsSum = modelWeights.Sum(); 169 var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows) 170 select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum(); 171 172 if (AverageModelEstimates) 173 return summedEstimates.Select(v => v / weightsSum); 174 else 175 return summedEstimates; 176 177 } 178 179 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) { 180 var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator(); 181 var rowsEnumerator = rows.GetEnumerator(); 182 183 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) { 184 var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator(); 185 int currentRow = rowsEnumerator.Current; 186 double weightsSum = 0.0; 187 double filteredEstimatesSum = 0.0; 188 189 for (int m = 0; m < models.Count; m++) { 190 estimatedValueEnumerator.MoveNext(); 191 var model = models[m]; 192 if (!modelSelectionPredicate(currentRow, model)) continue; 193 194 filteredEstimatesSum += estimatedValueEnumerator.Current; 195 weightsSum += modelWeights[m]; 196 } 197 198 if (AverageModelEstimates) 199 yield return filteredEstimatesSum / weightsSum; 200 else 201 yield return filteredEstimatesSum; 202 } 203 } 204 93 205 #endregion 94 206 95 #region IRegressionModel Members96 97 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {98 foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) {99 yield return estimatedValuesVector.Average();100 101 } 207 public event EventHandler Changed; 208 private void OnChanged() { 209 var handler = Changed; 210 if (handler != null) 211 handler(this, EventArgs.Empty); 212 } 213 102 214 103 215 public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) { 104 return new RegressionEnsembleSolution(this .Models, new RegressionEnsembleProblemData(problemData));216 return new RegressionEnsembleSolution(this, new RegressionEnsembleProblemData(problemData)); 105 217 } 106 218 IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) { 107 219 return CreateRegressionSolution(problemData); 108 220 } 109 110 #endregion111 221 } 112 222 } -
stable/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
r13049 r13976 79 79 } 80 80 } 81 82 RegisterModelEvents(); 81 83 RegisterRegressionSolutionsEventHandler(); 82 84 } … … 93 95 } 94 96 97 evaluationCache = new Dictionary<int, double>(original.ProblemData.Dataset.Rows); 95 98 trainingEvaluationCache = new Dictionary<int, double>(original.ProblemData.TrainingIndices.Count()); 96 99 testEvaluationCache = new Dictionary<int, double>(original.ProblemData.TestIndices.Count()); 97 100 98 101 regressionSolutions = cloner.Clone(original.regressionSolutions); 102 RegisterModelEvents(); 99 103 RegisterRegressionSolutionsEventHandler(); 100 104 } … … 106 110 regressionSolutions = new ItemCollection<IRegressionSolution>(); 107 111 112 RegisterModelEvents(); 108 113 RegisterRegressionSolutionsEventHandler(); 109 114 } 110 115 111 116 public RegressionEnsembleSolution(IRegressionProblemData problemData) 112 : this(Enumerable.Empty<IRegressionModel>(), problemData) { 113 } 114 115 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData) 116 : this(models, problemData, 117 models.Select(m => (IntRange)problemData.TrainingPartition.Clone()), 118 models.Select(m => (IntRange)problemData.TestPartition.Clone()) 119 ) { } 120 121 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 122 : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) { 123 this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 124 this.testPartitions = new Dictionary<IRegressionModel, IntRange>(); 125 this.regressionSolutions = new ItemCollection<IRegressionSolution>(); 126 127 List<IRegressionSolution> solutions = new List<IRegressionSolution>(); 128 var modelEnumerator = models.GetEnumerator(); 129 var trainingPartitionEnumerator = trainingPartitions.GetEnumerator(); 130 var testPartitionEnumerator = testPartitions.GetEnumerator(); 131 132 while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) { 133 var p = (IRegressionProblemData)problemData.Clone(); 134 p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start; 135 p.TrainingPartition.End = trainingPartitionEnumerator.Current.End; 136 p.TestPartition.Start = testPartitionEnumerator.Current.Start; 137 p.TestPartition.End = testPartitionEnumerator.Current.End; 138 139 solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p)); 140 } 141 if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) { 142 throw new ArgumentException(); 143 } 144 117 : this(new RegressionEnsembleModel(), problemData) { 118 } 119 120 public RegressionEnsembleSolution(IRegressionEnsembleModel model, IRegressionProblemData problemData) 121 : base(model, new RegressionEnsembleProblemData(problemData)) { 122 trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 123 testPartitions = new Dictionary<IRegressionModel, IntRange>(); 124 regressionSolutions = new ItemCollection<IRegressionSolution>(); 125 126 evaluationCache = new Dictionary<int, double>(problemData.Dataset.Rows); 145 127 trainingEvaluationCache = new Dictionary<int, double>(problemData.TrainingIndices.Count()); 146 128 testEvaluationCache = new Dictionary<int, double>(problemData.TestIndices.Count()); 147 129 130 131 var solutions = model.Models.Select(m => m.CreateRegressionSolution((IRegressionProblemData)problemData.Clone())); 132 foreach (var solution in solutions) { 133 regressionSolutions.Add(solution); 134 trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition); 135 testPartitions.Add(solution.Model, solution.ProblemData.TestPartition); 136 } 137 138 RecalculateResults(); 139 RegisterModelEvents(); 148 140 RegisterRegressionSolutionsEventHandler(); 149 regressionSolutions.AddRange(solutions);150 } 141 } 142 151 143 152 144 public override IDeepCloneable Clone(Cloner cloner) { 153 145 return new RegressionEnsembleSolution(this, cloner); 146 } 147 148 private void RegisterModelEvents() { 149 Model.Changed += Model_Changed; 154 150 } 155 151 private void RegisterRegressionSolutionsEventHandler() { … … 168 164 var rows = ProblemData.TrainingIndices; 169 165 var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys); 166 170 167 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 171 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();168 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator(); 172 169 173 170 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { … … 184 181 var rowsToEvaluate = rows.Except(testEvaluationCache.Keys); 185 182 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 186 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();183 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, RowIsTestForModel).GetEnumerator(); 187 184 188 185 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { … … 193 190 } 194 191 } 195 196 private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {197 var estimatedValuesEnumerators = (from model in Model.Models198 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })199 .ToList();200 var rowsEnumerator = rows.GetEnumerator();201 // aggregate to make sure that MoveNext is called for all enumerators202 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {203 int currentRow = rowsEnumerator.Current;204 205 var selectedEnumerators = from pair in estimatedValuesEnumerators206 where modelSelectionPredicate(currentRow, pair.Model)207 select pair.EstimatedValuesEnumerator;208 209 yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));210 }211 }212 213 192 private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) { 214 193 return trainingPartitions == null || !trainingPartitions.ContainsKey(model) || 215 194 (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End); 216 195 } 217 218 196 private bool RowIsTestForModel(int currentRow, IRegressionModel model) { 219 197 return testPartitions == null || !testPartitions.ContainsKey(model) || … … 224 202 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 225 203 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 226 var valuesEnumerator = (from xs in GetEstimatedValueVectors(ProblemData.Dataset, rowsToEvaluate) 227 select AggregateEstimatedValues(xs)) 228 .GetEnumerator(); 204 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 229 205 230 206 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { … … 235 211 } 236 212 237 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) { 238 if (!Model.Models.Any()) yield break; 239 var estimatedValuesEnumerators = (from model in Model.Models 240 select model.GetEstimatedValues(dataset, rows).GetEnumerator()) 241 .ToList(); 242 243 while (estimatedValuesEnumerators.All(en => en.MoveNext())) { 244 yield return from enumerator in estimatedValuesEnumerators 245 select enumerator.Current; 246 } 247 } 248 249 private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) { 250 return estimatedValues.DefaultIfEmpty(double.NaN).Average(); 213 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IEnumerable<int> rows) { 214 return Model.GetEstimatedValueVectors(ProblemData.Dataset, rows); 251 215 } 252 216 #endregion … … 282 246 } 283 247 284 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 285 regressionSolutions.AddRange(solutions); 248 private void Model_Changed(object sender, EventArgs e) { 249 var modelSet = new HashSet<IRegressionModel>(Model.Models); 250 foreach (var model in Model.Models) { 251 if (!trainingPartitions.ContainsKey(model)) trainingPartitions.Add(model, ProblemData.TrainingPartition); 252 if (!testPartitions.ContainsKey(model)) testPartitions.Add(model, ProblemData.TrainingPartition); 253 } 254 foreach (var model in trainingPartitions.Keys) { 255 if (modelSet.Contains(model)) continue; 256 trainingPartitions.Remove(model); 257 testPartitions.Remove(model); 258 } 286 259 287 260 trainingEvaluationCache.Clear(); 288 261 testEvaluationCache.Clear(); 289 262 evaluationCache.Clear(); 263 264 OnModelChanged(); 265 } 266 267 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 268 regressionSolutions.AddRange(solutions); 290 269 } 291 270 public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 292 271 regressionSolutions.RemoveRange(solutions); 293 294 trainingEvaluationCache.Clear();295 testEvaluationCache.Clear();296 evaluationCache.Clear();297 272 } 298 273 299 274 private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 300 foreach (var solution in e.Items) AddRegressionSolution(solution); 301 RecalculateResults(); 275 foreach (var solution in e.Items) { 276 trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition); 277 testPartitions.Add(solution.Model, solution.ProblemData.TestPartition); 278 } 279 Model.AddRange(e.Items.Select(s => s.Model)); 302 280 } 303 281 private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 304 foreach (var solution in e.Items) RemoveRegressionSolution(solution); 305 RecalculateResults(); 282 foreach (var solution in e.Items) { 283 trainingPartitions.Remove(solution.Model); 284 testPartitions.Remove(solution.Model); 285 } 286 Model.RemoveRange(e.Items.Select(s => s.Model)); 306 287 } 307 288 private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 308 foreach (var solution in e.OldItems) RemoveRegressionSolution(solution); 309 foreach (var solution in e.Items) AddRegressionSolution(solution); 310 RecalculateResults(); 311 } 312 313 private void AddRegressionSolution(IRegressionSolution solution) { 314 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 315 Model.Add(solution.Model); 316 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 317 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 318 319 trainingEvaluationCache.Clear(); 320 testEvaluationCache.Clear(); 321 evaluationCache.Clear(); 322 } 323 324 private void RemoveRegressionSolution(IRegressionSolution solution) { 325 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 326 Model.Remove(solution.Model); 327 trainingPartitions.Remove(solution.Model); 328 testPartitions.Remove(solution.Model); 329 330 trainingEvaluationCache.Clear(); 331 testEvaluationCache.Clear(); 332 evaluationCache.Clear(); 289 foreach (var solution in e.OldItems) { 290 trainingPartitions.Remove(solution.Model); 291 testPartitions.Remove(solution.Model); 292 } 293 Model.RemoveRange(e.OldItems.Select(s => s.Model)); 294 295 foreach (var solution in e.Items) { 296 trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition); 297 testPartitions.Add(solution.Model, solution.ProblemData.TestPartition); 298 } 299 Model.AddRange(e.Items.Select(s => s.Model)); 333 300 } 334 301 }
Note: See TracChangeset
for help on using the changeset viewer.