Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/04/17 17:52:44 (8 years ago)
Author:
gkronber
Message:

#2650: merged the factors branch into trunk

Location:
trunk/sources
Files:
16 edited
2 copied

Legend:

Unmodified
Added
Removed
  • trunk/sources

  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis

  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers/OneR.cs

    r14523 r14826  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     
    6566
    6667    public static IClassificationSolution CreateOneRSolution(IClassificationProblemData problemData, int minBucketSize = 6) {
     68      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     69      var model1 = FindBestDoubleVariableModel(problemData, minBucketSize);
     70      var model2 = FindBestFactorModel(problemData);
     71
     72      if (model1 == null && model2 == null) throw new InvalidProgramException("Could not create OneR solution");
     73      else if (model1 == null) return new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone());
     74      else if (model2 == null) return new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone());
     75      else {
     76        var model1EstimatedValues = model1.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     77        var model1NumCorrect = classValues.Zip(model1EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     78
     79        var model2EstimatedValues = model2.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     80        var model2NumCorrect = classValues.Zip(model2EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     81
     82        if (model1NumCorrect > model2NumCorrect) {
     83          return new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone());
     84        } else {
     85          return new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone());
     86        }
     87      }
     88    }
     89
     90    private static OneRClassificationModel FindBestDoubleVariableModel(IClassificationProblemData problemData, int minBucketSize = 6) {
    6791      var bestClassified = 0;
    6892      List<Split> bestSplits = null;
     
    7195      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
    7296
    73       foreach (var variable in problemData.AllowedInputVariables) {
     97      var allowedInputVariables = problemData.AllowedInputVariables.Where(problemData.Dataset.VariableHasType<double>);
     98
     99      if (!allowedInputVariables.Any()) return null;
     100
     101      foreach (var variable in allowedInputVariables) {
    74102        var inputValues = problemData.Dataset.GetDoubleValues(variable, problemData.TrainingIndices);
    75103        var samples = inputValues.Zip(classValues, (i, v) => new Sample(i, v)).OrderBy(s => s.inputValue);
    76104
    77         var missingValuesDistribution = samples.Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue).ToDictionary(s => s.Key, s => s.Count()).MaxItems(s => s.Value).FirstOrDefault();
     105        var missingValuesDistribution = samples
     106          .Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue)
     107          .ToDictionary(s => s.Key, s => s.Count())
     108          .MaxItems(s => s.Value)
     109          .FirstOrDefault();
    78110
    79111        //calculate class distributions for all distinct inputValues
     
    120152          while (sample.inputValue >= splits[splitIndex].thresholdValue)
    121153            splitIndex++;
    122           correctClassified += sample.classValue == splits[splitIndex].classValue ? 1 : 0;
     154          correctClassified += sample.classValue.IsAlmost(splits[splitIndex].classValue) ? 1 : 0;
    123155        }
    124156        correctClassified += missingValuesDistribution.Value;
     
    134166      //remove neighboring splits with the same class value
    135167      for (int i = 0; i < bestSplits.Count - 1; i++) {
    136         if (bestSplits[i].classValue == bestSplits[i + 1].classValue) {
     168        if (bestSplits[i].classValue.IsAlmost(bestSplits[i + 1].classValue)) {
    137169          bestSplits.Remove(bestSplits[i]);
    138170          i--;
     
    140172      }
    141173
    142       var model = new OneRClassificationModel(problemData.TargetVariable, bestVariable, bestSplits.Select(s => s.thresholdValue).ToArray(), bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass);
    143       var solution = new OneRClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
    144 
    145       return solution;
     174      var model = new OneRClassificationModel(problemData.TargetVariable, bestVariable,
     175        bestSplits.Select(s => s.thresholdValue).ToArray(),
     176        bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass);
     177
     178      return model;
     179    }
     180    private static OneFactorClassificationModel FindBestFactorModel(IClassificationProblemData problemData) {
     181      var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     182      var defaultClass = FindMostFrequentClassValue(classValues);
     183      // only select string variables
     184      var allowedInputVariables = problemData.AllowedInputVariables.Where(problemData.Dataset.VariableHasType<string>);
     185
     186      if (!allowedInputVariables.Any()) return null;
     187
     188      OneFactorClassificationModel bestModel = null;
     189      var bestModelNumCorrect = 0;
     190
     191      foreach (var variable in allowedInputVariables) {
     192        var variableValues = problemData.Dataset.GetStringValues(variable, problemData.TrainingIndices);
     193        var groupedClassValues = variableValues
     194          .Zip(classValues, (v, c) => new KeyValuePair<string, double>(v, c))
     195          .GroupBy(kvp => kvp.Key)
     196          .ToDictionary(g => g.Key, g => FindMostFrequentClassValue(g.Select(kvp => kvp.Value)));
     197
     198        var model = new OneFactorClassificationModel(problemData.TargetVariable, variable,
     199          groupedClassValues.Select(kvp => kvp.Key).ToArray(), groupedClassValues.Select(kvp => kvp.Value).ToArray(), defaultClass);
     200
     201        var modelEstimatedValues = model.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
     202        var modelNumCorrect = classValues.Zip(modelEstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);
     203        if (modelNumCorrect > bestModelNumCorrect) {
     204          bestModelNumCorrect = modelNumCorrect;
     205          bestModel = model;
     206        }
     207      }
     208
     209      return bestModel;
     210    }
     211
     212    private static double FindMostFrequentClassValue(IEnumerable<double> classValues) {
     213      return classValues.GroupBy(c => c).OrderByDescending(g => g.Count()).Select(g => g.Key).First();
    146214    }
    147215
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers/OneRClassificationModel.cs

    r14185 r14826  
    3131  [StorableClass]
    3232  [Item("OneR Classification Model", "A model that uses intervals for one variable to determine the class.")]
    33   public class OneRClassificationModel : ClassificationModel {
     33  public sealed class OneRClassificationModel : ClassificationModel {
    3434    public override IEnumerable<string> VariablesUsedForPrediction {
    3535      get { return new[] { Variable }; }
     
    3737
    3838    [Storable]
    39     protected string variable;
     39    private string variable;
    4040    public string Variable {
    4141      get { return variable; }
     
    4343
    4444    [Storable]
    45     protected double[] splits;
     45    private double[] splits;
    4646    public double[] Splits {
    4747      get { return splits; }
     
    4949
    5050    [Storable]
    51     protected double[] classes;
     51    private double[] classes;
    5252    public double[] Classes {
    5353      get { return classes; }
     
    5555
    5656    [Storable]
    57     protected double missingValuesClass;
     57    private double missingValuesClass;
    5858    public double MissingValuesClass {
    5959      get { return missingValuesClass; }
     
    6161
    6262    [StorableConstructor]
    63     protected OneRClassificationModel(bool deserializing) : base(deserializing) { }
    64     protected OneRClassificationModel(OneRClassificationModel original, Cloner cloner)
     63    private OneRClassificationModel(bool deserializing) : base(deserializing) { }
     64    private OneRClassificationModel(OneRClassificationModel original, Cloner cloner)
    6565      : base(original, cloner) {
    6666      this.variable = (string)original.variable;
    6767      this.splits = (double[])original.splits.Clone();
    6868      this.classes = (double[])original.classes.Clone();
     69      this.missingValuesClass = original.missingValuesClass;
    6970    }
    7071    public override IDeepCloneable Clone(Cloner cloner) { return new OneRClassificationModel(this, cloner); }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/BaselineClassifiers/OneRClassificationSolution.cs

    r14185 r14826  
    2828  [StorableClass]
    2929  [Item(Name = "OneR Classification Solution", Description = "Represents a OneR classification solution which uses only a single feature with potentially multiple thresholds for class prediction.")]
    30   public class OneRClassificationSolution : ClassificationSolution {
     30  public sealed class OneRClassificationSolution : ClassificationSolution {
    3131    public new OneRClassificationModel Model {
    3232      get { return (OneRClassificationModel)base.Model; }
     
    3535
    3636    [StorableConstructor]
    37     protected OneRClassificationSolution(bool deserializing) : base(deserializing) { }
    38     protected OneRClassificationSolution(OneRClassificationSolution original, Cloner cloner) : base(original, cloner) { }
     37    private OneRClassificationSolution(bool deserializing) : base(deserializing) { }
     38    private OneRClassificationSolution(OneRClassificationSolution original, Cloner cloner) : base(original, cloner) { }
    3939    public OneRClassificationSolution(OneRClassificationModel model, IClassificationProblemData problemData)
    4040      : base(model, problemData) {
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessClassificationModelCreator.cs

    r14185 r14826  
    6767        HyperparameterGradientsParameter.ActualValue = new RealVector(model.HyperparameterGradients);
    6868        return base.Apply();
    69       } catch (ArgumentException) { } catch (alglib.alglibexception) { }
     69      } catch (ArgumentException) {
     70      } catch (alglib.alglibexception) {
     71      }
    7072      NegativeLogLikelihoodParameter.ActualValue = new DoubleValue(1E300);
    7173      HyperparameterGradientsParameter.ActualValue = new RealVector(Hyperparameter.Count());
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs

    r14185 r14826  
    148148    // for custom stepping & termination
    149149    public static IGbmState CreateGbmState(IRegressionProblemData problemData, ILossFunction lossFunction, uint randSeed, int maxSize = 3, double r = 0.66, double m = 0.5, double nu = 0.01) {
     150      // check input variables. Only double variables are allowed.
     151      var invalidInputs =
     152        problemData.AllowedInputVariables.Where(name => !problemData.Dataset.VariableHasType<double>(name));
     153      if (invalidInputs.Any())
     154        throw new NotSupportedException("Gradient tree boosting only supports real-valued variables. Unsupported inputs: " + string.Join(", ", invalidInputs));
     155
    150156      return new GbmState(problemData, lossFunction, randSeed, maxSize, r, m, nu);
    151157    }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r14400 r14826  
    122122  </ItemGroup>
    123123  <ItemGroup>
     124    <Compile Include="BaselineClassifiers\OneFactorClassificationModel.cs" />
     125    <Compile Include="BaselineClassifiers\OneFactorClassificationSolution.cs" />
    124126    <Compile Include="BaselineClassifiers\OneR.cs" />
    125127    <Compile Include="BaselineClassifiers\OneRClassificationModel.cs" />
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/AlglibUtil.cs

    r14400 r14826  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     
    2728  public static class AlglibUtil {
    2829    public static double[,] PrepareInputMatrix(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows) {
    29       List<string> variablesList = variables.ToList();
     30      // check input variables. Only double variables are allowed.
     31      var invalidInputs =
     32        variables.Where(name => !dataset.VariableHasType<double>(name));
     33      if (invalidInputs.Any())
     34        throw new NotSupportedException("Unsupported inputs: " + string.Join(", ", invalidInputs));
     35
    3036      List<int> rowsList = rows.ToList();
    31 
    32       double[,] matrix = new double[rowsList.Count, variablesList.Count];
     37      double[,] matrix = new double[rowsList.Count, variables.Count()];
    3338
    3439      int col = 0;
     
    4550      return matrix;
    4651    }
     52
    4753    public static double[,] PrepareAndScaleInputMatrix(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows, Scaling scaling) {
     54      // check input variables. Only double variables are allowed.
     55      var invalidInputs =
     56        variables.Where(name => !dataset.VariableHasType<double>(name));
     57      if (invalidInputs.Any())
     58        throw new NotSupportedException("Unsupported inputs: " + string.Join(", ", invalidInputs));
     59
    4860      List<string> variablesList = variables.ToList();
    4961      List<int> rowsList = rows.ToList();
     
    6476      return matrix;
    6577    }
     78
     79    /// <summary>
     80    /// Prepares a binary data matrix from a number of factors and specified factor values
     81    /// </summary>
     82    /// <param name="dataset">A dataset that contains the variable values</param>
     83    /// <param name="factorVariables">An enumerable of categorical variables (factors). For each variable an enumerable of values must be specified.</param>
     84    /// <param name="rows">An enumerable of row indices for the dataset</param>
     85    /// <returns></returns>
     86    /// <remarks>Factor variables (categorical variables) are split up into multiple binary variables one for each specified value.</remarks>
     87    public static double[,] PrepareInputMatrix(
     88      IDataset dataset,
     89      IEnumerable<KeyValuePair<string, IEnumerable<string>>> factorVariables,
     90      IEnumerable<int> rows) {
     91      // check input variables. Only string variables are allowed.
     92      var invalidInputs =
     93        factorVariables.Select(kvp => kvp.Key).Where(name => !dataset.VariableHasType<string>(name));
     94      if (invalidInputs.Any())
     95        throw new NotSupportedException("Unsupported inputs: " + string.Join(", ", invalidInputs));
     96
     97      int numBinaryColumns = factorVariables.Sum(kvp => kvp.Value.Count());
     98
     99      List<int> rowsList = rows.ToList();
     100      double[,] matrix = new double[rowsList.Count, numBinaryColumns];
     101
     102      int col = 0;
     103      foreach (var kvp in factorVariables) {
     104        var varName = kvp.Key;
     105        var cats = kvp.Value;
     106        if (!cats.Any()) continue;
     107        foreach (var cat in cats) {
     108          var values = dataset.GetStringValues(varName, rows);
     109          int row = 0;
     110          foreach (var value in values) {
     111            matrix[row, col] = value == cat ? 1 : 0;
     112            row++;
     113          }
     114          col++;
     115        }
     116      }
     117      return matrix;
     118    }
     119
     120    public static IEnumerable<KeyValuePair<string, IEnumerable<string>>> GetFactorVariableValues(IDataset ds, IEnumerable<string> factorVariables, IEnumerable<int> rows) {
     121      return from factor in factorVariables
     122             let distinctValues = ds.GetStringValues(factor, rows).Distinct().ToArray()
     123             // 1 distinct value => skip (constant)
     124             // 2 distinct values => only take one of the two values
     125             // >=3 distinct values => create a binary value for each value
     126             let reducedValues = distinctValues.Length <= 2
     127               ? distinctValues.Take(distinctValues.Length - 1)
     128               : distinctValues
     129             select new KeyValuePair<string, IEnumerable<string>>(factor, reducedValues);
     130    }
    66131  }
    67132}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/LinearDiscriminantAnalysis.cs

    r14685 r14826  
    3737  /// Linear discriminant analysis classification algorithm.
    3838  /// </summary>
    39   [Item("Linear Discriminant Analysis", "Linear discriminant analysis classification algorithm (wrapper for ALGLIB).")]
     39  [Item("Linear Discriminant Analysis (LDA)", "Linear discriminant analysis classification algorithm (wrapper for ALGLIB).")]
    4040  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 100)]
    4141  [StorableClass]
     
    7171      IEnumerable<int> rows = problemData.TrainingIndices;
    7272      int nClasses = problemData.ClassNames.Count();
    73       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
     73      var doubleVariableNames = allowedInputVariables.Where(dataset.VariableHasType<double>).ToArray();
     74      var factorVariableNames = allowedInputVariables.Where(dataset.VariableHasType<string>).ToArray();
     75      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, doubleVariableNames.Concat(new string[] { targetVariable }), rows);
     76
     77      var factorVariables = AlglibUtil.GetFactorVariableValues(dataset, factorVariableNames, rows);
     78      double[,] factorMatrix = AlglibUtil.PrepareInputMatrix(dataset, factorVariables, rows);
     79
     80      inputMatrix = factorMatrix.HorzCat(inputMatrix);
     81
    7482      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    7583        throw new NotSupportedException("Linear discriminant analysis does not support NaN or infinity values in the input dataset.");
     
    8391      int info;
    8492      double[] w;
    85       alglib.fisherlda(inputMatrix, inputMatrix.GetLength(0), allowedInputVariables.Count(), nClasses, out info, out w);
     93      alglib.fisherlda(inputMatrix, inputMatrix.GetLength(0), inputMatrix.GetLength(1) - 1, nClasses, out info, out w);
    8694      if (info < 1) throw new ArgumentException("Error in calculation of linear discriminant analysis solution");
    8795
     
    93101
    94102      int col = 0;
    95       foreach (string column in allowedInputVariables) {
     103      foreach (var kvp in factorVariables) {
     104        var varName = kvp.Key;
     105        foreach (var cat in kvp.Value) {
     106          BinaryFactorVariableTreeNode vNode =
     107            (BinaryFactorVariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.BinaryFactorVariable().CreateTreeNode();
     108          vNode.VariableName = varName;
     109          vNode.VariableValue = cat;
     110          vNode.Weight = w[col];
     111          addition.AddSubtree(vNode);
     112          col++;
     113        }
     114      }
     115      foreach (string column in doubleVariableNames) {
    96116        VariableTreeNode vNode = (VariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
    97117        vNode.VariableName = column;
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/LinearRegression.cs

    r14685 r14826  
    7474      IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
    7575      IEnumerable<int> rows = problemData.TrainingIndices;
    76       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
     76      var doubleVariables = allowedInputVariables.Where(dataset.VariableHasType<double>);
     77      var factorVariableNames = allowedInputVariables.Where(dataset.VariableHasType<string>);
     78      var factorVariables = AlglibUtil.GetFactorVariableValues(dataset, factorVariableNames, rows);
     79      double[,] binaryMatrix = AlglibUtil.PrepareInputMatrix(dataset, factorVariables, rows);
     80      double[,] doubleVarMatrix = AlglibUtil.PrepareInputMatrix(dataset, doubleVariables.Concat(new string[] { targetVariable }), rows);
     81      var inputMatrix = binaryMatrix.HorzCat(doubleVarMatrix);
     82
    7783      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    7884        throw new NotSupportedException("Linear regression does not support NaN or infinity values in the input dataset.");
     
    99105
    100106      int col = 0;
    101       foreach (string column in allowedInputVariables) {
     107      foreach (var kvp in factorVariables) {
     108        var varName = kvp.Key;
     109        foreach (var cat in kvp.Value) {
     110          BinaryFactorVariableTreeNode vNode =
     111            (BinaryFactorVariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.BinaryFactorVariable().CreateTreeNode();
     112          vNode.VariableName = varName;
     113          vNode.VariableValue = cat;
     114          vNode.Weight = coefficients[col];
     115          addition.AddSubtree(vNode);
     116          col++;
     117        }
     118      }
     119      foreach (string column in doubleVariables) {
    102120        VariableTreeNode vNode = (VariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
    103121        vNode.VariableName = column;
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/MultinomialLogitClassification.cs

    r14523 r14826  
    6969      var dataset = problemData.Dataset;
    7070      string targetVariable = problemData.TargetVariable;
    71       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
     71      var doubleVariableNames = problemData.AllowedInputVariables.Where(dataset.VariableHasType<double>);
     72      var factorVariableNames = problemData.AllowedInputVariables.Where(dataset.VariableHasType<string>);
    7273      IEnumerable<int> rows = problemData.TrainingIndices;
    73       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
     74      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, doubleVariableNames.Concat(new string[] { targetVariable }), rows);
     75
     76      var factorVariableValues = AlglibUtil.GetFactorVariableValues(dataset, factorVariableNames, rows);
     77      var factorMatrix = AlglibUtil.PrepareInputMatrix(dataset, factorVariableValues, rows);
     78      inputMatrix = factorMatrix.HorzCat(inputMatrix);
     79
    7480      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    7581        throw new NotSupportedException("Multinomial logit classification does not support NaN or infinity values in the input dataset.");
     
    96102      relClassError = alglib.mnlrelclserror(lm, inputMatrix, nRows);
    97103
    98       MultinomialLogitClassificationSolution solution = new MultinomialLogitClassificationSolution(new MultinomialLogitModel(lm, targetVariable, allowedInputVariables, classValues), (IClassificationProblemData)problemData.Clone());
     104      MultinomialLogitClassificationSolution solution = new MultinomialLogitClassificationSolution(new MultinomialLogitModel(lm, targetVariable, doubleVariableNames, factorVariableValues, classValues), (IClassificationProblemData)problemData.Clone());
    99105      return solution;
    100106    }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/MultinomialLogitClassificationSolution.cs

    r14185 r14826  
    4343      : base(original, cloner) {
    4444    }
    45     public MultinomialLogitClassificationSolution( MultinomialLogitModel logitModel,IClassificationProblemData problemData)
     45    public MultinomialLogitClassificationSolution(MultinomialLogitModel logitModel, IClassificationProblemData problemData)
    4646      : base(logitModel, problemData) {
    4747    }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/MultinomialLogitModel.cs

    r14400 r14826  
    5656    [Storable]
    5757    private double[] classValues;
     58    [Storable]
     59    private List<KeyValuePair<string, IEnumerable<string>>> factorVariables;
     60
    5861    [StorableConstructor]
    5962    private MultinomialLogitModel(bool deserializing)
     
    6871      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
    6972      classValues = (double[])original.classValues.Clone();
     73      this.factorVariables = original.factorVariables.Select(kvp => new KeyValuePair<string, IEnumerable<string>>(kvp.Key, new List<string>(kvp.Value))).ToList();
    7074    }
    71     public MultinomialLogitModel(alglib.logitmodel logitModel, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues)
     75    public MultinomialLogitModel(alglib.logitmodel logitModel, string targetVariable, IEnumerable<string> doubleInputVariables, IEnumerable<KeyValuePair<string, IEnumerable<string>>> factorVariables, double[] classValues)
    7276      : base(targetVariable) {
    7377      this.name = ItemName;
    7478      this.description = ItemDescription;
    7579      this.logitModel = logitModel;
    76       this.allowedInputVariables = allowedInputVariables.ToArray();
     80      this.allowedInputVariables = doubleInputVariables.ToArray();
     81      this.factorVariables = factorVariables.Select(kvp => new KeyValuePair<string, IEnumerable<string>>(kvp.Key, new List<string>(kvp.Value))).ToList();
    7782      this.classValues = (double[])classValues.Clone();
     83    }
     84
     85    [StorableHook(HookType.AfterDeserialization)]
     86    private void AfterDeserialization() {
     87      // BackwardsCompatibility3.3
     88      #region Backwards compatible code, remove with 3.4
     89      factorVariables = new List<KeyValuePair<string, IEnumerable<string>>>();
     90      #endregion
    7891    }
    7992
     
    8396
    8497    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
     98
    8599      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     100      double[,] factorData = AlglibUtil.PrepareInputMatrix(dataset, factorVariables, rows);
     101
     102      inputData = factorData.HorzCat(inputData);
    86103
    87104      int n = inputData.GetLength(0);
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs

    r14400 r14826  
    144144      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    145145        throw new NotSupportedException(
    146           "Nearest neighbour classification does not support NaN or infinity values in the input dataset.");
     146          "Nearest neighbour model does not support NaN or infinity values in the input dataset.");
    147147
    148148      this.kdTree = new alglib.nearestneighbor.kdtree();
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NonlinearRegression/NonlinearRegression.cs

    r14523 r14826  
    2121
    2222using System;
     23using System.Collections.Generic;
    2324using System.Linq;
    2425using System.Threading;
     
    208209      var parser = new InfixExpressionParser();
    209210      var tree = parser.Parse(modelStructure);
     211      // parser handles double and string variables equally by creating a VariableTreeNode
     212      // post-process to replace VariableTreeNodes by FactorVariableTreeNodes for all string variables
     213      var factorSymbol = new FactorVariable();
     214      factorSymbol.VariableNames =
     215        problemData.AllowedInputVariables.Where(name => problemData.Dataset.VariableHasType<string>(name));
     216      factorSymbol.AllVariableNames = factorSymbol.VariableNames;
     217      factorSymbol.VariableValues =
     218        factorSymbol.VariableNames.Select(name =>
     219        new KeyValuePair<string, Dictionary<string, int>>(name,
     220        problemData.Dataset.GetReadOnlyStringValues(name).Distinct()
     221        .Select((n, i) => Tuple.Create(n, i))
     222        .ToDictionary(tup => tup.Item1, tup => tup.Item2)));
     223
     224      foreach (var parent in tree.IterateNodesPrefix().ToArray()) {
     225        for (int i = 0; i < parent.SubtreeCount; i++) {
     226          var varChild = parent.GetSubtree(i) as VariableTreeNode;
     227          var factorVarChild = parent.GetSubtree(i) as FactorVariableTreeNode;
     228          if (varChild != null && factorSymbol.VariableNames.Contains(varChild.VariableName)) {
     229            parent.RemoveSubtree(i);
     230            var factorTreeNode = (FactorVariableTreeNode)factorSymbol.CreateTreeNode();
     231            factorTreeNode.VariableName = varChild.VariableName;
     232            factorTreeNode.Weights =
     233              factorTreeNode.Symbol.GetVariableValues(factorTreeNode.VariableName).Select(_ => 1.0).ToArray();
     234            // weight = 1.0 for each value
     235            parent.InsertSubtree(i, factorTreeNode);
     236          } else if (factorVarChild != null && factorSymbol.VariableNames.Contains(factorVarChild.VariableName)) {
     237            if (factorSymbol.GetVariableValues(factorVarChild.VariableName).Count() != factorVarChild.Weights.Length)
     238              throw new ArgumentException(
     239                string.Format("Factor variable {0} needs exactly {1} weights",
     240                factorVarChild.VariableName,
     241                factorSymbol.GetVariableValues(factorVarChild.VariableName).Count()));
     242            parent.RemoveSubtree(i);
     243            var factorTreeNode = (FactorVariableTreeNode)factorSymbol.CreateTreeNode();
     244            factorTreeNode.VariableName = factorVarChild.VariableName;
     245            factorTreeNode.Weights = factorVarChild.Weights;
     246            parent.InsertSubtree(i, factorTreeNode);
     247          }
     248        }
     249      }
    210250
    211251      if (!SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree)) throw new ArgumentException("The optimizer does not support the specified model structure.");
Note: See TracChangeset for help on using the changeset viewer.