Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/05/16 18:44:51 (8 years ago)
Author:
gkronber
Message:

#2650: added support for categorical variables to LDA and MNL (TODO: OneR )

Location:
branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/AlglibUtil.cs

    r14237 r14240  
    118118      return matrix;
    119119    }
     120
     121    public static IEnumerable<KeyValuePair<string, IEnumerable<string>>> GetFactorVariableValues(IDataset ds, IEnumerable<string> factorVariables, IEnumerable<int> rows) {
     122      return from factor in factorVariables
     123             let distinctValues = ds.GetStringValues(factor, rows).Distinct().ToArray()
     124             // 1 distinct value => skip (constant)
     125             // 2 distinct values => only take one of the two values
     126             // >=3 distinct values => create a binary value for each value
     127             let reducedValues = distinctValues.Length <= 2
     128               ? distinctValues.Take(distinctValues.Length - 1)
     129               : distinctValues
     130             select new KeyValuePair<string, IEnumerable<string>>(factor, reducedValues);
     131    }
    120132  }
    121133}
  • branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/LinearDiscriminantAnalysis.cs

    r14185 r14240  
    3636  /// Linear discriminant analysis classification algorithm.
    3737  /// </summary>
    38   [Item("Linear Discriminant Analysis", "Linear discriminant analysis classification algorithm (wrapper for ALGLIB).")]
     38  [Item("Linear Discriminant Analysis (LDA)", "Linear discriminant analysis classification algorithm (wrapper for ALGLIB).")]
    3939  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 100)]
    4040  [StorableClass]
     
    7070      IEnumerable<int> rows = problemData.TrainingIndices;
    7171      int nClasses = problemData.ClassNames.Count();
    72       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
     72      var doubleVariableNames = allowedInputVariables.Where(dataset.VariableHasType<double>).ToArray();
     73      var factorVariableNames = allowedInputVariables.Where(dataset.VariableHasType<string>).ToArray();
     74      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, doubleVariableNames.Concat(new string[] { targetVariable }), rows);
     75
     76      var factorVariables = AlglibUtil.GetFactorVariableValues(dataset, factorVariableNames, rows);
     77      double[,] factorMatrix = AlglibUtil.PrepareInputMatrix(dataset, factorVariables, rows);
     78
     79      inputMatrix = factorMatrix.VertCat(inputMatrix);
     80
    7381      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    7482        throw new NotSupportedException("Linear discriminant analysis does not support NaN or infinity values in the input dataset.");
     
    8290      int info;
    8391      double[] w;
    84       alglib.fisherlda(inputMatrix, inputMatrix.GetLength(0), allowedInputVariables.Count(), nClasses, out info, out w);
     92      alglib.fisherlda(inputMatrix, inputMatrix.GetLength(0), inputMatrix.GetLength(1) - 1, nClasses, out info, out w);
    8593      if (info < 1) throw new ArgumentException("Error in calculation of linear discriminant analysis solution");
    8694
     
    92100
    93101      int col = 0;
    94       foreach (string column in allowedInputVariables) {
     102      foreach (var kvp in factorVariables) {
     103        var varName = kvp.Key;
     104        foreach (var cat in kvp.Value) {
     105          FactorVariableTreeNode vNode =
     106            (FactorVariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.FactorVariable().CreateTreeNode();
     107          vNode.VariableName = varName;
     108          vNode.VariableValue = cat;
     109          vNode.Weight = w[col];
     110          addition.AddSubtree(vNode);
     111          col++;
     112        }
     113      }
     114      foreach (string column in doubleVariableNames) {
    95115        VariableTreeNode vNode = (VariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
    96116        vNode.VariableName = column;
  • branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/LinearRegression.cs

    r14237 r14240  
    7575      var doubleVariables = allowedInputVariables.Where(dataset.VariableHasType<double>);
    7676      var factorVariableNames = allowedInputVariables.Where(dataset.VariableHasType<string>);
    77       var factorVariables = from factor in factorVariableNames
    78                             let distinctValues = dataset.GetStringValues(factor, rows).Distinct().ToArray()
    79                             // 1 distinct value => skip (constant)
    80                             // 2 distinct values => only take one of the two values
    81                             // >=3 distinct values => create a binary value for each value
    82                             let reducedValues = distinctValues.Length <= 2
    83                               ? distinctValues.Take(distinctValues.Length - 1)
    84                               : distinctValues
    85                             select new KeyValuePair<string, IEnumerable<string>>(factor, reducedValues);
     77      var factorVariables = AlglibUtil.GetFactorVariableValues(dataset, factorVariableNames, rows);
    8678      double[,] binaryMatrix = AlglibUtil.PrepareInputMatrix(dataset, factorVariables, rows);
    8779      double[,] doubleVarMatrix = AlglibUtil.PrepareInputMatrix(dataset, doubleVariables.Concat(new string[] { targetVariable }), rows);
  • branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/MultinomialLogitClassification.cs

    r14185 r14240  
    6868      var dataset = problemData.Dataset;
    6969      string targetVariable = problemData.TargetVariable;
    70       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
     70      var doubleVariableNames = problemData.AllowedInputVariables.Where(dataset.VariableHasType<double>);
     71      var factorVariableNames = problemData.AllowedInputVariables.Where(dataset.VariableHasType<string>);
    7172      IEnumerable<int> rows = problemData.TrainingIndices;
    72       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
     73      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, doubleVariableNames.Concat(new string[] { targetVariable }), rows);
     74
     75      var factorVariableValues = AlglibUtil.GetFactorVariableValues(dataset, factorVariableNames, rows);
     76      var factorMatrix = AlglibUtil.PrepareInputMatrix(dataset, factorVariableValues, rows);
     77      inputMatrix = factorMatrix.VertCat(inputMatrix);
     78
    7379      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    7480        throw new NotSupportedException("Multinomial logit classification does not support NaN or infinity values in the input dataset.");
     
    95101      relClassError = alglib.mnlrelclserror(lm, inputMatrix, nRows);
    96102
    97       MultinomialLogitClassificationSolution solution = new MultinomialLogitClassificationSolution(new MultinomialLogitModel(lm, targetVariable, allowedInputVariables, classValues), (IClassificationProblemData)problemData.Clone());
     103      MultinomialLogitClassificationSolution solution = new MultinomialLogitClassificationSolution(new MultinomialLogitModel(lm, targetVariable, doubleVariableNames, factorVariableValues, classValues), (IClassificationProblemData)problemData.Clone());
    98104      return solution;
    99105    }
  • branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/MultinomialLogitModel.cs

    r14185 r14240  
    5656    [Storable]
    5757    private double[] classValues;
     58    [Storable]
     59    private List<KeyValuePair<string, IEnumerable<string>>> factorVariables;
     60
    5861    [StorableConstructor]
    5962    private MultinomialLogitModel(bool deserializing)
     
    6871      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
    6972      classValues = (double[])original.classValues.Clone();
     73      this.factorVariables = original.factorVariables.Select(kvp => new KeyValuePair<string, IEnumerable<string>>(kvp.Key, new List<string>(kvp.Value))).ToList();
    7074    }
    71     public MultinomialLogitModel(alglib.logitmodel logitModel, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues)
     75    public MultinomialLogitModel(alglib.logitmodel logitModel, string targetVariable, IEnumerable<string> doubleInputVariables, IEnumerable<KeyValuePair<string, IEnumerable<string>>> factorVariables, double[] classValues)
    7276      : base(targetVariable) {
    7377      this.name = ItemName;
    7478      this.description = ItemDescription;
    7579      this.logitModel = logitModel;
    76       this.allowedInputVariables = allowedInputVariables.ToArray();
     80      this.allowedInputVariables = doubleInputVariables.ToArray();
     81      this.factorVariables = factorVariables.Select(kvp => new KeyValuePair<string, IEnumerable<string>>(kvp.Key, new List<string>(kvp.Value))).ToList();
    7782      this.classValues = (double[])classValues.Clone();
     83    }
     84
     85    [StorableHook(HookType.AfterDeserialization)]
     86    private void AfterDeserialization() {
     87      // BackwardsCompatibility3.3
     88      #region Backwards compatible code, remove with 3.4
     89      factorVariables = new List<KeyValuePair<string, IEnumerable<string>>>();
     90      #endregion
    7891    }
    7992
     
    8396
    8497    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
     98
    8599      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     100      double[,] factorData = AlglibUtil.PrepareInputMatrix(dataset, factorVariables, rows);
     101
     102      inputData = factorData.VertCat(inputData);
    86103
    87104      int n = inputData.GetLength(0);
Note: See TracChangeset for help on using the changeset viewer.