Changeset 13930


Ignore:
Timestamp:
06/21/16 09:00:35 (5 years ago)
Author:
gkronber
Message:

#745: added parameter lambda to support calculation of a solution for a specific lambda value (instead of the full path)

Location:
branches/HeuristicLab.Algorithms.DataAnalysis.Glmnet/3.4
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Algorithms.DataAnalysis.Glmnet/3.4/ElasticNetLinearRegression.cs

    r13929 r13930  
    77using HeuristicLab.Core;
    88using HeuristicLab.Data;
     9using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    910using HeuristicLab.Optimization;
    1011using HeuristicLab.Parameters;
    1112using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    1213using HeuristicLab.Problems.DataAnalysis;
     14using HeuristicLab.Problems.DataAnalysis.Symbolic;
     15using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    1316
    1417namespace HeuristicLab.LibGlmNet {
     
    1821  public sealed class ElasticNetLinearRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
    1922    private const string PenalityParameterName = "Penality";
     23    private const string LambdaParameterName = "Lambda";
    2024    #region parameters
    2125    public IFixedValueParameter<DoubleValue> PenalityParameter {
    2226      get { return (IFixedValueParameter<DoubleValue>)Parameters[PenalityParameterName]; }
     27    }
     28    public IValueParameter<DoubleValue> LambdaParameter {
     29      get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
    2330    }
    2431    #endregion
     
    2835      set { PenalityParameter.Value.Value = value; }
    2936    }
     37    public DoubleValue Lambda {
     38      get { return LambdaParameter.Value; }
     39      set { LambdaParameter.Value = value; }
     40    }
    3041    #endregion
    3142
     
    3849      Problem = new RegressionProblem();
    3950      Parameters.Add(new FixedValueParameter<DoubleValue>(PenalityParameterName, "Penalty factor for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.5)));
     51      Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
    4052    }
    4153
     
    4860
    4961    protected override void Run() {
     62      if (Lambda == null) {
     63        CreateSolutionPath();
     64      } else {
     65        CreateSolution(Lambda.Value);
     66      }
     67    }
     68
     69    private void CreateSolution(double lambda) {
     70      double trainRsq;
     71      double testRsq;
     72      var coeff = CreateElasticNetLinearRegressionSolution(Problem.ProblemData, Penality, lambda, out trainRsq, out testRsq);
     73      Results.Add(new Result("R² (train)", new DoubleValue(trainRsq)));
     74      Results.Add(new Result("R² (test)", new DoubleValue(testRsq)));
     75
     76      // copied from LR => TODO: reuse code (but skip coefficients = 0.0)
     77      ISymbolicExpressionTree tree = new SymbolicExpressionTree(new ProgramRootSymbol().CreateTreeNode());
     78      ISymbolicExpressionTreeNode startNode = new StartSymbol().CreateTreeNode();
     79      tree.Root.AddSubtree(startNode);
     80      ISymbolicExpressionTreeNode addition = new Addition().CreateTreeNode();
     81      startNode.AddSubtree(addition);
     82
     83      int col = 0;
     84      foreach (string column in Problem.ProblemData.AllowedInputVariables) {
     85        if (!coeff[col].IsAlmost(0.0)) {
     86          VariableTreeNode vNode = (VariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
     87          vNode.VariableName = column;
     88          vNode.Weight = coeff[col];
     89          addition.AddSubtree(vNode);
     90        }
     91        col++;
     92      }
     93
     94      if (!coeff[coeff.Length - 1].IsAlmost(0.0)) {
     95        ConstantTreeNode cNode = (ConstantTreeNode)new Constant().CreateTreeNode();
     96        cNode.Value = coeff[coeff.Length - 1];
     97        addition.AddSubtree(cNode);
     98      }
     99
     100      SymbolicRegressionSolution solution = new SymbolicRegressionSolution(new SymbolicRegressionModel(tree, new SymbolicDataAnalysisExpressionTreeInterpreter()), (IRegressionProblemData)Problem.ProblemData.Clone());
     101      solution.Model.Name = "Elastic-net Linear Regression Model";
     102      solution.Name = "Elastic-net Linear Regression Solution";
     103
     104      Results.Add(new Result(solution.Name, solution.Description, solution));
     105    }
     106
     107    private void CreateSolutionPath() {
    50108      double[] lambda;
    51109      double[] trainRsq;
     
    72130      rsqTable.Rows.Add(new DataRow("R² (train)", "Path of R² values over different lambda values", trainRsq));
    73131      rsqTable.Rows.Add(new DataRow("R² (test)", "Path of R² values over different lambda values", testRsq));
     132      rsqTable.Rows.Add(new DataRow("Lambda", "The lambda values along the path", lambda));
     133      rsqTable.Rows["Lambda"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
     134      rsqTable.Rows["Lambda"].VisualProperties.SecondYAxis = true;
     135      rsqTable.VisualProperties.SecondYAxisMinimumFixedValue = 1E-7;
     136      rsqTable.VisualProperties.SecondYAxisLogScale = true;
    74137      Results.Add(new Result(rsqTable.Name, rsqTable.Description, rsqTable));
    75138    }
    76139
    77140    public static double[] CreateElasticNetLinearRegressionSolution(IRegressionProblemData problemData, double penalty, double lambda,
    78             out double trainRsq,  out double testRsq,
     141            out double trainRsq, out double testRsq,
    79142            double coeffLowerBound = double.NegativeInfinity, double coeffUpperBound = double.PositiveInfinity) {
    80143      double[] trainRsqs;
     
    192255        modval(intercept[solIdx], selectedCa, ia, selectedNin, numTestObs, testX, out fn);
    193256        OnlineCalculatorError error;
    194         var r  = OnlinePearsonsRCalculator.Calculate(testY, fn, out error);
     257        var r = OnlinePearsonsRCalculator.Calculate(testY, fn, out error);
    195258        if (error != OnlineCalculatorError.None) r = 0;
    196259        testRsq[solIdx] = r * r;
     
    204267    }
    205268
    206     private static void PrepareData(IRegressionProblemData problemData, out double[,] trainX, out double[] trainY, out int numTrainObs, 
     269    private static void PrepareData(IRegressionProblemData problemData, out double[,] trainX, out double[] trainY, out int numTrainObs,
    207270      out double[,] testX, out double[] testY, out int numTestObs, out int numVars) {
    208271      numVars = problemData.AllowedInputVariables.Count();
     
    229292      // test
    230293      rIdx = 0;
    231       foreach(var row in problemData.TestIndices) {
     294      foreach (var row in problemData.TestIndices) {
    232295        int cIdx = 0;
    233         foreach(var var in problemData.AllowedInputVariables) {
     296        foreach (var var in problemData.AllowedInputVariables) {
    234297          testX[cIdx, rIdx] = ds.GetDoubleValue(var, row);
    235298          cIdx++;
  • branches/HeuristicLab.Algorithms.DataAnalysis.Glmnet/3.4/HeuristicLab.Algorithms.DataAnalysis.Glmnet.csproj

    r13929 r13930  
    7575      <Private>False</Private>
    7676    </Reference>
     77    <Reference Include="HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4, Version=3.4.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
     78      <SpecificVersion>False</SpecificVersion>
     79      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4.dll</HintPath>
     80    </Reference>
    7781    <Reference Include="HeuristicLab.Optimization-3.3, Version=3.3.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
    7882      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Optimization-3.3.dll</HintPath>
     
    99103      <SpecificVersion>False</SpecificVersion>
    100104      <Private>False</Private>
     105    </Reference>
     106    <Reference Include="HeuristicLab.Problems.DataAnalysis.Symbolic-3.4, Version=3.4.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
     107      <SpecificVersion>False</SpecificVersion>
     108      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Problems.DataAnalysis.Symbolic-3.4.dll</HintPath>
     109    </Reference>
     110    <Reference Include="HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4, Version=3.4.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
     111      <SpecificVersion>False</SpecificVersion>
     112      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4.dll</HintPath>
    101113    </Reference>
    102114    <Reference Include="HeuristicLab.Problems.Instances-3.3, Version=3.3.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
  • branches/HeuristicLab.Algorithms.DataAnalysis.Glmnet/3.4/Plugin.cs.frame

    r13928 r13930  
    3737  [PluginDependency("HeuristicLab.Core", "3.3")]
    3838  [PluginDependency("HeuristicLab.Data", "3.3")]
     39  [PluginDependency("HeuristicLab.Encodings.SymbolicExpressionTreeEncoding", "3.4")]
    3940  [PluginDependency("HeuristicLab.Optimization", "3.3")]
    4041  [PluginDependency("HeuristicLab.Parameters", "3.3")]
    4142  [PluginDependency("HeuristicLab.Persistence", "3.3")]
    4243  [PluginDependency("HeuristicLab.Problems.DataAnalysis", "3.4")]
     44  [PluginDependency("HeuristicLab.Problems.DataAnalysis.Symbolic", "3.4")]
     45  [PluginDependency("HeuristicLab.Problems.DataAnalysis.Symbolic.Regression", "3.4")]
    4346  [PluginDependency("HeuristicLab.Problems.Instances", "3.3")]
    4447  public class HeuristicLabAlgorithmsDataAnalysisGlmnetPlugin : PluginBase {
Note: See TracChangeset for help on using the changeset viewer.