Changeset 15023


Ignore:
Timestamp:
06/03/17 19:19:18 (4 months ago)
Author:
gkronber
Message:

#745: added support for factor variables to elastic net regression

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis.Glmnet/3.4/ElasticNetLinearRegression.cs

    r14846 r15023  
    2121
    2222using System;
     23using System.Collections.Generic;
    2324using System.Linq;
    2425using System.Threading;
     
    9495      Results.Add(new Result("NMSE (test)", new DoubleValue(testNMSE)));
    9596
     97      var ds = Problem.ProblemData.Dataset;
    9698      var allVariables = Problem.ProblemData.AllowedInputVariables.ToArray();
    97 
    98       var remainingVars = Enumerable.Range(0, allVariables.Length)
    99         .Where(idx => !coeff[idx].IsAlmost(0.0)).Select(idx => allVariables[idx])
    100         .ToArray();
    101       var remainingCoeff = Enumerable.Range(0, allVariables.Length)
    102         .Select(idx => coeff[idx])
    103         .Where(c => !c.IsAlmost(0.0))
    104         .ToArray();
    105 
    106       var tree = LinearModelToTreeConverter.CreateTree(remainingVars, remainingCoeff, coeff.Last());
     99      var doubleVariables = allVariables.Where(ds.VariableHasType<double>);
     100      var factorVariableNames = allVariables.Where(ds.VariableHasType<string>);
     101      var factorVariablesAndValues = ds.GetFactorVariableValues(factorVariableNames, Enumerable.Range(0, ds.Rows)); // must consider all factor values (in train and test set)
     102
     103      List<KeyValuePair<string, IEnumerable<string>>> remainingFactorVariablesAndValues = new List<KeyValuePair<string, IEnumerable<string>>>();
     104      List<double> factorCoeff = new List<double>();
     105      List<string> remainingDoubleVariables = new List<string>();
     106      List<double> doubleVarCoeff = new List<double>();
     107
     108      {
     109        int i = 0;
     110        // find factor varibles & value combinations with non-zero coeff
     111        foreach (var factorVarAndValues in factorVariablesAndValues) {
     112          var l = new List<string>();
     113          foreach (var factorValue in factorVarAndValues.Value) {
     114            if (!coeff[i].IsAlmost(0.0)) {
     115              l.Add(factorValue);
     116              factorCoeff.Add(coeff[i]);
     117            }
     118            i++;
     119          }
     120          if (l.Any()) remainingFactorVariablesAndValues.Add(new KeyValuePair<string, IEnumerable<string>>(factorVarAndValues.Key, l));
     121        }
     122        // find double variables with non-zero coeff
     123        foreach (var doubleVar in doubleVariables) {
     124          if (!coeff[i].IsAlmost(0.0)) {
     125            remainingDoubleVariables.Add(doubleVar);
     126            doubleVarCoeff.Add(coeff[i]);
     127          }
     128          i++;
     129        }
     130      }
     131      var tree = LinearModelToTreeConverter.CreateTree(
     132        remainingFactorVariablesAndValues, factorCoeff.ToArray(),
     133        remainingDoubleVariables.ToArray(), doubleVarCoeff.ToArray(),
     134        coeff.Last());
    107135
    108136
     
    140168      var allowedVars = Problem.ProblemData.AllowedInputVariables.ToArray();
    141169      var numNonZeroCoeffs = new int[nLambdas];
    142       for (int i = 0; i < nCoeff; i++) {
    143         var coeffId = allowedVars[i];
    144         double sigma = Problem.ProblemData.Dataset.GetDoubleValues(coeffId).StandardDeviation();
    145         var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();
    146         dataRows[i] = new IndexedDataRow<double>(coeffId, coeffId, path);
    147       }
    148       // add to coeffTable by total weight (larger area under the curve => more important);
    149       foreach (var r in dataRows.OrderByDescending(r => r.Values.Select(t => t.Item2).Sum(x => Math.Abs(x)))) {
    150         coeffTable.Rows.Add(r);
     170
     171      var ds = Problem.ProblemData.Dataset;
     172      var doubleVariables = allowedVars.Where(ds.VariableHasType<double>);
     173      var factorVariableNames = allowedVars.Where(ds.VariableHasType<string>);
     174      var factorVariablesAndValues = ds.GetFactorVariableValues(factorVariableNames, Enumerable.Range(0, ds.Rows)); // must consider all factor values (in train and test set)
     175      {
     176        int i = 0;
     177        foreach (var factorVariableAndValues in factorVariablesAndValues) {
     178          foreach (var factorValue in factorVariableAndValues.Value) {
     179            double sigma = ds.GetStringValues(factorVariableAndValues.Key)
     180              .Select(s => s == factorValue ? 1.0 : 0.0)
     181              .StandardDeviation(); // calc std dev of binary indicator
     182            var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();
     183            dataRows[i] = new IndexedDataRow<double>(factorVariableAndValues.Key + "=" + factorValue, factorVariableAndValues.Key + "=" + factorValue, path);
     184            i++;
     185          }
     186        }
     187
     188        foreach (var doubleVariable in doubleVariables) {
     189          double sigma = ds.GetDoubleValues(doubleVariable).StandardDeviation();
     190          var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();
     191          dataRows[i] = new IndexedDataRow<double>(doubleVariable, doubleVariable, path);
     192          i++;
     193        }
     194        // add to coeffTable by total weight (larger area under the curve => more important);
     195        foreach (var r in dataRows.OrderByDescending(r => r.Values.Select(t => t.Item2).Sum(x => Math.Abs(x)))) {
     196          coeffTable.Rows.Add(r);
     197        }
    151198      }
    152199
     
    330377    private static void PrepareData(IRegressionProblemData problemData, out double[,] trainX, out double[] trainY,
    331378      out double[,] testX, out double[] testY) {
    332 
    333379      var ds = problemData.Dataset;
    334       trainX = ds.ToArray(problemData.AllowedInputVariables, problemData.TrainingIndices);
    335       trainX = trainX.Transpose();
    336       trainY = problemData.Dataset.GetDoubleValues(problemData.TargetVariable,
    337         problemData.TrainingIndices)
    338         .ToArray();
    339       testX = ds.ToArray(problemData.AllowedInputVariables, problemData.TestIndices);
    340       testX = testX.Transpose();
    341       testY = problemData.Dataset.GetDoubleValues(problemData.TargetVariable,
    342         problemData.TestIndices)
    343         .ToArray();
     380      var targetVariable = problemData.TargetVariable;
     381      var allowedInputs = problemData.AllowedInputVariables;
     382      trainX = PrepareInputData(ds, allowedInputs, problemData.TrainingIndices);
     383      trainY = ds.GetDoubleValues(targetVariable, problemData.TrainingIndices).ToArray();
     384
     385      testX = PrepareInputData(ds, allowedInputs, problemData.TestIndices);
     386      testY = ds.GetDoubleValues(targetVariable, problemData.TestIndices).ToArray();
     387    }
     388
     389    private static double[,] PrepareInputData(IDataset ds, IEnumerable<string> allowedInputs, IEnumerable<int> rows) {
     390      var doubleVariables = allowedInputs.Where(ds.VariableHasType<double>);
     391      var factorVariableNames = allowedInputs.Where(ds.VariableHasType<string>);
     392      var factorVariables = ds.GetFactorVariableValues(factorVariableNames, Enumerable.Range(0, ds.Rows)); // must consider all factor values (in train and test set)
     393      double[,] binaryMatrix = ds.ToArray(factorVariables, rows);
     394      double[,] doubleVarMatrix = ds.ToArray(doubleVariables, rows);
     395      var x = binaryMatrix.HorzCat(doubleVarMatrix);
     396      return x.Transpose();
    344397    }
    345398  }
Note: See TracChangeset for help on using the changeset viewer.