Changeset 14395
- Timestamp:
- 11/15/16 21:53:42 (8 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Algorithms.DataAnalysis.Glmnet/3.4/ElasticNetLinearRegression.cs
r14377 r14395 40 40 public sealed class ElasticNetLinearRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> { 41 41 private const string PenalityParameterName = "Penality"; 42 private const string L ogLambdaParameterName = "Log10(Lambda)";42 private const string LambdaParameterName = "Lambda"; 43 43 #region parameters 44 44 public IFixedValueParameter<DoubleValue> PenalityParameter { 45 45 get { return (IFixedValueParameter<DoubleValue>)Parameters[PenalityParameterName]; } 46 46 } 47 public IValueParameter<DoubleValue> L ogLambdaParameter {48 get { return (IValueParameter<DoubleValue>)Parameters[L ogLambdaParameterName]; }47 public IValueParameter<DoubleValue> LambdaParameter { 48 get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; } 49 49 } 50 50 #endregion … … 54 54 set { PenalityParameter.Value.Value = value; } 55 55 } 56 public DoubleValue L ogLambda {57 get { return L ogLambdaParameter.Value; }58 set { L ogLambdaParameter.Value = value; }56 public DoubleValue Lambda { 57 get { return LambdaParameter.Value; } 58 set { LambdaParameter.Value = value; } 59 59 } 60 60 #endregion … … 69 69 Problem = new RegressionProblem(); 70 70 Parameters.Add(new FixedValueParameter<DoubleValue>(PenalityParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.5))); 71 Parameters.Add(new OptionalValueParameter<DoubleValue>(L ogLambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));71 Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas")); 72 72 } 73 73 … … 80 80 81 81 protected override void Run() { 82 if (L ogLambda == null) {82 if (Lambda == null) { 83 83 CreateSolutionPath(); 84 84 } else { 85 CreateSolution(L ogLambda.Value);86 } 87 } 88 89 private void CreateSolution(double l ogLambda) {85 CreateSolution(Lambda.Value); 86 } 87 } 88 89 private void CreateSolution(double lambda) { 90 90 double trainNMSE; 91 91 double testNMSE; 92 var coeff = CreateElasticNetLinearRegressionSolution(Problem.ProblemData, Penality, Math.Pow(10, logLambda), out trainNMSE, out testNMSE);92 var coeff = CreateElasticNetLinearRegressionSolution(Problem.ProblemData, Penality, lambda, out trainNMSE, out testNMSE); 93 93 Results.Add(new Result("NMSE (train)", new DoubleValue(trainNMSE))); 94 94 Results.Add(new Result("NMSE (test)", new DoubleValue(testNMSE))); 95 95 96 // copied from LR => TODO: reuse code (but skip coefficients = 0.0) 97 ISymbolicExpressionTree tree = new SymbolicExpressionTree(new ProgramRootSymbol().CreateTreeNode()); 98 ISymbolicExpressionTreeNode startNode = new StartSymbol().CreateTreeNode(); 99 tree.Root.AddSubtree(startNode); 100 ISymbolicExpressionTreeNode addition = new Addition().CreateTreeNode(); 101 startNode.AddSubtree(addition); 102 103 int col = 0; 104 foreach (string column in Problem.ProblemData.AllowedInputVariables) { 105 if (!coeff[col].IsAlmost(0.0)) { 106 VariableTreeNode vNode = (VariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode(); 107 vNode.VariableName = column; 108 vNode.Weight = coeff[col]; 109 addition.AddSubtree(vNode); 110 } 111 col++; 112 } 113 114 if (!coeff[coeff.Length - 1].IsAlmost(0.0)) { 115 ConstantTreeNode cNode = (ConstantTreeNode)new Constant().CreateTreeNode(); 116 cNode.Value = coeff[coeff.Length - 1]; 117 addition.AddSubtree(cNode); 118 } 96 var allVariables = Problem.ProblemData.AllowedInputVariables.ToArray(); 97 98 var remainingVars = Enumerable.Range(0, allVariables.Length) 99 .Where(idx => !coeff[idx].IsAlmost(0.0)).Select(idx => allVariables[idx]) 100 .ToArray(); 101 var remainingCoeff = Enumerable.Range(0, allVariables.Length) 102 .Select(idx => coeff[idx]) 103 .Where(c => !c.IsAlmost(0.0)) 104 .ToArray(); 105 106 var tree = LinearModelToTreeConverter.CreateTree(remainingVars, remainingCoeff, coeff.Last()); 107 119 108 120 109 SymbolicRegressionSolution solution = new SymbolicRegressionSolution( … … 142 131 143 132 coeffTable.VisualProperties.XAxisLogScale = true; 144 coeffTable.VisualProperties.XAxisTitle = "L og10(Lambda)";133 coeffTable.VisualProperties.XAxisTitle = "Lambda"; 145 134 coeffTable.VisualProperties.YAxisTitle = "Coefficients"; 146 135 coeffTable.VisualProperties.SecondYAxisTitle = "Number of variables"; … … 184 173 errorTable.VisualProperties.YAxisMaximumFixedValue = 1.0; 185 174 errorTable.VisualProperties.XAxisLogScale = true; 186 errorTable.VisualProperties.XAxisTitle = "L og10(Lambda)";175 errorTable.VisualProperties.XAxisTitle = "Lambda"; 187 176 errorTable.VisualProperties.YAxisTitle = "Normalized mean of squared errors (NMSE)"; 188 errorTable.VisualProperties.SecondYAxisTitle = "Number of variables";177 errorTable.VisualProperties.SecondYAxisTitle = "Number of variables"; 189 178 errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (train)", "Path of NMSE values over different lambda values", lambda.Zip(trainNMSE, (l, v) => Tuple.Create(l, v)))); 190 179 errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (test)", "Path of NMSE values over different lambda values", lambda.Zip(testNMSE, (l, v) => Tuple.Create(l, v)))); … … 198 187 errorTable.Rows["Number of variables"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points; 199 188 errorTable.Rows["Number of variables"].VisualProperties.SecondYAxis = true; 200 189 201 190 Results.Add(new Result(errorTable.Name, errorTable.Description, errorTable)); 202 191 } … … 270 259 double[,] trainX; 271 260 double[,] testX; 272 273 261 double[] trainY; 274 262 double[] testY; 275 int numTrainObs, numTestObs; 276 int numVars; 277 PrepareData(problemData, out trainX, out trainY, out numTrainObs, out testX, out testY, out numTestObs, out numVars); 263 264 PrepareData(problemData, out trainX, out trainY, out testX, out testY); 265 var numTrainObs = trainX.GetLength(1); 266 var numTestObs = testX.GetLength(1); 267 var numVars = trainX.GetLength(0); 278 268 279 269 int ka = 1; // => covariance updating algorithm … … 334 324 } 335 325 336 private static void PrepareData(IRegressionProblemData problemData, out double[,] trainX, out double[] trainY, out int numTrainObs, 337 out double[,] testX, out double[] testY, out int numTestObs, out int numVars) { 338 numVars = problemData.AllowedInputVariables.Count(); 339 numTrainObs = problemData.TrainingIndices.Count(); 340 numTestObs = problemData.TestIndices.Count(); 341 342 trainX = new double[numVars, numTrainObs]; 343 trainY = new double[numTrainObs]; 344 testX = new double[numVars, numTestObs]; 345 testY = new double[numTestObs]; 326 private static void PrepareData(IRegressionProblemData problemData, out double[,] trainX, out double[] trainY, 327 out double[,] testX, out double[] testY) { 328 346 329 var ds = problemData.Dataset; 347 var targetVar = problemData.TargetVariable; 348 // train 349 int rIdx = 0; 350 foreach (var row in problemData.TrainingIndices) { 351 int cIdx = 0; 352 foreach (var var in problemData.AllowedInputVariables) { 353 trainX[cIdx, rIdx] = ds.GetDoubleValue(var, row); 354 cIdx++; 355 } 356 trainY[rIdx] = ds.GetDoubleValue(targetVar, row); 357 rIdx++; 358 } 359 // test 360 rIdx = 0; 361 foreach (var row in problemData.TestIndices) { 362 int cIdx = 0; 363 foreach (var var in problemData.AllowedInputVariables) { 364 testX[cIdx, rIdx] = ds.GetDoubleValue(var, row); 365 cIdx++; 366 } 367 testY[rIdx] = ds.GetDoubleValue(targetVar, row); 368 rIdx++; 369 } 370 } 371 330 trainX = ds.ToArray(problemData.AllowedInputVariables, problemData.TrainingIndices); 331 trainX = trainX.Transpose(); 332 trainY = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, 333 problemData.TrainingIndices) 334 .ToArray(); 335 testX = ds.ToArray(problemData.AllowedInputVariables, problemData.TestIndices); 336 testX = testX.Transpose(); 337 testY = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, 338 problemData.TestIndices) 339 .ToArray(); 340 } 372 341 } 373 342 }
Note: See TracChangeset
for help on using the changeset viewer.