Changeset 15280 for branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs
- Timestamp:
- 07/23/17 00:52:14 (7 years ago)
- Location:
- branches/Async
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/Async
- Property svn:mergeinfo changed
-
branches/Async/HeuristicLab.Problems.DataAnalysis
-
branches/Async/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs
r12509 r15280 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; … … 32 33 [StorableClass] 33 34 [Item("RegressionEnsembleModel", "A regression model that contains an ensemble of multiple regression models")] 34 public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel { 35 public sealed class RegressionEnsembleModel : RegressionModel, IRegressionEnsembleModel { 36 public override IEnumerable<string> VariablesUsedForPrediction { 37 get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); } 38 } 35 39 36 40 private List<IRegressionModel> models; … … 45 49 } 46 50 51 private List<double> modelWeights; 52 public IEnumerable<double> ModelWeights { 53 get { return modelWeights; } 54 } 55 56 [Storable(Name = "ModelWeights")] 57 private IEnumerable<double> StorableModelWeights { 58 get { return modelWeights; } 59 set { modelWeights = value.ToList(); } 60 } 61 62 [Storable] 63 private bool averageModelEstimates = true; 64 public bool AverageModelEstimates { 65 get { return averageModelEstimates; } 66 set { 67 if (averageModelEstimates != value) { 68 averageModelEstimates = value; 69 OnChanged(); 70 } 71 } 72 } 73 47 74 #region backwards compatiblity 3.3.5 48 75 [Storable(Name = "models", AllowOneWay = true)] … … 52 79 #endregion 53 80 81 [StorableHook(HookType.AfterDeserialization)] 82 private void AfterDeserialization() { 83 // BackwardsCompatibility 3.3.14 84 #region Backwards compatible code, remove with 3.4 85 if (modelWeights == null || !modelWeights.Any()) 86 modelWeights = new List<double>(models.Select(m => 1.0)); 87 #endregion 88 } 89 54 90 [StorableConstructor] 55 pr otectedRegressionEnsembleModel(bool deserializing) : base(deserializing) { }56 pr otectedRegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)91 private RegressionEnsembleModel(bool deserializing) : base(deserializing) { } 92 private RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner) 57 93 : base(original, cloner) { 58 this.models = original.Models.Select(m => cloner.Clone(m)).ToList(); 94 this.models = original.Models.Select(cloner.Clone).ToList(); 95 this.modelWeights = new List<double>(original.ModelWeights); 96 this.averageModelEstimates = original.averageModelEstimates; 97 } 98 public override IDeepCloneable Clone(Cloner cloner) { 99 return new RegressionEnsembleModel(this, cloner); 59 100 } 60 101 61 102 public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { } 62 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) 63 : base() { 103 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) : this(models, models.Select(m => 1.0)) { } 104 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models, IEnumerable<double> modelWeights) 105 : base(string.Empty) { 64 106 this.name = ItemName; 65 107 this.description = ItemDescription; 108 66 109 this.models = new List<IRegressionModel>(models); 67 } 68 69 public override IDeepCloneable Clone(Cloner cloner) { 70 return new RegressionEnsembleModel(this, cloner); 71 } 72 73 #region IRegressionEnsembleModel Members 110 this.modelWeights = new List<double>(modelWeights); 111 112 if (this.models.Any()) this.TargetVariable = this.models.First().TargetVariable; 113 } 74 114 75 115 public void Add(IRegressionModel model) { 116 if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable; 117 Add(model, 1.0); 118 } 119 public void Add(IRegressionModel model, double weight) { 120 if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable; 121 76 122 models.Add(model); 77 } 123 modelWeights.Add(weight); 124 OnChanged(); 125 } 126 127 public void AddRange(IEnumerable<IRegressionModel> models) { 128 AddRange(models, models.Select(m => 1.0)); 129 } 130 public void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) { 131 if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = models.First().TargetVariable; 132 133 this.models.AddRange(models); 134 modelWeights.AddRange(weights); 135 OnChanged(); 136 } 137 78 138 public void Remove(IRegressionModel model) { 79 models.Remove(model); 80 } 81 139 var index = models.IndexOf(model); 140 models.RemoveAt(index); 141 modelWeights.RemoveAt(index); 142 143 if (!models.Any()) TargetVariable = string.Empty; 144 OnChanged(); 145 } 146 public void RemoveRange(IEnumerable<IRegressionModel> models) { 147 foreach (var model in models) { 148 var index = this.models.IndexOf(model); 149 this.models.RemoveAt(index); 150 modelWeights.RemoveAt(index); 151 } 152 153 if (!models.Any()) TargetVariable = string.Empty; 154 OnChanged(); 155 } 156 157 public double GetModelWeight(IRegressionModel model) { 158 var index = models.IndexOf(model); 159 return modelWeights[index]; 160 } 161 public void SetModelWeight(IRegressionModel model, double weight) { 162 var index = models.IndexOf(model); 163 modelWeights[index] = weight; 164 OnChanged(); 165 } 166 167 #region evaluation 82 168 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) { 83 169 var estimatedValuesEnumerators = (from model in models 84 select model.GetEstimatedValues(dataset, rows).GetEnumerator()) 85 .ToList(); 170 let weight = GetModelWeight(model) 171 select model.GetEstimatedValues(dataset, rows).Select(e => weight * e) 172 .GetEnumerator()).ToList(); 86 173 87 174 while (estimatedValuesEnumerators.All(en => en.MoveNext())) { … … 91 178 } 92 179 180 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 181 double weightsSum = modelWeights.Sum(); 182 var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows) 183 select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum(); 184 185 if (AverageModelEstimates) 186 return summedEstimates.Select(v => v / weightsSum); 187 else 188 return summedEstimates; 189 190 } 191 192 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) { 193 var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator(); 194 var rowsEnumerator = rows.GetEnumerator(); 195 196 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) { 197 var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator(); 198 int currentRow = rowsEnumerator.Current; 199 double weightsSum = 0.0; 200 double filteredEstimatesSum = 0.0; 201 202 for (int m = 0; m < models.Count; m++) { 203 estimatedValueEnumerator.MoveNext(); 204 var model = models[m]; 205 if (!modelSelectionPredicate(currentRow, model)) continue; 206 207 filteredEstimatesSum += estimatedValueEnumerator.Current; 208 weightsSum += modelWeights[m]; 209 } 210 211 if (AverageModelEstimates) 212 yield return filteredEstimatesSum / weightsSum; 213 else 214 yield return filteredEstimatesSum; 215 } 216 } 217 93 218 #endregion 94 219 95 #region IRegressionModel Members 96 97 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 98 foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) { 99 yield return estimatedValuesVector.Average(); 100 } 101 } 102 103 public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) { 104 return new RegressionEnsembleSolution(this.Models, new RegressionEnsembleProblemData(problemData)); 105 } 106 IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) { 107 return CreateRegressionSolution(problemData); 108 } 109 110 #endregion 220 public event EventHandler Changed; 221 private void OnChanged() { 222 var handler = Changed; 223 if (handler != null) 224 handler(this, EventArgs.Empty); 225 } 226 227 228 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 229 return new RegressionEnsembleSolution(this, new RegressionEnsembleProblemData(problemData)); 230 } 111 231 } 112 232 }
Note: See TracChangeset
for help on using the changeset viewer.