Changeset 7729 for branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation
- Timestamp:
- 04/16/12 09:14:38 (13 years ago)
- Location:
- branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators
- Files:
-
- 2 added
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs
r7562 r7729 59 59 60 60 IEnumerable<IDictionary<IClassificationSolution, double>> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler); 61 IEnumerable<IDictionary<I ClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler);61 IEnumerable<IDictionary<IDiscriminantFunctionClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler); 62 62 63 63 return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values }) … … 65 65 } 66 66 67 protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<I ClassificationSolution, double> estimatedValues) {67 protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IDiscriminantFunctionClassificationSolution, double> estimatedValues) { 68 68 return AggregateEstimatedClassValues(estimatedClassValues); 69 69 } 70 70 71 protected IEnumerable<IDictionary<I ClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {71 protected IEnumerable<IDictionary<IDiscriminantFunctionClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) { 72 72 var estimatedValuesEnumerators = (from solution in solutions 73 select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimated ClassValues(dataset, rows).GetEnumerator() })73 select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedValues(dataset, rows).GetEnumerator() }) 74 74 .ToList(); 75 75 … … 79 79 where handler(enumerator.Solution.ProblemData, rowEnumerator.Current) 80 80 select enumerator) 81 .ToDictionary(x => (IClassificationSolution)x.Solution, x => x.EstimatedValuesEnumerator.Current);81 .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current); 82 82 } 83 83 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs
r7562 r7729 47 47 protected double[] classValues; 48 48 49 /// <summary>50 ///51 /// </summary>52 /// <param name="discriminantSolutions"></param>53 /// <returns>median instead of weights, because it doesn't use any weights</returns>54 49 protected override IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 55 List<List<double>> estimatedValues = new List<List<double>>(); 56 List<List<double>> estimatedClassValues = new List<List<double>>(); 57 58 List<IClassificationProblemData> solutionProblemData = discriminantSolutions.Select(sol => sol.ProblemData).ToList(); 59 Dataset dataSet = solutionProblemData[0].Dataset; 60 IEnumerable<int> rows = Enumerable.Range(0, dataSet.Rows); 61 foreach (var solution in discriminantSolutions) { 62 estimatedValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList()); 63 estimatedClassValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList()); 50 classValues = discriminantSolutions.First().Model.ClassValues.ToArray(); 51 var modelThresholds = discriminantSolutions.Select(x => x.Model.Thresholds.ToArray()); 52 threshold = new double[modelThresholds.First().GetLength(0)]; 53 for (int i = 0; i < modelThresholds.First().GetLength(0); i++) { 54 threshold[i] = GetMedian(modelThresholds.Select(x => x[i]).ToList()); 64 55 } 65 66 List<double> median = new List<double>();67 List<double> targetValues = dataSet.GetDoubleValues(solutionProblemData[0].TargetVariable).ToList();68 IList<double> curTrainingpoints = new List<double>();69 int removed = 0;70 int count = targetValues.Count;71 for (int point = 0; point < count; point++) {72 curTrainingpoints.Clear();73 for (int solutionPos = 0; solutionPos < solutionProblemData.Count; solutionPos++) {74 if (PointInTraining(solutionProblemData[solutionPos], point)) {75 curTrainingpoints.Add(estimatedValues[solutionPos][point]);76 }77 }78 if (curTrainingpoints.Count > 0)79 median.Add(GetMedian(curTrainingpoints.OrderBy(p => p).ToList()));80 else {81 //remove not used points82 targetValues.RemoveAt(point - removed);83 removed++;84 }85 }86 AccuracyMaximizationThresholdCalculator.CalculateThresholds(solutionProblemData[0], median, targetValues, out classValues, out threshold);87 56 return Enumerable.Repeat<double>(1, discriminantSolutions.Count()); 88 57 } 89 58 90 protected override double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<I ClassificationSolution, double> estimatedValues) {59 protected override double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IDiscriminantFunctionClassificationSolution, double> estimatedValues) { 91 60 IList<double> values = estimatedValues.Select(x => x.Value).ToList(); 92 61 if (values.Count <= 0) … … 122 91 else { 123 92 double distance = threshold[1] - classValues[0]; 124 return (1 / distance) * ( median - classValues[0]);93 return (1 / distance) * (threshold[1] - median); 125 94 } 126 95 } else if (estimatedClassValue.Equals(classValues[1])) { … … 131 100 else { 132 101 double distance = classValues[1] - threshold[1]; 133 return (1 / distance) * ( classValues[1] - median);102 return (1 / distance) * (median - threshold[1]); 134 103 } 135 104 } else
Note: See TracChangeset
for help on using the changeset viewer.