Changeset 6238 for trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
- Timestamp:
- 05/20/11 15:07:45 (13 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
r6184 r6238 51 51 } 52 52 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData) 53 : base(new RegressionEnsembleModel(models), problemData) {53 : base(new RegressionEnsembleModel(models), new RegressionEnsembleProblemData(problemData)) { 54 54 trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 55 55 testPartitions = new Dictionary<IRegressionModel, IntRange>(); … … 62 62 63 63 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 64 : base(new RegressionEnsembleModel(models), problemData) {64 : base(new RegressionEnsembleModel(models), new RegressionEnsembleProblemData(problemData)) { 65 65 this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 66 66 this.testPartitions = new Dictionary<IRegressionModel, IntRange>(); … … 75 75 throw new ArgumentException(); 76 76 } 77 78 77 RecalculateResults(); 79 }80 81 private void RecalculateResults() {82 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values83 var trainingIndizes = Enumerable.Range(ProblemData.TrainingPartition.Start,84 ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);85 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, trainingIndizes);86 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values87 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);88 89 OnlineCalculatorError errorState;90 double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);91 TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;92 double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);93 TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;94 95 double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);96 TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;97 double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);98 TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;99 100 double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);101 TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;102 double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);103 TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;104 105 double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);106 TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN;107 double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);108 TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN;109 78 } 110 79 … … 115 84 public override IEnumerable<double> EstimatedTrainingValues { 116 85 get { 117 var rows = Enumerable.Range(ProblemData.TrainingPartition.Start, ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);86 var rows = ProblemData.TrainingIndizes; 118 87 var estimatedValuesEnumerators = (from model in Model.Models 119 88 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() }) 120 89 .ToList(); 121 90 var rowsEnumerator = rows.GetEnumerator(); 91 // aggregate to make sure that MoveNext is called for all enumerators 122 92 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 123 93 int currentRow = rowsEnumerator.Current; … … 134 104 public override IEnumerable<double> EstimatedTestValues { 135 105 get { 106 var rows = ProblemData.TestIndizes; 136 107 var estimatedValuesEnumerators = (from model in Model.Models 137 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, ProblemData.TestIndizes).GetEnumerator() })108 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() }) 138 109 .ToList(); 139 110 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 111 // aggregate to make sure that MoveNext is called for all enumerators 140 112 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 141 113 int currentRow = rowsEnumerator.Current; … … 168 140 169 141 private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) { 170 return estimatedValues. Average();142 return estimatedValues.DefaultIfEmpty(double.NaN).Average(); 171 143 } 172 144
Note: See TracChangeset
for help on using the changeset viewer.