Free cookie consent management tool by TermsFeed Policy Generator

Changeset 7504


Ignore:
Timestamp:
02/21/12 16:44:20 (13 years ago)
Author:
sforsten
Message:

#1776:

  • improvements in the usage of the WeightCalculators
  • small changes in all class which inherit from the WeightCalculator class
Location:
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4
Files:
8 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/HeuristicLab.Problems.DataAnalysis-3.4.csproj

    r7491 r7504  
    129129    <Compile Include="Implementation\Classification\DiscriminantFunctionClassificationSolutionBase.cs" />
    130130    <Compile Include="Implementation\Classification\WeightCalculators\AccuracyWeightCalculator.cs" />
     131    <Compile Include="Implementation\Classification\WeightCalculators\ContinuousPointCertaintyWeightCalculator.cs" />
    131132    <Compile Include="Implementation\Classification\WeightCalculators\NeighbourhoodWeightCalculator.cs" />
    132133    <Compile Include="Implementation\Classification\WeightCalculators\PointCertaintyWeightCalculator.cs" />
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r7464 r7504  
    6868        if (value != null) {
    6969          weightCalculator = value;
    70           weights = weights = weightCalculator.CalculateWeights(classificationSolutions);
     70          weights = weights = weightCalculator.CalculateNormalizedWeights(classificationSolutions);
    7171          if (!ProblemData.IsEmpty)
    7272            RecalculateResults();
     
    316316      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    317317      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
    318       weights = weightCalculator.CalculateWeights(classificationSolutions);
     318      weights = weightCalculator.CalculateNormalizedWeights(classificationSolutions);
    319319    }
    320320
     
    324324      trainingPartitions.Remove(solution.Model);
    325325      testPartitions.Remove(solution.Model);
    326       weights = weightCalculator.CalculateWeights(classificationSolutions);
     326      weights = weightCalculator.CalculateNormalizedWeights(classificationSolutions);
    327327    }
    328328  }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AccuracyWeightCalculator.cs

    r7464 r7504  
    4848    }
    4949
    50     public override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
    51       double sum = classificationSolutions.Select(s => s.TestAccuracy).Sum();
    52       List<double> weights = new List<double>();
    53       foreach (var item in classificationSolutions) {
    54         weights.Add(item.TestAccuracy / sum);
    55       }
    56       return weights;
     50    protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
     51      return classificationSolutions.Select(s => s.TrainingAccuracy);
    5752    }
    5853  }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MajorityVoteWeightCalculator.cs

    r7459 r7504  
    4848    }
    4949
    50     public override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
     50    protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
    5151      return Enumerable.Repeat<double>(1, classificationSolutions.Count);
    5252    }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/NeighbourhoodWeightCalculator.cs

    r7491 r7504  
    4949    }
    5050
    51     public override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
     51    protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
    5252      if (classificationSolutions.Count <= 0)
    5353        return new List<double>();
     
    5656        return Enumerable.Repeat<double>(1, classificationSolutions.Count);
    5757
    58       ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions = new ItemCollection<IDiscriminantFunctionClassificationSolution>();
    59       List<List<double>> estimatedTestValEnumerators = new List<List<double>>();
    60       List<List<double>> estimatedTestClassValEnumerators = new List<List<double>>();
     58      List<List<double>> estimatedTrainingValEnumerators = new List<List<double>>();
     59      List<List<double>> estimatedTrainingClassValEnumerators = new List<List<double>>();
     60      IDiscriminantFunctionClassificationSolution discriminantSolution;
    6161      foreach (var solution in classificationSolutions) {
    62         discriminantSolutions.Add((IDiscriminantFunctionClassificationSolution)solution);
    63         estimatedTestValEnumerators.Add(discriminantSolutions.Last().EstimatedTestValues.ToList());
    64         estimatedTestClassValEnumerators.Add(discriminantSolutions.Last().EstimatedTestClassValues.ToList());
     62        discriminantSolution = (IDiscriminantFunctionClassificationSolution)solution;
     63        estimatedTrainingValEnumerators.Add(discriminantSolution.EstimatedTrainingValues.ToList());
     64        estimatedTrainingClassValEnumerators.Add(discriminantSolution.EstimatedTrainingClassValues.ToList());
    6565      }
    6666
     
    6868
    6969      IClassificationProblemData problemData = classificationSolutions.ElementAt(0).ProblemData;
    70       IList<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList();
    71       List<double> testVal = GetValues(targetValues, problemData.TestIndizes).ToList();
     70      List<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList();
     71      List<double> trainingVal = GetValues(targetValues, problemData.TrainingIndizes).ToList();
    7272
    7373      double pointAvg, help;
    7474      int count;
    75       for (int point = 0; point < estimatedTestClassValEnumerators.First().Count; point++) {
     75      for (int point = 0; point < estimatedTrainingClassValEnumerators.First().Count; point++) {
    7676        pointAvg = 0.0;
    7777        count = 0;
    78         for (int solution = 0; solution < estimatedTestClassValEnumerators.Count; solution++) {
    79           if (estimatedTestClassValEnumerators[solution][point].Equals(testVal[point])) {
    80             pointAvg += estimatedTestValEnumerators[solution][point];
     78        for (int solution = 0; solution < estimatedTrainingClassValEnumerators.Count; solution++) {
     79          if (estimatedTrainingClassValEnumerators[solution][point].Equals(targetValues[point])) {
     80            pointAvg += estimatedTrainingValEnumerators[solution][point];
    8181            count++;
    8282          }
    8383        }
    8484        pointAvg /= (double)count;
    85         for (int solution = 0; solution < estimatedTestClassValEnumerators.Count; solution++) {
    86           if (estimatedTestClassValEnumerators[solution][point].Equals(testVal[point])) {
     85        for (int solution = 0; solution < estimatedTrainingClassValEnumerators.Count; solution++) {
     86          if (estimatedTrainingClassValEnumerators[solution][point].Equals(targetValues[point])) {
    8787            weights[solution] += 0.5;
    88             help = Math.Abs(estimatedTestValEnumerators[solution][point] - 0.5);
     88            help = Math.Abs(estimatedTrainingValEnumerators[solution][point] - 0.5);
    8989            weights[solution] += help < 0.5 ? 0.5 - help : 0.0;
    9090          }
    9191        }
    9292      }
    93       return weights.Select(x => x / weights.Sum());
     93      return weights;
    9494    }
    9595
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/PointCertaintyWeightCalculator.cs

    r7491 r7504  
    4949    }
    5050
    51     public override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
     51    protected override IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
    5252      if (classificationSolutions.Count <= 0)
    5353        return new List<double>();
     
    6262
    6363      List<double> weights = new List<double>();
    64 
    6564      IClassificationProblemData problemData = classificationSolutions.ElementAt(0).ProblemData;
    66       IList<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList();
    67       IEnumerator<double> testVal;
    68 
     65      IEnumerable<double> targetValues = GetValues(problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(), problemData.TrainingIndizes);
     66      IEnumerator<double> trainingValues;
    6967      double avg = problemData.ClassValues.Average();
    7068
    7169      foreach (var solution in discriminantSolutions) {
    72         IEnumerator<double> estimatedTest = solution.EstimatedTestValues.GetEnumerator();
    73         IEnumerator<double> estimatedTestClass = solution.EstimatedTestClassValues.GetEnumerator();
     70        IEnumerator<double> estimatedTrainingVal = solution.EstimatedTrainingValues.GetEnumerator();
     71        IEnumerator<double> estimatedTrainingClassVal = solution.EstimatedTrainingClassValues.GetEnumerator();
    7472
    75         testVal = GetValues(targetValues, problemData.TestIndizes).GetEnumerator();
     73        trainingValues = targetValues.GetEnumerator();
    7674        double curWeight = 0.0;
    77         while (estimatedTest.MoveNext() && estimatedTestClass.MoveNext() && testVal.MoveNext()) {
    78           if (estimatedTestClass.Current.Equals(testVal.Current)) {
     75        while (estimatedTrainingVal.MoveNext() && estimatedTrainingClassVal.MoveNext() && trainingValues.MoveNext()) {
     76          if (estimatedTrainingClassVal.Current.Equals(trainingValues.Current)) {
    7977            curWeight += 0.5;
    80             double distanceToPoint = Math.Abs(estimatedTest.Current - avg);
    81             double distanceToClass = Math.Abs(testVal.Current - avg);
    82             if (testVal.Current > avg && estimatedTest.Current > avg
    83              || testVal.Current < avg && estimatedTest.Current < avg)
     78            double distanceToPoint = Math.Abs(estimatedTrainingVal.Current - avg);
     79            double distanceToClass = Math.Abs(trainingValues.Current - avg);
     80            if (trainingValues.Current > avg && estimatedTrainingVal.Current > avg
     81             || trainingValues.Current < avg && estimatedTrainingVal.Current < avg)
    8482              curWeight += distanceToPoint < distanceToClass ? (0.5 / distanceToClass) * distanceToPoint : 0.5;
    8583          }
     
    8785        weights.Add(curWeight);
    8886      }
    89       return weights.Select(x => x / weights.Sum());
     87      return weights;
    9088    }
    9189
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/WeightCalculator.cs

    r7459 r7504  
    2121
    2222using System.Collections.Generic;
     23using System.Linq;
    2324using HeuristicLab.Common;
    2425using HeuristicLab.Core;
     
    4344    }
    4445
    45     abstract public IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions);
     46    /// <summary>
     47    /// calls CalculateWeights and removes negative weights
     48    /// </summary>
     49    /// <param name="classificationSolutions"></param>
     50    /// <returns>weights which are equal or bigger than zero</returns>
     51    public IEnumerable<double> CalculateNormalizedWeights(ItemCollection<IClassificationSolution> classificationSolutions) {
     52      List<double> weights = new List<double>();
     53      foreach (var weight in CalculateWeights(classificationSolutions)) {
     54        weights.Add(weight >= 0 ? weight : 0);
     55      }
     56      return weights.Select(x => x / weights.Sum());
     57    }
     58
     59    protected abstract IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions);
    4660  }
    4761}
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleSolutionWeightCalculator.cs

    r7464 r7504  
    2525namespace HeuristicLab.Problems.DataAnalysis.Interfaces.Classification {
    2626  public interface IClassificationEnsembleSolutionWeightCalculator : INamedItem {
    27     IEnumerable<double> CalculateWeights(ItemCollection<IClassificationSolution> classificationSolutions);
     27    IEnumerable<double> CalculateNormalizedWeights(ItemCollection<IClassificationSolution> classificationSolutions);
    2828  }
    2929}
Note: See TracChangeset for help on using the changeset viewer.