Changeset 15131 for stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective
- Timestamp:
- 07/06/17 10:19:37 (8 years ago)
- Location:
- stable
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
stable
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression
-
stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4
-
stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs
r14962 r15131 176 176 177 177 178 public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, bool updateVariableWeights = true, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, bool updateConstantsInTree = true) { 179 180 List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>(); 181 List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>(); 182 List<string> variableNames = new List<string>(); 183 List<int> lags = new List<int>(); 178 179 public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, 180 ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, 181 int maxIterations, bool updateVariableWeights = true, 182 double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, 183 bool updateConstantsInTree = true) { 184 185 // numeric constants in the tree become variables for constant opt 186 // variables in the tree become parameters (fixed values) for constant opt 187 // for each parameter (variable in the original tree) we store the 188 // variable name, variable value (for factor vars) and lag as a DataForVariable object. 189 // A dictionary is used to find parameters 190 var variables = new List<AutoDiff.Variable>(); 191 var parameters = new Dictionary<DataForVariable, AutoDiff.Variable>(); 184 192 185 193 AutoDiff.Term func; 186 if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, lags,updateVariableWeights, out func))194 if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, updateVariableWeights, out func)) 187 195 throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree."); 188 if (variableNames.Count == 0) return 0.0; 189 190 AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameters.ToArray()); 191 192 List<SymbolicExpressionTreeTerminalNode> terminalNodes = null; 196 if (parameters.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0 197 198 var parameterEntries = parameters.ToArray(); // order of entries must be the same for x 199 AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameterEntries.Select(kvp => kvp.Value).ToArray()); 200 201 List<SymbolicExpressionTreeTerminalNode> terminalNodes = null; // gkronber only used for extraction of initial constants 193 202 if (updateVariableWeights) 194 203 terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList(); 195 204 else 196 terminalNodes = new List<SymbolicExpressionTreeTerminalNode>(tree.Root.IterateNodesPrefix().OfType<ConstantTreeNode>()); 205 terminalNodes = new List<SymbolicExpressionTreeTerminalNode> 206 (tree.Root.IterateNodesPrefix() 207 .OfType<SymbolicExpressionTreeTerminalNode>() 208 .Where(node => node is ConstantTreeNode || node is FactorVariableTreeNode)); 197 209 198 210 //extract inital constants … … 205 217 ConstantTreeNode constantTreeNode = node as ConstantTreeNode; 206 218 VariableTreeNode variableTreeNode = node as VariableTreeNode; 219 BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode; 220 FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode; 207 221 if (constantTreeNode != null) 208 222 c[i++] = constantTreeNode.Value; 209 223 else if (updateVariableWeights && variableTreeNode != null) 210 224 c[i++] = variableTreeNode.Weight; 225 else if (updateVariableWeights && binFactorVarTreeNode != null) 226 c[i++] = binFactorVarTreeNode.Weight; 227 else if (factorVarTreeNode != null) { 228 // gkronber: a factorVariableTreeNode holds a category-specific constant therefore we can consider factors to be the same as constants 229 foreach (var w in factorVarTreeNode.Weights) c[i++] = w; 230 } 211 231 } 212 232 } … … 216 236 alglib.lsfitstate state; 217 237 alglib.lsfitreport rep; 218 int info;238 int retVal; 219 239 220 240 IDataset ds = problemData.Dataset; 221 double[,] x = new double[rows.Count(), variableNames.Count];241 double[,] x = new double[rows.Count(), parameters.Count]; 222 242 int row = 0; 223 243 foreach (var r in rows) { 224 for (int col = 0; col < variableNames.Count; col++) { 225 int lag = lags[col]; 226 x[row, col] = ds.GetDoubleValue(variableNames[col], r + lag); 244 int col = 0; 245 foreach (var kvp in parameterEntries) { 246 var info = kvp.Key; 247 int lag = info.lag; 248 if (ds.VariableHasType<double>(info.variableName)) { 249 x[row, col] = ds.GetDoubleValue(info.variableName, r + lag); 250 } else if (ds.VariableHasType<string>(info.variableName)) { 251 x[row, col] = ds.GetStringValue(info.variableName, r) == info.variableValue ? 1 : 0; 252 } else throw new InvalidProgramException("found a variable of unknown type"); 253 col++; 227 254 } 228 255 row++; … … 241 268 //alglib.lsfitsetgradientcheck(state, 0.001); 242 269 alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null); 243 alglib.lsfitresults(state, out info, out c, out rep);270 alglib.lsfitresults(state, out retVal, out c, out rep); 244 271 } catch (ArithmeticException) { 245 272 return originalQuality; … … 248 275 } 249 276 250 // info== -7 => constant optimization failed due to wrong gradient251 if ( info!= -7) UpdateConstants(tree, c.Skip(2).ToArray(), updateVariableWeights);277 //retVal == -7 => constant optimization failed due to wrong gradient 278 if (retVal != -7) UpdateConstants(tree, c.Skip(2).ToArray(), updateVariableWeights); 252 279 var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling); 253 280 … … 265 292 ConstantTreeNode constantTreeNode = node as ConstantTreeNode; 266 293 VariableTreeNode variableTreeNode = node as VariableTreeNode; 294 BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode; 295 FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode; 267 296 if (constantTreeNode != null) 268 297 constantTreeNode.Value = constants[i++]; 269 298 else if (updateVariableWeights && variableTreeNode != null) 270 299 variableTreeNode.Weight = constants[i++]; 300 else if (updateVariableWeights && binFactorVarTreeNode != null) 301 binFactorVarTreeNode.Weight = constants[i++]; 302 else if (factorVarTreeNode != null) { 303 for (int j = 0; j < factorVarTreeNode.Weights.Length; j++) 304 factorVarTreeNode.Weights[j] = constants[i++]; 305 } 271 306 } 272 307 } … … 286 321 } 287 322 288 private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, List<int> lags, bool updateVariableWeights, out AutoDiff.Term term) { 323 private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, 324 List<AutoDiff.Variable> variables, Dictionary<DataForVariable, AutoDiff.Variable> parameters, 325 bool updateVariableWeights, out AutoDiff.Term term) { 289 326 if (node.Symbol is Constant) { 290 327 var var = new AutoDiff.Variable(); … … 293 330 return true; 294 331 } 295 if (node.Symbol is Variable ) {296 var varNode = node as VariableTreeNode ;297 var par = new AutoDiff.Variable();298 parameters.Add(par);299 var iableNames.Add(varNode.VariableName);300 lags.Add(0);332 if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) { 333 var varNode = node as VariableTreeNodeBase; 334 var factorVarNode = node as BinaryFactorVariableTreeNode; 335 // factor variable values are only 0 or 1 and set in x accordingly 336 var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty; 337 var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue); 301 338 302 339 if (updateVariableWeights) { … … 309 346 return true; 310 347 } 348 if (node.Symbol is FactorVariable) { 349 var factorVarNode = node as FactorVariableTreeNode; 350 var products = new List<Term>(); 351 foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) { 352 var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue); 353 354 var wVar = new AutoDiff.Variable(); 355 variables.Add(wVar); 356 357 products.Add(AutoDiff.TermBuilder.Product(wVar, par)); 358 } 359 term = AutoDiff.TermBuilder.Sum(products); 360 return true; 361 } 311 362 if (node.Symbol is LaggedVariable) { 312 363 var varNode = node as LaggedVariableTreeNode; 313 var par = new AutoDiff.Variable(); 314 parameters.Add(par); 315 variableNames.Add(varNode.VariableName); 316 lags.Add(varNode.Lag); 364 var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag); 317 365 318 366 if (updateVariableWeights) { … … 329 377 foreach (var subTree in node.Subtrees) { 330 378 AutoDiff.Term t; 331 if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags,updateVariableWeights, out t)) {379 if (!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) { 332 380 term = null; 333 381 return false; … … 342 390 for (int i = 0; i < node.SubtreeCount; i++) { 343 391 AutoDiff.Term t; 344 if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, lags,updateVariableWeights, out t)) {392 if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, updateVariableWeights, out t)) { 345 393 term = null; 346 394 return false; … … 357 405 foreach (var subTree in node.Subtrees) { 358 406 AutoDiff.Term t; 359 if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags,updateVariableWeights, out t)) {407 if (!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) { 360 408 term = null; 361 409 return false; … … 372 420 foreach (var subTree in node.Subtrees) { 373 421 AutoDiff.Term t; 374 if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags,updateVariableWeights, out t)) {422 if (!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) { 375 423 term = null; 376 424 return false; … … 384 432 if (node.Symbol is Logarithm) { 385 433 AutoDiff.Term t; 386 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags,updateVariableWeights, out t)) {434 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) { 387 435 term = null; 388 436 return false; … … 394 442 if (node.Symbol is Exponential) { 395 443 AutoDiff.Term t; 396 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags,updateVariableWeights, out t)) {444 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) { 397 445 term = null; 398 446 return false; … … 404 452 if (node.Symbol is Square) { 405 453 AutoDiff.Term t; 406 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags,updateVariableWeights, out t)) {454 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) { 407 455 term = null; 408 456 return false; … … 414 462 if (node.Symbol is SquareRoot) { 415 463 AutoDiff.Term t; 416 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags,updateVariableWeights, out t)) {464 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) { 417 465 term = null; 418 466 return false; … … 424 472 if (node.Symbol is Sine) { 425 473 AutoDiff.Term t; 426 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags,updateVariableWeights, out t)) {474 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) { 427 475 term = null; 428 476 return false; … … 434 482 if (node.Symbol is Cosine) { 435 483 AutoDiff.Term t; 436 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags,updateVariableWeights, out t)) {484 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) { 437 485 term = null; 438 486 return false; … … 444 492 if (node.Symbol is Tangent) { 445 493 AutoDiff.Term t; 446 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags,updateVariableWeights, out t)) {494 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) { 447 495 term = null; 448 496 return false; … … 454 502 if (node.Symbol is Erf) { 455 503 AutoDiff.Term t; 456 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags,updateVariableWeights, out t)) {504 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) { 457 505 term = null; 458 506 return false; … … 464 512 if (node.Symbol is Norm) { 465 513 AutoDiff.Term t; 466 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags,updateVariableWeights, out t)) {514 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) { 467 515 term = null; 468 516 return false; … … 478 526 variables.Add(alpha); 479 527 AutoDiff.Term branchTerm; 480 if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags,updateVariableWeights, out branchTerm)) {528 if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out branchTerm)) { 481 529 term = branchTerm * alpha + beta; 482 530 return true; … … 488 536 term = null; 489 537 return false; 538 } 539 540 // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination 541 // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available 542 private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters, 543 string varName, string varValue = "", int lag = 0) { 544 var data = new DataForVariable(varName, varValue, lag); 545 546 AutoDiff.Variable par = null; 547 if (!parameters.TryGetValue(data, out par)) { 548 // not found -> create new parameter and entries in names and values lists 549 par = new AutoDiff.Variable(); 550 parameters.Add(data, par); 551 } 552 return par; 490 553 } 491 554 … … 495 558 where 496 559 !(n.Symbol is Variable) && 560 !(n.Symbol is BinaryFactorVariable) && 561 !(n.Symbol is FactorVariable) && 497 562 !(n.Symbol is LaggedVariable) && 498 563 !(n.Symbol is Constant) && … … 515 580 return !containsUnknownSymbol; 516 581 } 582 583 584 #region helper class 585 private class DataForVariable { 586 public readonly string variableName; 587 public readonly string variableValue; // for factor vars 588 public readonly int lag; 589 590 public DataForVariable(string varName, string varValue, int lag) { 591 this.variableName = varName; 592 this.variableValue = varValue; 593 this.lag = lag; 594 } 595 596 public override bool Equals(object obj) { 597 var other = obj as DataForVariable; 598 if (other == null) return false; 599 return other.variableName.Equals(this.variableName) && 600 other.variableValue.Equals(this.variableValue) && 601 other.lag == this.lag; 602 } 603 604 public override int GetHashCode() { 605 return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag; 606 } 607 } 608 #endregion 517 609 } 518 610 }
Note: See TracChangeset
for help on using the changeset viewer.