Free cookie consent management tool by TermsFeed Policy Generator

Changeset 6239


Ignore:
Timestamp:
05/20/11 16:10:07 (12 years ago)
Author:
gkronber
Message:

#1450: implemented support for ensemble solutions for classification.

Location:
trunk/sources
Files:
1 added
8 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/CrossValidation.cs

    r6184 r6239  
    376376        results.Add(result.Name, result.Value);
    377377      }
     378      foreach (IResult result in ExtractAndAggregateClassificationSolutions(resultCollections)) {
     379        results.Add(result.Name, result.Value);
     380      }
    378381      results.Add("Execution Time", new TimeSpanValue(this.ExecutionTime));
    379382      results.Add("CrossValidation Folds", new RunCollection(runs));
     
    406409    }
    407410
     411    private IEnumerable<IResult> ExtractAndAggregateClassificationSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) {
     412      Dictionary<string, List<IClassificationSolution>> resultSolutions = new Dictionary<string, List<IClassificationSolution>>();
     413      foreach (var result in resultCollections) {
     414        var classificationSolution = result.Value as IClassificationSolution;
     415        if (classificationSolution != null) {
     416          if (resultSolutions.ContainsKey(result.Key)) {
     417            resultSolutions[result.Key].Add(classificationSolution);
     418          } else {
     419            resultSolutions.Add(result.Key, new List<IClassificationSolution>() { classificationSolution });
     420          }
     421        }
     422      }
     423      List<IResult> aggregatedResults = new List<IResult>();
     424      foreach (KeyValuePair<string, List<IClassificationSolution>> solutions in resultSolutions) {
     425        var problemDataClone = (IClassificationProblemData)Problem.ProblemData.Clone();
     426        problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
     427        problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
     428        var ensembleSolution = new ClassificationEnsembleSolution(solutions.Value.Select(x => x.Model), problemDataClone,
     429          solutions.Value.Select(x => x.ProblemData.TrainingPartition),
     430          solutions.Value.Select(x => x.ProblemData.TestPartition));
     431
     432        aggregatedResults.Add(new Result(solutions.Key, ensembleSolution));
     433      }
     434      return aggregatedResults;
     435    }
     436
    408437    private static IEnumerable<IResult> ExtractAndAggregateResults<T>(IEnumerable<KeyValuePair<string, IItem>> results)
    409438  where T : class, IItem, new() {
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationSolutionConfusionMatrixView.cs

    r5975 r6239  
    103103        IEnumerable<int> rows;
    104104
     105        double[] predictedValues;
    105106        if (cmbSamples.SelectedItem.ToString() == TrainingSamples) {
    106107          rows = Content.ProblemData.TrainingIndizes;
     108          predictedValues = Content.EstimatedTrainingClassValues.ToArray();
    107109        } else if (cmbSamples.SelectedItem.ToString() == TestSamples) {
    108110          rows = Content.ProblemData.TestIndizes;
     111          predictedValues = Content.EstimatedTestClassValues.ToArray();         
    109112        } else throw new InvalidOperationException();
     113
     114        double[] targetValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable, rows).ToArray();
    110115
    111116        Dictionary<double, int> classValueIndexMapping = new Dictionary<double, int>();
     
    115120          index++;
    116121        }
    117 
    118         double[] targetValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable, rows).ToArray();
    119         double[] predictedValues = Content.GetEstimatedClassValues(rows).ToArray();
    120122
    121123        for (int i = 0; i < targetValues.Length; i++) {
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationSolutionEstimatedClassValuesView.cs

    r5975 r6239  
    3232  [Content(typeof(IClassificationSolution))]
    3333  public partial class ClassificationSolutionEstimatedClassValuesView : ItemView, IClassificationSolutionEvaluationView {
    34     private const string TARGETVARIABLE_SERIES_NAME = "TargetVariable";
    35     private const string ESTIMATEDVALUES_SERIES_NAME = "EstimatedClassValues";
     34    private const string TARGETVARIABLE_SERIES_NAME = "Target Variable";
     35    private const string ESTIMATEDVALUES_TRAINING_SERIES_NAME = "Estimated Class Values (training)";
     36    private const string ESTIMATEDVALUES_TEST_SERIES_NAME = "Estimated Class Values (test)";
    3637
    3738    public new IClassificationSolution Content {
     
    8586        DoubleMatrix matrix = null;
    8687        if (Content != null) {
    87           double[,] values = new double[Content.ProblemData.Dataset.Rows, 2];
     88          double[,] values = new double[Content.ProblemData.Dataset.Rows, 3];
     89          // fill with NaN
     90          for (int row = 0; row < Content.ProblemData.Dataset.Rows; row++)
     91            for (int column = 0; column < 3; column++)
     92              values[row, column] = double.NaN;
    8893
    8994          double[] target = Content.ProblemData.Dataset.GetVariableValues(Content.ProblemData.TargetVariable);
    90           double[] estimated = Content.EstimatedClassValues.ToArray();
    9195          for (int row = 0; row < target.Length; row++) {
    9296            values[row, 0] = target[row];
    93             values[row, 1] = estimated[row];
     97          }
     98          var estimatedTraining = Content.EstimatedTrainingClassValues.GetEnumerator();
     99          estimatedTraining.MoveNext();
     100          foreach (var trainingRow in Content.ProblemData.TrainingIndizes) {
     101            values[trainingRow, 1] = estimatedTraining.Current;
     102            estimatedTraining.MoveNext();
     103          }
     104          var estimatedTest = Content.EstimatedTestClassValues.GetEnumerator();
     105          estimatedTest.MoveNext();
     106          foreach (var testRow in Content.ProblemData.TestIndizes) {
     107            values[testRow, 2] = estimatedTest.Current;
     108            estimatedTest.MoveNext();
    94109          }
    95110
    96111          matrix = new DoubleMatrix(values);
    97           matrix.ColumnNames = new string[] { TARGETVARIABLE_SERIES_NAME, ESTIMATEDVALUES_SERIES_NAME };
     112          matrix.ColumnNames = new string[] { TARGETVARIABLE_SERIES_NAME, ESTIMATEDVALUES_TRAINING_SERIES_NAME, ESTIMATEDVALUES_TEST_SERIES_NAME };
    98113        }
    99114        matrixView.Content = matrix;
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/HeuristicLab.Problems.DataAnalysis-3.4.csproj

    r6238 r6239  
    109109  <ItemGroup>
    110110    <Compile Include="DoubleLimit.cs" />
     111    <Compile Include="Implementation\Classification\ClassificationEnsembleModel.cs">
     112      <SubType>Code</SubType>
     113    </Compile>
     114    <Compile Include="Implementation\Classification\ClassificationEnsembleSolution.cs" />
    111115    <Compile Include="Implementation\Classification\ClassificationProblemData.cs" />
    112116    <Compile Include="Implementation\Classification\ClassificationProblem.cs" />
    113117    <Compile Include="Implementation\Classification\ClassificationSolution.cs" />
     118    <Compile Include="Implementation\Classification\ClassificationEnsembleProblemData.cs" />
    114119    <Compile Include="Implementation\Clustering\ClusteringProblem.cs" />
    115120    <Compile Include="Implementation\Clustering\ClusteringProblemData.cs" />
     
    120125    </Compile>
    121126    <Compile Include="Implementation\Regression\RegressionEnsembleSolution.cs" />
     127    <Compile Include="Interfaces\Classification\IClassificationEnsembleModel.cs">
     128      <SubType>Code</SubType>
     129    </Compile>
     130    <Compile Include="Interfaces\Classification\IClassificationEnsembleSolution.cs">
     131      <SubType>Code</SubType>
     132    </Compile>
    122133    <Compile Include="Interfaces\Classification\IDiscriminantFunctionThresholdCalculator.cs" />
    123134    <Compile Include="Interfaces\Regression\IRegressionEnsembleModel.cs">
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleModel.cs

    r5809 r6239  
    3939      get { return new List<IClassificationModel>(models); }
    4040    }
     41
    4142    [StorableConstructor]
    4243    protected ClassificationEnsembleModel(bool deserializing) : base(deserializing) { }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r6184 r6239  
    2525using HeuristicLab.Core;
    2626using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     27using HeuristicLab.Data;
     28using System;
    2729
    2830namespace HeuristicLab.Problems.DataAnalysis {
     
    3335  [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")]
    3436  // [Creatable("Data Analysis")]
    35   public class ClassificationEnsembleSolution : NamedItem, IClassificationEnsembleSolution {
     37  public class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution {
     38
     39    public new IClassificationEnsembleModel Model {
     40      set { base.Model = value; }
     41      get { return (IClassificationEnsembleModel)base.Model; }
     42    }
    3643
    3744    [Storable]
    38     private List<IClassificationModel> models;
    39     public IEnumerable<IClassificationModel> Models {
    40       get { return new List<IClassificationModel>(models); }
    41     }
     45    private Dictionary<IClassificationModel, IntRange> trainingPartitions;
     46    [Storable]
     47    private Dictionary<IClassificationModel, IntRange> testPartitions;
     48
     49
    4250    [StorableConstructor]
    4351    protected ClassificationEnsembleSolution(bool deserializing) : base(deserializing) { }
    4452    protected ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
    4553      : base(original, cloner) {
    46       this.models = original.Models.Select(m => cloner.Clone(m)).ToList();
     54      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     55      testPartitions = new Dictionary<IClassificationModel, IntRange>();
     56      foreach (var model in Model.Models) {
     57        trainingPartitions[model] = (IntRange)ProblemData.TrainingPartition.Clone();
     58        testPartitions[model] = (IntRange)ProblemData.TestPartition.Clone();
     59      }
     60      RecalculateResults();
    4761    }
    48     public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models)
    49       : base() {
     62    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData)
     63      : base(new ClassificationEnsembleModel(models), new ClassificationEnsembleProblemData(problemData)) {
    5064      this.name = ItemName;
    5165      this.description = ItemDescription;
    52       this.models = new List<IClassificationModel>(models);
     66      trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     67      testPartitions = new Dictionary<IClassificationModel, IntRange>();
     68      foreach (var model in models) {
     69        trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone();
     70        testPartitions[model] = (IntRange)problemData.TestPartition.Clone();
     71      }
     72      RecalculateResults();
     73    }
     74
     75    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
     76      : base(new ClassificationEnsembleModel(models), new ClassificationEnsembleProblemData(problemData)) {
     77      this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
     78      this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
     79      var modelEnumerator = models.GetEnumerator();
     80      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
     81      var testPartitionEnumerator = testPartitions.GetEnumerator();
     82      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
     83        this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone();
     84        this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone();
     85      }
     86      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
     87        throw new ArgumentException();
     88      }
     89      RecalculateResults();
    5390    }
    5491
     
    5794    }
    5895
    59     #region IClassificationEnsembleModel Members
     96    public override IEnumerable<double> EstimatedTrainingClassValues {
     97      get {
     98        var rows = ProblemData.TrainingIndizes;
     99        var estimatedValuesEnumerators = (from model in Model.Models
     100                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     101                                         .ToList();
     102        var rowsEnumerator = rows.GetEnumerator();
     103        // aggregate to make sure that MoveNext is called for all enumerators
     104        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     105          int currentRow = rowsEnumerator.Current;
     106
     107          var selectedEnumerators = from pair in estimatedValuesEnumerators
     108                                    where trainingPartitions == null || !trainingPartitions.ContainsKey(pair.Model) ||
     109                                         (trainingPartitions[pair.Model].Start <= currentRow && currentRow < trainingPartitions[pair.Model].End)
     110                                    select pair.EstimatedValuesEnumerator;
     111          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     112        }
     113      }
     114    }
     115
     116    public override IEnumerable<double> EstimatedTestClassValues {
     117      get {
     118        var rows = ProblemData.TestIndizes;
     119        var estimatedValuesEnumerators = (from model in Model.Models
     120                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() })
     121                                         .ToList();
     122        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
     123        // aggregate to make sure that MoveNext is called for all enumerators
     124        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
     125          int currentRow = rowsEnumerator.Current;
     126
     127          var selectedEnumerators = from pair in estimatedValuesEnumerators
     128                                    where testPartitions == null || !testPartitions.ContainsKey(pair.Model) ||
     129                                      (testPartitions[pair.Model].Start <= currentRow && currentRow < testPartitions[pair.Model].End)
     130                                    select pair.EstimatedValuesEnumerator;
     131
     132          yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
     133        }
     134      }
     135    }
     136
     137    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     138      return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows)
     139             select AggregateEstimatedClassValues(xs);
     140    }
    60141
    61142    public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) {
    62       var estimatedValuesEnumerators = (from model in models
     143      var estimatedValuesEnumerators = (from model in Model.Models
    63144                                        select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())
    64145                                       .ToList();
     
    70151    }
    71152
    72     #endregion
    73 
    74     #region IClassificationModel Members
    75 
    76     public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    77       foreach (var estimatedValuesVector in GetEstimatedClassValueVectors(dataset, rows)) {
    78         // return the class which is most often occuring
    79         yield return
    80           estimatedValuesVector
    81           .GroupBy(x => x)
    82           .OrderBy(g => -g.Count())
    83           .Select(g => g.Key)
    84           .First();
    85       }
     153    private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) {
     154      return estimatedClassValues
     155      .GroupBy(x => x)
     156      .OrderBy(g => -g.Count())
     157      .Select(g => g.Key)
     158      .First();
    86159    }
    87 
    88     #endregion
    89160  }
    90161}
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r6238 r6239  
    4949    protected RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner)
    5050      : base(original, cloner) {
     51      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
     52      testPartitions = new Dictionary<IRegressionModel, IntRange>();
     53      foreach (var model in Model.Models) {
     54        trainingPartitions[model] = (IntRange)ProblemData.TrainingPartition.Clone();
     55        testPartitions[model] = (IntRange)ProblemData.TestPartition.Clone();
     56      }
    5157    }
     58
    5259    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
    5360      : base(new RegressionEnsembleModel(models), new RegressionEnsembleProblemData(problemData)) {
     
    141148    private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) {
    142149      return estimatedValues.DefaultIfEmpty(double.NaN).Average();
    143     }
    144 
    145     //[Storable]
    146     //private string name;
    147     //public string Name {
    148     //  get {
    149     //    return name;
    150     //  }
    151     //  set {
    152     //    if (value != null && value != name) {
    153     //      var cancelEventArgs = new CancelEventArgs<string>(value);
    154     //      OnNameChanging(cancelEventArgs);
    155     //      if (cancelEventArgs.Cancel == false) {
    156     //        name = value;
    157     //        OnNamedChanged(EventArgs.Empty);
    158     //      }
    159     //    }
    160     //  }
    161     //}
    162 
    163     //public bool CanChangeName {
    164     //  get { return true; }
    165     //}
    166 
    167     //[Storable]
    168     //private string description;
    169     //public string Description {
    170     //  get {
    171     //    return description;
    172     //  }
    173     //  set {
    174     //    if (value != null && value != description) {
    175     //      description = value;
    176     //      OnDescriptionChanged(EventArgs.Empty);
    177     //    }
    178     //  }
    179     //}
    180 
    181     //public bool CanChangeDescription {
    182     //  get { return true; }
    183     //}
    184 
    185     //#region events
    186     //public event EventHandler<CancelEventArgs<string>> NameChanging;
    187     //private void OnNameChanging(CancelEventArgs<string> cancelEventArgs) {
    188     //  var listener = NameChanging;
    189     //  if (listener != null) listener(this, cancelEventArgs);
    190     //}
    191 
    192     //public event EventHandler NameChanged;
    193     //private void OnNamedChanged(EventArgs e) {
    194     //  var listener = NameChanged;
    195     //  if (listener != null) listener(this, e);
    196     //}
    197 
    198     //public event EventHandler DescriptionChanged;
    199     //private void OnDescriptionChanged(EventArgs e) {
    200     //  var listener = DescriptionChanged;
    201     //  if (listener != null) listener(this, e);
    202     //}
    203     // #endregion
     150    }   
    204151  }
    205152}
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleSolution.cs

    r6184 r6239  
    2323namespace HeuristicLab.Problems.DataAnalysis {
    2424  public interface IClassificationEnsembleSolution : IClassificationSolution {
    25     IEnumerable<IClassificationModel> Models { get; }
     25    new IClassificationEnsembleModel Model { get; }
    2626    IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows);
    2727  }
Note: See TracChangeset for help on using the changeset viewer.