Changeset 17835


Ignore:
Timestamp:
02/12/21 16:26:36 (2 weeks ago)
Author:
bburlacu
Message:

#3102: Add ClassificationProblemData constructor that explicitly takes class names and positive class value arguments, adapt code.

Location:
trunk
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/CrossValidation.cs

    r17824 r17835  
    521521      foreach (KeyValuePair<string, List<IClassificationSolution>> solutions in resultSolutions) {
    522522        // at least one algorithm (GBT with logistic regression loss) produces a classification solution even though the original problem is a regression problem.
    523         var targetVariable = solutions.Value.First().ProblemData.TargetVariable;
    524523        var dataset = (Dataset)Problem.ProblemData.Dataset;
    525524        if (ShuffleSamples.Value) {
     
    527526          dataset = dataset.Shuffle(random);
    528527        }
    529         var problemDataClone = new ClassificationProblemData(dataset, Problem.ProblemData.AllowedInputVariables, targetVariable);
     528        var problemData = (IClassificationProblemData)Problem.ProblemData;
     529        var problemDataClone = new ClassificationProblemData(dataset, problemData.AllowedInputVariables, problemData.TargetVariable, problemData.ClassNames, problemData.PositiveClass);
    530530        // set partitions of problem data clone correctly
    531531        problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
  • trunk/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification.Views/3.4/SolutionComparisonView.cs

    r17421 r17835  
    9090
    9191      var newDs = new Dataset(variableNames, variableValues);
    92       var newProblemData = new ClassificationProblemData(newDs, variableNames.Take(variableNames.Length - 1), variableNames.Last());
    93 
    94       foreach (var classValue in problemData.ClassValues) {
    95         newProblemData.SetClassName(classValue, problemData.GetClassName(classValue));
    96       }
    97       newProblemData.PositiveClass = problemData.PositiveClass;
     92      var newProblemData = new ClassificationProblemData(newDs, variableNames.Take(variableNames.Length - 1), variableNames.Last(), problemData.ClassNames, problemData.PositiveClass);
    9893
    9994      newProblemData.TrainingPartition.Start = problemData.TrainingPartition.Start;
  • trunk/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleProblemData.cs

    r17180 r17835  
    7373
    7474    public ClassificationEnsembleProblemData() : base() { }
     75
    7576    public ClassificationEnsembleProblemData(IClassificationProblemData classificationProblemData)
    76       : base(classificationProblemData.Dataset, classificationProblemData.AllowedInputVariables, classificationProblemData.TargetVariable) {
    77       this.TrainingPartition.Start = classificationProblemData.TrainingPartition.Start;
    78       this.TrainingPartition.End = classificationProblemData.TrainingPartition.End;
    79       this.TestPartition.Start = classificationProblemData.TestPartition.Start;
    80       this.TestPartition.End = classificationProblemData.TestPartition.End;
    81       this.PositiveClass = classificationProblemData.PositiveClass;
     77      : base(classificationProblemData) {
    8278    }
    8379
     
    8581      : base(dataset, allowedInputVariables, targetVariable) {
    8682    }
     83
     84    public ClassificationEnsembleProblemData(Dataset dataset, IEnumerable<string> allowedInputVariables, string targetVariable, IEnumerable<string> classNames, string positiveClass = null)
     85      : base(dataset, allowedInputVariables, targetVariable, classNames, positiveClass) {
     86    }
    8787  }
    8888}
  • trunk/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r17180 r17835  
    260260      evaluationCache.Clear();
    261261
    262       IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset,
    263                                                                      ProblemData.AllowedInputVariables,
    264                                                                      ProblemData.TargetVariable);
    265       problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
    266       problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
    267       problemData.TestPartition.Start = ProblemData.TestPartition.Start;
    268       problemData.TestPartition.End = ProblemData.TestPartition.End;
     262      IClassificationProblemData problemData = new ClassificationProblemData(ProblemData);
    269263
    270264      foreach (var solution in ClassificationSolutions) {
  • trunk/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblemData.cs

    r17180 r17835  
    311311    }
    312312
    313     public ClassificationProblemData() : this(defaultDataset, defaultAllowedInputVariables, defaultTargetVariable) { }
     313    public ClassificationProblemData() : this(defaultDataset, defaultAllowedInputVariables, defaultTargetVariable, Enumerable.Empty<string>()) { }
    314314
    315315    public ClassificationProblemData(IClassificationProblemData classificationProblemData)
    316       : this(classificationProblemData.Dataset, classificationProblemData.AllowedInputVariables, classificationProblemData.TargetVariable) {
     316      : this(classificationProblemData.Dataset, classificationProblemData.AllowedInputVariables, classificationProblemData.TargetVariable, classificationProblemData.ClassNames, classificationProblemData.PositiveClass) {
     317     
    317318      TrainingPartition.Start = classificationProblemData.TrainingPartition.Start;
    318319      TrainingPartition.End = classificationProblemData.TrainingPartition.End;
    319320      TestPartition.Start = classificationProblemData.TestPartition.Start;
    320321      TestPartition.End = classificationProblemData.TestPartition.End;
    321 
    322       for (int i = 0; i < classificationProblemData.ClassNames.Count(); i++)
    323         ClassNamesParameter.Value[i, 0] = classificationProblemData.ClassNames.ElementAt(i);
    324 
    325       //mkommend: The positive class depends on the class names and as a result must only be set after the classe names parameter.
    326       PositiveClass = classificationProblemData.PositiveClass;
    327 
     322     
    328323      for (int i = 0; i < Classes; i++) {
    329324        for (int j = 0; j < Classes; j++) {
     
    334329
    335330    public ClassificationProblemData(IDataset dataset, IEnumerable<string> allowedInputVariables, string targetVariable, IEnumerable<ITransformation> transformations = null)
     331      : this(dataset, allowedInputVariables, targetVariable, Enumerable.Empty<string>(), null, transformations) { }
     332
     333    public ClassificationProblemData(IDataset dataset, IEnumerable<string> allowedInputVariables, string targetVariable,
     334      IEnumerable<string> classNames,
     335      string positiveClass = null, // can be null in which case it's set as the first class name
     336      IEnumerable<ITransformation> transformations = null)
    336337      : base(dataset, allowedInputVariables, transformations ?? Enumerable.Empty<ITransformation>()) {
    337338      var validTargetVariableValues = CheckVariablesForPossibleTargetVariables(dataset).Select(x => new StringValue(x).AsReadOnly()).ToList();
     
    339340
    340341      Parameters.Add(new ConstrainedValueParameter<StringValue>(TargetVariableParameterName, new ItemSet<StringValue>(validTargetVariableValues), target));
    341       Parameters.Add(new FixedValueParameter<StringMatrix>(ClassNamesParameterName, ""));
     342      Parameters.Add(new FixedValueParameter<StringMatrix>(ClassNamesParameterName, "", new StringMatrix()));
    342343      Parameters.Add(new ConstrainedValueParameter<StringValue>(PositiveClassParameterName, "The positive class which is used for quality measure calculation (e.g., specifity, sensitivity,...)"));
    343344      Parameters.Add(new FixedValueParameter<DoubleMatrix>(ClassificationPenaltiesParameterName, ""));
    344345
    345346      RegisterParameterEvents();
    346       ResetTargetVariableDependentMembers();
     347      ResetTargetVariableDependentMembers(); // correctly set the values of the parameters added above
     348
     349      // set the class names
     350      if (classNames.Any()) {
     351        // better to allocate lists because we use these multiple times below
     352        var names = classNames.ToList();
     353        var values = ClassValues.ToList();
     354
     355        if (names.Count != values.Count) {
     356          throw new ArgumentException();
     357        }
     358
     359        ((IStringConvertibleMatrix)ClassNamesParameter.Value).Columns = 1;
     360        ((IStringConvertibleMatrix)ClassNamesParameter.Value).Rows = names.Count;
     361
     362        for (int i = 0; i < names.Count; ++i) {
     363          SetClassName(values[i], names[i]);
     364        }
     365      }
     366
     367      // set the positive class value
     368      if (positiveClass != null) {
     369        PositiveClass = positiveClass;
     370      }
    347371    }
    348372
Note: See TracChangeset for help on using the changeset viewer.