Changeset 13930
- Timestamp:
- 06/21/16 09:00:35 (8 years ago)
- Location:
- branches/HeuristicLab.Algorithms.DataAnalysis.Glmnet/3.4
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Algorithms.DataAnalysis.Glmnet/3.4/ElasticNetLinearRegression.cs
r13929 r13930 7 7 using HeuristicLab.Core; 8 8 using HeuristicLab.Data; 9 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 9 10 using HeuristicLab.Optimization; 10 11 using HeuristicLab.Parameters; 11 12 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 12 13 using HeuristicLab.Problems.DataAnalysis; 14 using HeuristicLab.Problems.DataAnalysis.Symbolic; 15 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; 13 16 14 17 namespace HeuristicLab.LibGlmNet { … … 18 21 public sealed class ElasticNetLinearRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> { 19 22 private const string PenalityParameterName = "Penality"; 23 private const string LambdaParameterName = "Lambda"; 20 24 #region parameters 21 25 public IFixedValueParameter<DoubleValue> PenalityParameter { 22 26 get { return (IFixedValueParameter<DoubleValue>)Parameters[PenalityParameterName]; } 27 } 28 public IValueParameter<DoubleValue> LambdaParameter { 29 get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; } 23 30 } 24 31 #endregion … … 28 35 set { PenalityParameter.Value.Value = value; } 29 36 } 37 public DoubleValue Lambda { 38 get { return LambdaParameter.Value; } 39 set { LambdaParameter.Value = value; } 40 } 30 41 #endregion 31 42 … … 38 49 Problem = new RegressionProblem(); 39 50 Parameters.Add(new FixedValueParameter<DoubleValue>(PenalityParameterName, "Penalty factor for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.5))); 51 Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas")); 40 52 } 41 53 … … 48 60 49 61 protected override void Run() { 62 if (Lambda == null) { 63 CreateSolutionPath(); 64 } else { 65 CreateSolution(Lambda.Value); 66 } 67 } 68 69 private void CreateSolution(double lambda) { 70 double trainRsq; 71 double testRsq; 72 var coeff = CreateElasticNetLinearRegressionSolution(Problem.ProblemData, Penality, lambda, out trainRsq, out testRsq); 73 Results.Add(new Result("R² (train)", new DoubleValue(trainRsq))); 74 Results.Add(new Result("R² (test)", new DoubleValue(testRsq))); 75 76 // copied from LR => TODO: reuse code (but skip coefficients = 0.0) 77 ISymbolicExpressionTree tree = new SymbolicExpressionTree(new ProgramRootSymbol().CreateTreeNode()); 78 ISymbolicExpressionTreeNode startNode = new StartSymbol().CreateTreeNode(); 79 tree.Root.AddSubtree(startNode); 80 ISymbolicExpressionTreeNode addition = new Addition().CreateTreeNode(); 81 startNode.AddSubtree(addition); 82 83 int col = 0; 84 foreach (string column in Problem.ProblemData.AllowedInputVariables) { 85 if (!coeff[col].IsAlmost(0.0)) { 86 VariableTreeNode vNode = (VariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode(); 87 vNode.VariableName = column; 88 vNode.Weight = coeff[col]; 89 addition.AddSubtree(vNode); 90 } 91 col++; 92 } 93 94 if (!coeff[coeff.Length - 1].IsAlmost(0.0)) { 95 ConstantTreeNode cNode = (ConstantTreeNode)new Constant().CreateTreeNode(); 96 cNode.Value = coeff[coeff.Length - 1]; 97 addition.AddSubtree(cNode); 98 } 99 100 SymbolicRegressionSolution solution = new SymbolicRegressionSolution(new SymbolicRegressionModel(tree, new SymbolicDataAnalysisExpressionTreeInterpreter()), (IRegressionProblemData)Problem.ProblemData.Clone()); 101 solution.Model.Name = "Elastic-net Linear Regression Model"; 102 solution.Name = "Elastic-net Linear Regression Solution"; 103 104 Results.Add(new Result(solution.Name, solution.Description, solution)); 105 } 106 107 private void CreateSolutionPath() { 50 108 double[] lambda; 51 109 double[] trainRsq; … … 72 130 rsqTable.Rows.Add(new DataRow("R² (train)", "Path of R² values over different lambda values", trainRsq)); 73 131 rsqTable.Rows.Add(new DataRow("R² (test)", "Path of R² values over different lambda values", testRsq)); 132 rsqTable.Rows.Add(new DataRow("Lambda", "The lambda values along the path", lambda)); 133 rsqTable.Rows["Lambda"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points; 134 rsqTable.Rows["Lambda"].VisualProperties.SecondYAxis = true; 135 rsqTable.VisualProperties.SecondYAxisMinimumFixedValue = 1E-7; 136 rsqTable.VisualProperties.SecondYAxisLogScale = true; 74 137 Results.Add(new Result(rsqTable.Name, rsqTable.Description, rsqTable)); 75 138 } 76 139 77 140 public static double[] CreateElasticNetLinearRegressionSolution(IRegressionProblemData problemData, double penalty, double lambda, 78 out double trainRsq, 141 out double trainRsq, out double testRsq, 79 142 double coeffLowerBound = double.NegativeInfinity, double coeffUpperBound = double.PositiveInfinity) { 80 143 double[] trainRsqs; … … 192 255 modval(intercept[solIdx], selectedCa, ia, selectedNin, numTestObs, testX, out fn); 193 256 OnlineCalculatorError error; 194 var r 257 var r = OnlinePearsonsRCalculator.Calculate(testY, fn, out error); 195 258 if (error != OnlineCalculatorError.None) r = 0; 196 259 testRsq[solIdx] = r * r; … … 204 267 } 205 268 206 private static void PrepareData(IRegressionProblemData problemData, out double[,] trainX, out double[] trainY, out int numTrainObs, 269 private static void PrepareData(IRegressionProblemData problemData, out double[,] trainX, out double[] trainY, out int numTrainObs, 207 270 out double[,] testX, out double[] testY, out int numTestObs, out int numVars) { 208 271 numVars = problemData.AllowedInputVariables.Count(); … … 229 292 // test 230 293 rIdx = 0; 231 foreach (var row in problemData.TestIndices) {294 foreach (var row in problemData.TestIndices) { 232 295 int cIdx = 0; 233 foreach (var var in problemData.AllowedInputVariables) {296 foreach (var var in problemData.AllowedInputVariables) { 234 297 testX[cIdx, rIdx] = ds.GetDoubleValue(var, row); 235 298 cIdx++; -
branches/HeuristicLab.Algorithms.DataAnalysis.Glmnet/3.4/HeuristicLab.Algorithms.DataAnalysis.Glmnet.csproj
r13929 r13930 75 75 <Private>False</Private> 76 76 </Reference> 77 <Reference Include="HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4, Version=3.4.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL"> 78 <SpecificVersion>False</SpecificVersion> 79 <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4.dll</HintPath> 80 </Reference> 77 81 <Reference Include="HeuristicLab.Optimization-3.3, Version=3.3.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL"> 78 82 <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Optimization-3.3.dll</HintPath> … … 99 103 <SpecificVersion>False</SpecificVersion> 100 104 <Private>False</Private> 105 </Reference> 106 <Reference Include="HeuristicLab.Problems.DataAnalysis.Symbolic-3.4, Version=3.4.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL"> 107 <SpecificVersion>False</SpecificVersion> 108 <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Problems.DataAnalysis.Symbolic-3.4.dll</HintPath> 109 </Reference> 110 <Reference Include="HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4, Version=3.4.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL"> 111 <SpecificVersion>False</SpecificVersion> 112 <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4.dll</HintPath> 101 113 </Reference> 102 114 <Reference Include="HeuristicLab.Problems.Instances-3.3, Version=3.3.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL"> -
branches/HeuristicLab.Algorithms.DataAnalysis.Glmnet/3.4/Plugin.cs.frame
r13928 r13930 37 37 [PluginDependency("HeuristicLab.Core", "3.3")] 38 38 [PluginDependency("HeuristicLab.Data", "3.3")] 39 [PluginDependency("HeuristicLab.Encodings.SymbolicExpressionTreeEncoding", "3.4")] 39 40 [PluginDependency("HeuristicLab.Optimization", "3.3")] 40 41 [PluginDependency("HeuristicLab.Parameters", "3.3")] 41 42 [PluginDependency("HeuristicLab.Persistence", "3.3")] 42 43 [PluginDependency("HeuristicLab.Problems.DataAnalysis", "3.4")] 44 [PluginDependency("HeuristicLab.Problems.DataAnalysis.Symbolic", "3.4")] 45 [PluginDependency("HeuristicLab.Problems.DataAnalysis.Symbolic.Regression", "3.4")] 43 46 [PluginDependency("HeuristicLab.Problems.Instances", "3.3")] 44 47 public class HeuristicLabAlgorithmsDataAnalysisGlmnetPlugin : PluginBase {
Note: See TracChangeset
for help on using the changeset viewer.