Changeset 13940


Ignore:
Timestamp:
06/28/16 12:53:54 (5 years ago)
Author:
gkronber
Message:

#745:

  • added scatterplot of R² values over lambda instead of line chart,
  • normalized coefficient values in coefficient path chart
  • changed parameter lambda to LogLambda
File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Algorithms.DataAnalysis.Glmnet/3.4/ElasticNetLinearRegression.cs

    r13930 r13940  
    2121  public sealed class ElasticNetLinearRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
    2222    private const string PenalityParameterName = "Penality";
    23     private const string LambdaParameterName = "Lambda";
     23    private const string LogLambdaParameterName = "Log10(Lambda)";
    2424    #region parameters
    2525    public IFixedValueParameter<DoubleValue> PenalityParameter {
    2626      get { return (IFixedValueParameter<DoubleValue>)Parameters[PenalityParameterName]; }
    2727    }
    28     public IValueParameter<DoubleValue> LambdaParameter {
    29       get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
     28    public IValueParameter<DoubleValue> LogLambdaParameter {
     29      get { return (IValueParameter<DoubleValue>)Parameters[LogLambdaParameterName]; }
    3030    }
    3131    #endregion
     
    3535      set { PenalityParameter.Value.Value = value; }
    3636    }
    37     public DoubleValue Lambda {
    38       get { return LambdaParameter.Value; }
    39       set { LambdaParameter.Value = value; }
     37    public DoubleValue LogLambda {
     38      get { return LogLambdaParameter.Value; }
     39      set { LogLambdaParameter.Value = value; }
    4040    }
    4141    #endregion
     
    4646      : base(original, cloner) {
    4747    }
    48     public ElasticNetLinearRegression() : base() {
     48    public ElasticNetLinearRegression()
     49      : base() {
    4950      Problem = new RegressionProblem();
    5051      Parameters.Add(new FixedValueParameter<DoubleValue>(PenalityParameterName, "Penalty factor for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.5)));
    51       Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
     52      Parameters.Add(new OptionalValueParameter<DoubleValue>(LogLambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
    5253    }
    5354
     
    6061
    6162    protected override void Run() {
    62       if (Lambda == null) {
     63      if (LogLambda == null) {
    6364        CreateSolutionPath();
    6465      } else {
    65         CreateSolution(Lambda.Value);
    66       }
    67     }
    68 
    69     private void CreateSolution(double lambda) {
     66        CreateSolution(LogLambda.Value);
     67      }
     68    }
     69
     70    private void CreateSolution(double logLambda) {
    7071      double trainRsq;
    7172      double testRsq;
    72       var coeff = CreateElasticNetLinearRegressionSolution(Problem.ProblemData, Penality, lambda, out trainRsq, out testRsq);
     73      var coeff = CreateElasticNetLinearRegressionSolution(Problem.ProblemData, Penality, Math.Pow(10, logLambda), out trainRsq, out testRsq);
    7374      Results.Add(new Result("R² (train)", new DoubleValue(trainRsq)));
    7475      Results.Add(new Result("R² (test)", new DoubleValue(testRsq)));
     
    113114      RunElasticNetLinearRegression(Problem.ProblemData, Penality, out lambda, out trainRsq, out testRsq, out coeff, out intercept);
    114115
    115       var coeffTable = new DataTable("Coefficient Paths", "The paths of coefficient values over different lambda values");
     116      var coeffTable = new DataTable("Coefficient Paths", "The paths of standarized coefficient values over different lambda values");
    116117      var nLambdas = lambda.Length;
    117118      var nCoeff = coeff.GetLength(1);
     
    120121      for (int i = 0; i < nCoeff; i++) {
    121122        var coeffId = allowedVars[i];
    122         var path = Enumerable.Range(0, nLambdas).Select(r => coeff[r, i]).ToArray();
     123        double sigma = Problem.ProblemData.Dataset.GetDoubleValues(coeffId).StandardDeviation();
     124        var path = Enumerable.Range(0, nLambdas).Select(r => coeff[r, i] * sigma).ToArray();
    123125        dataRows[i] = new DataRow(coeffId, coeffId, path);
    124126        coeffTable.Rows.Add(dataRows[i]);
     
    127129      Results.Add(new Result(coeffTable.Name, coeffTable.Description, coeffTable));
    128130
    129       var rsqTable = new DataTable("R-Squared", "Path of R² values over different lambda values");
    130       rsqTable.Rows.Add(new DataRow("R² (train)", "Path of R² values over different lambda values", trainRsq));
    131       rsqTable.Rows.Add(new DataRow("R² (test)", "Path of R² values over different lambda values", testRsq));
    132       rsqTable.Rows.Add(new DataRow("Lambda", "The lambda values along the path", lambda));
    133       rsqTable.Rows["Lambda"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
    134       rsqTable.Rows["Lambda"].VisualProperties.SecondYAxis = true;
    135       rsqTable.VisualProperties.SecondYAxisMinimumFixedValue = 1E-7;
    136       rsqTable.VisualProperties.SecondYAxisLogScale = true;
    137       Results.Add(new Result(rsqTable.Name, rsqTable.Description, rsqTable));
     131      var rsqPlot = new ScatterPlot("R-Squared", "Path of R² values over different lambda values");
     132      rsqPlot.VisualProperties.YAxisMaximumAuto = false;
     133      rsqPlot.VisualProperties.YAxisMinimumAuto = false;
     134      rsqPlot.VisualProperties.XAxisMaximumAuto = false;
     135      rsqPlot.VisualProperties.XAxisMinimumAuto = false;
     136
     137      rsqPlot.VisualProperties.YAxisMinimumFixedValue = 0;
     138      rsqPlot.VisualProperties.YAxisMaximumFixedValue = 1.0;
     139      rsqPlot.VisualProperties.XAxisTitle = "Log10(Lambda)";
     140      rsqPlot.VisualProperties.YAxisTitle = "R²";
     141      rsqPlot.Rows.Add(new ScatterPlotDataRow("R² (train)", "Path of R² values over different lambda values", lambda.Zip(trainRsq, (l, r) => new Point2D<double>(Math.Log10(l), r))));
     142      rsqPlot.Rows.Add(new ScatterPlotDataRow("R² (test)", "Path of R² values over different lambda values", lambda.Zip(testRsq, (l, r) => new Point2D<double>(Math.Log10(l), r))));
     143      if (lambda.Length > 2) {
     144        rsqPlot.VisualProperties.XAxisMinimumFixedValue = Math.Floor(Math.Log10(lambda.Last()));
     145        rsqPlot.VisualProperties.XAxisMaximumFixedValue = Math.Ceiling(Math.Log10(lambda.Skip(1).First()));
     146      }
     147      rsqPlot.Rows["R² (train)"].VisualProperties.PointSize = 5;
     148      rsqPlot.Rows["R² (test)"].VisualProperties.PointSize = 5;
     149
     150      Results.Add(new Result(rsqPlot.Name, rsqPlot.Description, rsqPlot));
    138151    }
    139152
     
    227240      int nx = numVars;
    228241      double thr = 1.0e-5; // default value as recommended in glmnet
    229       int isd = 0; //  => regression on original predictor variables
     242      int isd = 1; //  => regression on standardized predictor variables
    230243      int intr = 1;  // => do include intercept in model
    231244      int maxit = 100000; // default value as recommended in glmnet
Note: See TracChangeset for help on using the changeset viewer.