Ignore:
Timestamp:
04/10/17 15:48:20 (3 years ago)
Author:
gkronber
Message:

#2700 merged changesets from trunk to branch

Location:
branches/TSNE/HeuristicLab.Algorithms.DataAnalysis
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis

  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/MultinomialLogitClassification.cs

    r14185 r14836  
    2323using System.Collections.Generic;
    2424using System.Linq;
     25using System.Threading;
    2526using HeuristicLab.Common;
    2627using HeuristicLab.Core;
     
    5758
    5859    #region logit classification
    59     protected override void Run() {
     60    protected override void Run(CancellationToken cancellationToken) {
    6061      double rmsError, relClassError;
    6162      var solution = CreateLogitClassificationSolution(Problem.ProblemData, out rmsError, out relClassError);
     
    6869      var dataset = problemData.Dataset;
    6970      string targetVariable = problemData.TargetVariable;
    70       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
     71      var doubleVariableNames = problemData.AllowedInputVariables.Where(dataset.VariableHasType<double>);
     72      var factorVariableNames = problemData.AllowedInputVariables.Where(dataset.VariableHasType<string>);
    7173      IEnumerable<int> rows = problemData.TrainingIndices;
    72       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
     74      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, doubleVariableNames.Concat(new string[] { targetVariable }), rows);
     75
     76      var factorVariableValues = AlglibUtil.GetFactorVariableValues(dataset, factorVariableNames, rows);
     77      var factorMatrix = AlglibUtil.PrepareInputMatrix(dataset, factorVariableValues, rows);
     78      inputMatrix = factorMatrix.HorzCat(inputMatrix);
     79
    7380      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    7481        throw new NotSupportedException("Multinomial logit classification does not support NaN or infinity values in the input dataset.");
     
    95102      relClassError = alglib.mnlrelclserror(lm, inputMatrix, nRows);
    96103
    97       MultinomialLogitClassificationSolution solution = new MultinomialLogitClassificationSolution(new MultinomialLogitModel(lm, targetVariable, allowedInputVariables, classValues), (IClassificationProblemData)problemData.Clone());
     104      MultinomialLogitClassificationSolution solution = new MultinomialLogitClassificationSolution(new MultinomialLogitModel(lm, targetVariable, doubleVariableNames, factorVariableValues, classValues), (IClassificationProblemData)problemData.Clone());
    98105      return solution;
    99106    }
Note: See TracChangeset for help on using the changeset viewer.