Changeset 18197 for branches/3136_Structural_GP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/ParameterOptimization.cs
- Timestamp:
- 01/14/22 12:06:18 (2 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/3136_Structural_GP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/ParameterOptimization.cs
r18192 r18197 27 27 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression { 28 28 public static class ParameterOptimization { 29 public static double OptimizeTreeParameters(IRegressionProblemData problemData, ISymbolicExpressionTree tree, 30 int maxIterations = 10, bool updateParametersInTree = true, bool updateVariableWeights = true,29 public static double OptimizeTreeParameters(IRegressionProblemData problemData, ISymbolicExpressionTree tree, int maxIterations = 10, 30 bool updateParametersInTree = true, bool updateVariableWeights = true, IEnumerable<ISymbolicExpressionTreeNode> excludeNodes = null, 31 31 double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, 32 32 IEnumerable<int> rows = null, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter = null, … … 35 35 if (rows == null) rows = problemData.TrainingIndices; 36 36 if (interpreter == null) interpreter = new SymbolicDataAnalysisExpressionTreeBatchInterpreter(); 37 if (excludeNodes == null) excludeNodes = Enumerable.Empty<ISymbolicExpressionTreeNode>(); 37 38 38 39 // Numeric parameters in the tree become variables for parameter optimization. … … 46 47 TreeToAutoDiffTermConverter.ParametricFunction func; 47 48 TreeToAutoDiffTermConverter.ParametricFunctionGradient func_grad; 48 if (!TreeToAutoDiffTermConverter.TryConvertToAutoDiff(tree, updateVariableWeights, addLinearScalingTerms: false, out parameters, out initialParameters, out func, out func_grad)) 49 if (!TreeToAutoDiffTermConverter.TryConvertToAutoDiff(tree, 50 updateVariableWeights, addLinearScalingTerms: false, excludeNodes, 51 out parameters, out initialParameters, out func, out func_grad)) 49 52 throw new NotSupportedException("Could not optimize parameters of symbolic expression tree due to not supported symbols used in the tree."); 50 53 var parameterEntries = parameters.ToArray(); // order of entries must be the same for x … … 112 115 // request was submitted. 113 116 if (rep.terminationtype > 0) { 114 UpdateParameters(tree, c, updateVariableWeights );117 UpdateParameters(tree, c, updateVariableWeights, excludeNodes); 115 118 } 116 119 var quality = SymbolicRegressionSingleObjectiveMeanSquaredErrorEvaluator.Calculate( … … 119 122 lowerEstimationLimit, upperEstimationLimit); 120 123 121 if (!updateParametersInTree) UpdateParameters(tree, initialParameters, updateVariableWeights );124 if (!updateParametersInTree) UpdateParameters(tree, initialParameters, updateVariableWeights, excludeNodes); 122 125 123 126 if (originalQuality < quality || double.IsNaN(quality)) { 124 UpdateParameters(tree, initialParameters, updateVariableWeights );127 UpdateParameters(tree, initialParameters, updateVariableWeights, excludeNodes); 125 128 return originalQuality; 126 129 } … … 128 131 } 129 132 130 private static void UpdateParameters(ISymbolicExpressionTree tree, double[] parameters, bool updateVariableWeights) { 133 private static void UpdateParameters(ISymbolicExpressionTree tree, double[] parameters, 134 bool updateVariableWeights, IEnumerable<ISymbolicExpressionTreeNode> excludedNodes) { 131 135 int i = 0; 132 foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>() ) {136 foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().Except(excludedNodes)) { 133 137 NumberTreeNode numberTreeNode = node as NumberTreeNode; 134 138 VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase; … … 140 144 } else if (updateVariableWeights && variableTreeNodeBase != null) 141 145 variableTreeNodeBase.Weight = parameters[i++]; 142 else if ( factorVarTreeNode != null) {146 else if (updateVariableWeights && factorVarTreeNode != null) { 143 147 for (int j = 0; j < factorVarTreeNode.Weights.Length; j++) 144 148 factorVarTreeNode.Weights[j] = parameters[i++];
Note: See TracChangeset
for help on using the changeset viewer.