Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/15/16 21:53:42 (7 years ago)
Author:
gkronber
Message:

#745: code simplification using functionality from refactored trunk and fixed lambda parameter

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Algorithms.DataAnalysis.Glmnet/3.4/ElasticNetLinearRegression.cs

    r14377 r14395  
    4040  public sealed class ElasticNetLinearRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
    4141    private const string PenalityParameterName = "Penality";
    42     private const string LogLambdaParameterName = "Log10(Lambda)";
     42    private const string LambdaParameterName = "Lambda";
    4343    #region parameters
    4444    public IFixedValueParameter<DoubleValue> PenalityParameter {
    4545      get { return (IFixedValueParameter<DoubleValue>)Parameters[PenalityParameterName]; }
    4646    }
    47     public IValueParameter<DoubleValue> LogLambdaParameter {
    48       get { return (IValueParameter<DoubleValue>)Parameters[LogLambdaParameterName]; }
     47    public IValueParameter<DoubleValue> LambdaParameter {
     48      get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
    4949    }
    5050    #endregion
     
    5454      set { PenalityParameter.Value.Value = value; }
    5555    }
    56     public DoubleValue LogLambda {
    57       get { return LogLambdaParameter.Value; }
    58       set { LogLambdaParameter.Value = value; }
     56    public DoubleValue Lambda {
     57      get { return LambdaParameter.Value; }
     58      set { LambdaParameter.Value = value; }
    5959    }
    6060    #endregion
     
    6969      Problem = new RegressionProblem();
    7070      Parameters.Add(new FixedValueParameter<DoubleValue>(PenalityParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.5)));
    71       Parameters.Add(new OptionalValueParameter<DoubleValue>(LogLambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
     71      Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
    7272    }
    7373
     
    8080
    8181    protected override void Run() {
    82       if (LogLambda == null) {
     82      if (Lambda == null) {
    8383        CreateSolutionPath();
    8484      } else {
    85         CreateSolution(LogLambda.Value);
    86       }
    87     }
    88 
    89     private void CreateSolution(double logLambda) {
     85        CreateSolution(Lambda.Value);
     86      }
     87    }
     88
     89    private void CreateSolution(double lambda) {
    9090      double trainNMSE;
    9191      double testNMSE;
    92       var coeff = CreateElasticNetLinearRegressionSolution(Problem.ProblemData, Penality, Math.Pow(10, logLambda), out trainNMSE, out testNMSE);
     92      var coeff = CreateElasticNetLinearRegressionSolution(Problem.ProblemData, Penality, lambda, out trainNMSE, out testNMSE);
    9393      Results.Add(new Result("NMSE (train)", new DoubleValue(trainNMSE)));
    9494      Results.Add(new Result("NMSE (test)", new DoubleValue(testNMSE)));
    9595
    96       // copied from LR => TODO: reuse code (but skip coefficients = 0.0)
    97       ISymbolicExpressionTree tree = new SymbolicExpressionTree(new ProgramRootSymbol().CreateTreeNode());
    98       ISymbolicExpressionTreeNode startNode = new StartSymbol().CreateTreeNode();
    99       tree.Root.AddSubtree(startNode);
    100       ISymbolicExpressionTreeNode addition = new Addition().CreateTreeNode();
    101       startNode.AddSubtree(addition);
    102 
    103       int col = 0;
    104       foreach (string column in Problem.ProblemData.AllowedInputVariables) {
    105         if (!coeff[col].IsAlmost(0.0)) {
    106           VariableTreeNode vNode = (VariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
    107           vNode.VariableName = column;
    108           vNode.Weight = coeff[col];
    109           addition.AddSubtree(vNode);
    110         }
    111         col++;
    112       }
    113 
    114       if (!coeff[coeff.Length - 1].IsAlmost(0.0)) {
    115         ConstantTreeNode cNode = (ConstantTreeNode)new Constant().CreateTreeNode();
    116         cNode.Value = coeff[coeff.Length - 1];
    117         addition.AddSubtree(cNode);
    118       }
     96      var allVariables = Problem.ProblemData.AllowedInputVariables.ToArray();
     97
     98      var remainingVars = Enumerable.Range(0, allVariables.Length)
     99        .Where(idx => !coeff[idx].IsAlmost(0.0)).Select(idx => allVariables[idx])
     100        .ToArray();
     101      var remainingCoeff = Enumerable.Range(0, allVariables.Length)
     102        .Select(idx => coeff[idx])
     103        .Where(c => !c.IsAlmost(0.0))
     104        .ToArray();
     105
     106      var tree = LinearModelToTreeConverter.CreateTree(remainingVars, remainingCoeff, coeff.Last());
     107
    119108
    120109      SymbolicRegressionSolution solution = new SymbolicRegressionSolution(
     
    142131
    143132      coeffTable.VisualProperties.XAxisLogScale = true;
    144       coeffTable.VisualProperties.XAxisTitle = "Log10(Lambda)";
     133      coeffTable.VisualProperties.XAxisTitle = "Lambda";
    145134      coeffTable.VisualProperties.YAxisTitle = "Coefficients";
    146135      coeffTable.VisualProperties.SecondYAxisTitle = "Number of variables";
     
    184173      errorTable.VisualProperties.YAxisMaximumFixedValue = 1.0;
    185174      errorTable.VisualProperties.XAxisLogScale = true;
    186       errorTable.VisualProperties.XAxisTitle = "Log10(Lambda)";
     175      errorTable.VisualProperties.XAxisTitle = "Lambda";
    187176      errorTable.VisualProperties.YAxisTitle = "Normalized mean of squared errors (NMSE)";
    188       errorTable.VisualProperties.SecondYAxisTitle= "Number of variables";
     177      errorTable.VisualProperties.SecondYAxisTitle = "Number of variables";
    189178      errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (train)", "Path of NMSE values over different lambda values", lambda.Zip(trainNMSE, (l, v) => Tuple.Create(l, v))));
    190179      errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (test)", "Path of NMSE values over different lambda values", lambda.Zip(testNMSE, (l, v) => Tuple.Create(l, v))));
     
    198187      errorTable.Rows["Number of variables"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
    199188      errorTable.Rows["Number of variables"].VisualProperties.SecondYAxis = true;
    200                                                                                        
     189
    201190      Results.Add(new Result(errorTable.Name, errorTable.Description, errorTable));
    202191    }
     
    270259      double[,] trainX;
    271260      double[,] testX;
    272 
    273261      double[] trainY;
    274262      double[] testY;
    275       int numTrainObs, numTestObs;
    276       int numVars;
    277       PrepareData(problemData, out trainX, out trainY, out numTrainObs, out testX, out testY, out numTestObs, out numVars);
     263
     264      PrepareData(problemData, out trainX, out trainY, out testX, out testY);
     265      var numTrainObs = trainX.GetLength(1);
     266      var numTestObs = testX.GetLength(1);
     267      var numVars = trainX.GetLength(0);
    278268
    279269      int ka = 1; // => covariance updating algorithm
     
    334324    }
    335325
    336     private static void PrepareData(IRegressionProblemData problemData, out double[,] trainX, out double[] trainY, out int numTrainObs,
    337       out double[,] testX, out double[] testY, out int numTestObs, out int numVars) {
    338       numVars = problemData.AllowedInputVariables.Count();
    339       numTrainObs = problemData.TrainingIndices.Count();
    340       numTestObs = problemData.TestIndices.Count();
    341 
    342       trainX = new double[numVars, numTrainObs];
    343       trainY = new double[numTrainObs];
    344       testX = new double[numVars, numTestObs];
    345       testY = new double[numTestObs];
     326    private static void PrepareData(IRegressionProblemData problemData, out double[,] trainX, out double[] trainY,
     327      out double[,] testX, out double[] testY) {
     328
    346329      var ds = problemData.Dataset;
    347       var targetVar = problemData.TargetVariable;
    348       // train
    349       int rIdx = 0;
    350       foreach (var row in problemData.TrainingIndices) {
    351         int cIdx = 0;
    352         foreach (var var in problemData.AllowedInputVariables) {
    353           trainX[cIdx, rIdx] = ds.GetDoubleValue(var, row);
    354           cIdx++;
    355         }
    356         trainY[rIdx] = ds.GetDoubleValue(targetVar, row);
    357         rIdx++;
    358       }
    359       // test
    360       rIdx = 0;
    361       foreach (var row in problemData.TestIndices) {
    362         int cIdx = 0;
    363         foreach (var var in problemData.AllowedInputVariables) {
    364           testX[cIdx, rIdx] = ds.GetDoubleValue(var, row);
    365           cIdx++;
    366         }
    367         testY[rIdx] = ds.GetDoubleValue(targetVar, row);
    368         rIdx++;
    369       }
    370     }
    371 
     330      trainX = ds.ToArray(problemData.AllowedInputVariables, problemData.TrainingIndices);
     331      trainX = trainX.Transpose();
     332      trainY = problemData.Dataset.GetDoubleValues(problemData.TargetVariable,
     333        problemData.TrainingIndices)
     334        .ToArray();
     335      testX = ds.ToArray(problemData.AllowedInputVariables, problemData.TestIndices);
     336      testX = testX.Transpose();
     337      testY = problemData.Dataset.GetDoubleValues(problemData.TargetVariable,
     338        problemData.TestIndices)
     339        .ToArray();
     340    }
    372341  }
    373342}
Note: See TracChangeset for help on using the changeset viewer.