- Timestamp:
- 03/17/16 17:48:36 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs
r13705 r13715 109 109 } 110 110 111 #region IRegressionEnsembleModel Members112 111 public void Add(IRegressionModel model) { 113 112 Add(model, 1.0); … … 153 152 } 154 153 154 #region evaluation 155 155 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) { 156 156 var estimatedValuesEnumerators = (from model in models … … 165 165 } 166 166 167 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 168 double weightsSum = modelWeights.Sum(); 169 var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows) 170 select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum(); 171 172 if (AverageModelEstimates) 173 return summedEstimates.Select(v => v / weightsSum); 174 else 175 return summedEstimates; 176 177 } 178 167 179 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) { 168 180 var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator(); 169 181 var rowsEnumerator = rows.GetEnumerator(); 170 182 171 // aggregate to make sure that MoveNext is called for all enumerators172 183 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) { 184 var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator(); 173 185 int currentRow = rowsEnumerator.Current; 174 175 var filteredEstimates = models.Zip(estimatedValuesEnumerators.Current, (m, e) => new { Model = m, EstimatedValue = e }) 176 .Where(f => modelSelectionPredicate(currentRow, f.Model)) 177 .Select(f => f.EstimatedValue).DefaultIfEmpty(double.NaN); 178 179 yield return AggregateEstimatedValues(filteredEstimates); 180 } 181 } 182 183 private double AggregateEstimatedValues(IEnumerable<double> estimatedValuesVector) { 184 if (AverageModelEstimates) 185 return estimatedValuesVector.Average(); 186 else 187 return estimatedValuesVector.Sum(); 188 } 186 double weightsSum = 0.0; 187 double filteredEstimatesSum = 0.0; 188 189 for (int m = 0; m < models.Count; m++) { 190 estimatedValueEnumerator.MoveNext(); 191 var model = models[m]; 192 if (!modelSelectionPredicate(currentRow, model)) continue; 193 194 filteredEstimatesSum += estimatedValueEnumerator.Current; 195 weightsSum += modelWeights[m]; 196 } 197 198 if (AverageModelEstimates) 199 yield return filteredEstimatesSum / weightsSum; 200 else 201 yield return filteredEstimatesSum; 202 } 203 } 204 205 #endregion 189 206 190 207 public event EventHandler Changed; … … 194 211 handler(this, EventArgs.Empty); 195 212 } 196 #endregion 197 198 #region IRegressionModel Members 199 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 200 foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) { 201 yield return AggregateEstimatedValues(estimatedValuesVector.DefaultIfEmpty(double.NaN)); 202 } 203 } 213 204 214 205 215 public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) { … … 209 219 return CreateRegressionSolution(problemData); 210 220 } 211 #endregion212 221 } 213 222 }
Note: See TracChangeset
for help on using the changeset viewer.