Changeset 8508 for branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression
- Timestamp:
- 08/20/12 17:24:14 (12 years ago)
- Location:
- branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis
- Files:
-
- 5 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis
- Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Problems.DataAnalysis merged: 7921,7969,8113,8121,8126,8139,8151-8153,8167,8174,8246,8355
- Property svn:mergeinfo changed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
r7866 r8508 37 37 [Creatable("Data Analysis - Ensembles")] 38 38 public sealed class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution { 39 private readonly Dictionary<int, double> trainingEvaluationCache = new Dictionary<int, double>(); 40 private readonly Dictionary<int, double> testEvaluationCache = new Dictionary<int, double>(); 41 39 42 public new IRegressionEnsembleModel Model { 40 43 get { return (IRegressionEnsembleModel)base.Model; } … … 52 55 53 56 [Storable] 54 private Dictionary<IRegressionModel, IntRange> trainingPartitions;57 private readonly Dictionary<IRegressionModel, IntRange> trainingPartitions; 55 58 [Storable] 56 private Dictionary<IRegressionModel, IntRange> testPartitions;59 private readonly Dictionary<IRegressionModel, IntRange> testPartitions; 57 60 58 61 [StorableConstructor] … … 86 89 } 87 90 91 trainingEvaluationCache = new Dictionary<int, double>(original.ProblemData.TrainingIndices.Count()); 92 testEvaluationCache = new Dictionary<int, double>(original.ProblemData.TestIndices.Count()); 93 88 94 regressionSolutions = cloner.Clone(original.regressionSolutions); 89 95 RegisterRegressionSolutionsEventHandler(); … … 133 139 } 134 140 141 trainingEvaluationCache = new Dictionary<int, double>(problemData.TrainingIndices.Count()); 142 testEvaluationCache = new Dictionary<int, double>(problemData.TestIndices.Count()); 143 135 144 RegisterRegressionSolutionsEventHandler(); 136 145 regressionSolutions.AddRange(solutions); … … 153 162 public override IEnumerable<double> EstimatedTrainingValues { 154 163 get { 155 var rows = ProblemData.TrainingIndizes; 156 var estimatedValuesEnumerators = (from model in Model.Models 157 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() }) 158 .ToList(); 159 var rowsEnumerator = rows.GetEnumerator(); 160 // aggregate to make sure that MoveNext is called for all enumerators 161 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 162 int currentRow = rowsEnumerator.Current; 163 164 var selectedEnumerators = from pair in estimatedValuesEnumerators 165 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 166 select pair.EstimatedValuesEnumerator; 167 yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current)); 164 var rows = ProblemData.TrainingIndices; 165 var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys); 166 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 167 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator(); 168 169 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 170 trainingEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 168 171 } 172 173 return rows.Select(row => trainingEvaluationCache[row]); 169 174 } 170 175 } … … 172 177 public override IEnumerable<double> EstimatedTestValues { 173 178 get { 174 var rows = ProblemData.TestIndizes; 175 var estimatedValuesEnumerators = (from model in Model.Models 176 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() }) 177 .ToList(); 178 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 179 // aggregate to make sure that MoveNext is called for all enumerators 180 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 181 int currentRow = rowsEnumerator.Current; 182 183 var selectedEnumerators = from pair in estimatedValuesEnumerators 184 where RowIsTestForModel(currentRow, pair.Model) 185 select pair.EstimatedValuesEnumerator; 186 187 yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current)); 179 var rows = ProblemData.TestIndices; 180 var rowsToEvaluate = rows.Except(testEvaluationCache.Keys); 181 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 182 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator(); 183 184 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 185 testEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 188 186 } 187 188 return rows.Select(row => testEvaluationCache[row]); 189 } 190 } 191 192 private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) { 193 var estimatedValuesEnumerators = (from model in Model.Models 194 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() }) 195 .ToList(); 196 var rowsEnumerator = rows.GetEnumerator(); 197 // aggregate to make sure that MoveNext is called for all enumerators 198 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 199 int currentRow = rowsEnumerator.Current; 200 201 var selectedEnumerators = from pair in estimatedValuesEnumerators 202 where modelSelectionPredicate(currentRow, pair.Model) 203 select pair.EstimatedValuesEnumerator; 204 205 yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current)); 189 206 } 190 207 } … … 201 218 202 219 public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 203 return from xs in GetEstimatedValueVectors(ProblemData.Dataset, rows) 204 select AggregateEstimatedValues(xs); 220 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 221 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 222 var valuesEnumerator = (from xs in GetEstimatedValueVectors(ProblemData.Dataset, rowsToEvaluate) 223 select AggregateEstimatedValues(xs)) 224 .GetEnumerator(); 225 226 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 227 evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 228 } 229 230 return rows.Select(row => evaluationCache[row]); 205 231 } 206 232 … … 223 249 224 250 protected override void OnProblemDataChanged() { 251 trainingEvaluationCache.Clear(); 252 testEvaluationCache.Clear(); 253 evaluationCache.Clear(); 225 254 IRegressionProblemData problemData = new RegressionProblemData(ProblemData.Dataset, 226 255 ProblemData.AllowedInputVariables, … … 251 280 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 252 281 regressionSolutions.AddRange(solutions); 282 283 trainingEvaluationCache.Clear(); 284 testEvaluationCache.Clear(); 285 evaluationCache.Clear(); 253 286 } 254 287 public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 255 288 regressionSolutions.RemoveRange(solutions); 289 290 trainingEvaluationCache.Clear(); 291 testEvaluationCache.Clear(); 292 evaluationCache.Clear(); 256 293 } 257 294 … … 275 312 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 276 313 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 314 315 trainingEvaluationCache.Clear(); 316 testEvaluationCache.Clear(); 317 evaluationCache.Clear(); 277 318 } 278 319 … … 282 323 trainingPartitions.Remove(solution.Model); 283 324 testPartitions.Remove(solution.Model); 325 326 trainingEvaluationCache.Clear(); 327 testEvaluationCache.Clear(); 328 evaluationCache.Clear(); 284 329 } 285 330 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs
r7866 r8508 95 95 #endregion 96 96 97 public ConstrainedValueParameter<StringValue> TargetVariableParameter {98 get { return ( ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }97 public IConstrainedValueParameter<StringValue> TargetVariableParameter { 98 get { return (IConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; } 99 99 } 100 100 public string TargetVariable { -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs
r7866 r8508 55 55 } 56 56 public override IEnumerable<double> EstimatedTrainingValues { 57 get { return GetEstimatedValues(ProblemData.TrainingIndi zes); }57 get { return GetEstimatedValues(ProblemData.TrainingIndices); } 58 58 } 59 59 public override IEnumerable<double> EstimatedTestValues { 60 get { return GetEstimatedValues(ProblemData.TestIndi zes); }60 get { return GetEstimatedValues(ProblemData.TestIndices); } 61 61 } 62 62 -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolutionBase.cs
r7866 r8508 138 138 OnlineCalculatorError errorState; 139 139 Add(new Result(TrainingMeanAbsoluteErrorResultName, "Mean of absolute errors of the model on the training partition", new DoubleValue())); 140 double trainingMAE = OnlineMeanAbsoluteErrorCalculator.Calculate(EstimatedTrainingValues, ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndi zes), out errorState);140 double trainingMAE = OnlineMeanAbsoluteErrorCalculator.Calculate(EstimatedTrainingValues, ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices), out errorState); 141 141 TrainingMeanAbsoluteError = errorState == OnlineCalculatorError.None ? trainingMAE : double.NaN; 142 142 } … … 145 145 OnlineCalculatorError errorState; 146 146 Add(new Result(TestMeanAbsoluteErrorResultName, "Mean of absolute errors of the model on the test partition", new DoubleValue())); 147 double testMAE = OnlineMeanAbsoluteErrorCalculator.Calculate(EstimatedTestValues, ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndi zes), out errorState);147 double testMAE = OnlineMeanAbsoluteErrorCalculator.Calculate(EstimatedTestValues, ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices), out errorState); 148 148 TestMeanAbsoluteError = errorState == OnlineCalculatorError.None ? testMAE : double.NaN; 149 149 } … … 152 152 OnlineCalculatorError errorState; 153 153 Add(new Result(TrainingMeanErrorResultName, "Mean of errors of the model on the training partition", new DoubleValue())); 154 double trainingME = OnlineMeanErrorCalculator.Calculate(EstimatedTrainingValues, ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndi zes), out errorState);154 double trainingME = OnlineMeanErrorCalculator.Calculate(EstimatedTrainingValues, ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices), out errorState); 155 155 TrainingMeanError = errorState == OnlineCalculatorError.None ? trainingME : double.NaN; 156 156 } … … 158 158 OnlineCalculatorError errorState; 159 159 Add(new Result(TestMeanErrorResultName, "Mean of errors of the model on the test partition", new DoubleValue())); 160 double testME = OnlineMeanErrorCalculator.Calculate(EstimatedTestValues, ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndi zes), out errorState);160 double testME = OnlineMeanErrorCalculator.Calculate(EstimatedTestValues, ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices), out errorState); 161 161 TestMeanError = errorState == OnlineCalculatorError.None ? testME : double.NaN; 162 162 } … … 166 166 protected void CalculateResults() { 167 167 IEnumerable<double> estimatedTrainingValues = EstimatedTrainingValues; // cache values 168 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndi zes);168 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices); 169 169 IEnumerable<double> estimatedTestValues = EstimatedTestValues; // cache values 170 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndi zes);170 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices); 171 171 172 172 OnlineCalculatorError errorState;
Note: See TracChangeset
for help on using the changeset viewer.