Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/16/12 09:14:38 (12 years ago)
Author:
sforsten
Message:

#1776:

  • bugfix the method GetEstimatedValues of DiscriminantClassificationWeightCalculator returns real values and not class values
  • changed arguments of method DiscriminantAggregateEstimatedClassValues of DiscriminantClassificationWeightCalculator
  • added two calculators
Location:
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators
Files:
2 added
2 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs

    r7562 r7729  
    5959
    6060      IEnumerable<IDictionary<IClassificationSolution, double>> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler);
    61       IEnumerable<IDictionary<IClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler);
     61      IEnumerable<IDictionary<IDiscriminantFunctionClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler);
    6262
    6363      return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values })
     
    6565    }
    6666
    67     protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) {
     67    protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IDiscriminantFunctionClassificationSolution, double> estimatedValues) {
    6868      return AggregateEstimatedClassValues(estimatedClassValues);
    6969    }
    7070
    71     protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
     71    protected IEnumerable<IDictionary<IDiscriminantFunctionClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) {
    7272      var estimatedValuesEnumerators = (from solution in solutions
    73                                         select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() })
     73                                        select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedValues(dataset, rows).GetEnumerator() })
    7474                                        .ToList();
    7575
     
    7979                      where handler(enumerator.Solution.ProblemData, rowEnumerator.Current)
    8080                      select enumerator)
    81                       .ToDictionary(x => (IClassificationSolution)x.Solution, x => x.EstimatedValuesEnumerator.Current);
     81                      .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current);
    8282      }
    8383    }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs

    r7562 r7729  
    4747    protected double[] classValues;
    4848
    49     /// <summary>
    50     ///
    51     /// </summary>
    52     /// <param name="discriminantSolutions"></param>
    53     /// <returns>median instead of weights, because it doesn't use any weights</returns>
    5449    protected override IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {
    55       List<List<double>> estimatedValues = new List<List<double>>();
    56       List<List<double>> estimatedClassValues = new List<List<double>>();
    57 
    58       List<IClassificationProblemData> solutionProblemData = discriminantSolutions.Select(sol => sol.ProblemData).ToList();
    59       Dataset dataSet = solutionProblemData[0].Dataset;
    60       IEnumerable<int> rows = Enumerable.Range(0, dataSet.Rows);
    61       foreach (var solution in discriminantSolutions) {
    62         estimatedValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList());
    63         estimatedClassValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList());
     50      classValues = discriminantSolutions.First().Model.ClassValues.ToArray();
     51      var modelThresholds = discriminantSolutions.Select(x => x.Model.Thresholds.ToArray());
     52      threshold = new double[modelThresholds.First().GetLength(0)];
     53      for (int i = 0; i < modelThresholds.First().GetLength(0); i++) {
     54        threshold[i] = GetMedian(modelThresholds.Select(x => x[i]).ToList());
    6455      }
    65 
    66       List<double> median = new List<double>();
    67       List<double> targetValues = dataSet.GetDoubleValues(solutionProblemData[0].TargetVariable).ToList();
    68       IList<double> curTrainingpoints = new List<double>();
    69       int removed = 0;
    70       int count = targetValues.Count;
    71       for (int point = 0; point < count; point++) {
    72         curTrainingpoints.Clear();
    73         for (int solutionPos = 0; solutionPos < solutionProblemData.Count; solutionPos++) {
    74           if (PointInTraining(solutionProblemData[solutionPos], point)) {
    75             curTrainingpoints.Add(estimatedValues[solutionPos][point]);
    76           }
    77         }
    78         if (curTrainingpoints.Count > 0)
    79           median.Add(GetMedian(curTrainingpoints.OrderBy(p => p).ToList()));
    80         else {
    81           //remove not used points
    82           targetValues.RemoveAt(point - removed);
    83           removed++;
    84         }
    85       }
    86       AccuracyMaximizationThresholdCalculator.CalculateThresholds(solutionProblemData[0], median, targetValues, out classValues, out threshold);
    8756      return Enumerable.Repeat<double>(1, discriminantSolutions.Count());
    8857    }
    8958
    90     protected override double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) {
     59    protected override double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IDiscriminantFunctionClassificationSolution, double> estimatedValues) {
    9160      IList<double> values = estimatedValues.Select(x => x.Value).ToList();
    9261      if (values.Count <= 0)
     
    12291        else {
    12392          double distance = threshold[1] - classValues[0];
    124           return (1 / distance) * (median - classValues[0]);
     93          return (1 / distance) * (threshold[1] - median);
    12594        }
    12695      } else if (estimatedClassValue.Equals(classValues[1])) {
     
    131100        else {
    132101          double distance = classValues[1] - threshold[1];
    133           return (1 / distance) * (classValues[1] - median);
     102          return (1 / distance) * (median - threshold[1]);
    134103        }
    135104      } else
Note: See TracChangeset for help on using the changeset viewer.