Changeset 16507 for branches/2974_Constants_Optimization/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/ConstantsOptimization
- Timestamp:
- 01/06/19 18:03:15 (6 years ago)
- Location:
- branches/2974_Constants_Optimization/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/ConstantsOptimization
- Files:
-
- 1 added
- 3 edited
- 1 copied
- 1 moved
Legend:
- Unmodified
- Added
- Removed
-
branches/2974_Constants_Optimization/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/ConstantsOptimization/AutoDiffConverter.cs
r16501 r16507 27 27 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 28 28 29 namespace HeuristicLab.Problems.DataAnalysis.Symbolic { 30 public class TreeToAutoDiffTermConverter { 31 public delegate double ParametricFunction(double[] vars, double[] @params); 32 33 public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params); 34 35 #region helper class 36 public class DataForVariable { 37 public readonly string variableName; 38 public readonly string variableValue; // for factor vars 39 public readonly int lag; 40 41 public DataForVariable(string varName, string varValue, int lag) { 42 this.variableName = varName; 43 this.variableValue = varValue; 44 this.lag = lag; 45 } 46 47 public override bool Equals(object obj) { 48 var other = obj as DataForVariable; 49 if (other == null) return false; 50 return other.variableName.Equals(this.variableName) && 51 other.variableValue.Equals(this.variableValue) && 52 other.lag == this.lag; 53 } 54 55 public override int GetHashCode() { 56 return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag; 57 } 58 } 59 #endregion 60 61 #region derivations of functions 62 // create function factory for arctangent 63 private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory( 64 eval: Math.Atan, 65 diff: x => 1 / (1 + x * x)); 66 67 private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory( 68 eval: Math.Sin, 69 diff: Math.Cos); 70 71 private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory( 72 eval: Math.Cos, 73 diff: x => -Math.Sin(x)); 74 75 private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory( 76 eval: Math.Tan, 77 diff: x => 1 + Math.Tan(x) * Math.Tan(x)); 78 79 private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory( 80 eval: alglib.errorfunction, 81 diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI)); 82 83 private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory( 84 eval: alglib.normaldistribution, 85 diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI)); 86 87 private static readonly Func<Term, UnaryFunc> abs = UnaryFunc.Factory( 88 eval: Math.Abs, 89 diff: x => Math.Sign(x) 90 ); 91 92 #endregion 93 94 public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms, 95 out List<DataForVariable> parameters, out double[] initialConstants, 96 out ParametricFunction func, 97 out ParametricFunctionGradient func_grad) { 98 29 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.ConstantsOptimization{ 30 public class AutoDiffConverter { 31 32 /// <summary> 33 /// Converts a symbolic expression tree into a parametetric AutoDiff term. 34 /// </summary> 35 /// <param name="tree">The tree the should be converted.</param> 36 /// <param name="addLinearScalingTerms">A flag that determines whether linear scaling terms should be added to the parametric term.</param> 37 /// <param name="numericNodes">The nodes that contain numeric coefficents that should be added as variables in the term.</param> 38 /// <param name="variableData">The variable information that is used to create parameters in the term.</param> 39 /// <param name="autoDiffTerm">The resulting parametric AutoDiff term.</param> 40 /// <returns>A flag to see if the conversion has succeeded.</returns> 41 public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool addLinearScalingTerms, 42 IEnumerable<ISymbolicExpressionTreeNode> numericNodes, IEnumerable<VariableData> variableData, 43 out IParametricCompiledTerm autoDiffTerm) { 99 44 // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree 100 var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable);45 var transformator = new AutoDiffConverter(numericNodes, variableData); 101 46 AutoDiff.Term term; 102 try { 103 term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0)); 104 105 if (addLinearScalingTerms) { 106 // scaling variables α, β are given at the beginning of the parameter vector 107 var alpha = new AutoDiff.Variable(); 108 var beta = new AutoDiff.Variable(); 109 transformator.variables.Insert(0, alpha); 110 transformator.variables.Insert(0, beta); 111 112 term = term * alpha + beta; 113 } 114 115 var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values 116 var compiledTerm = term.Compile(transformator.variables.ToArray(), 117 parameterEntries.Select(kvp => kvp.Value).ToArray()); 118 119 parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key)); 120 initialConstants = transformator.initialConstants.ToArray(); 121 func = (vars, @params) => compiledTerm.Evaluate(vars, @params); 122 func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params); 123 return true; 124 } catch (ConversionException) { 125 parameters = null; 126 initialConstants = null; 127 func = null; 128 func_grad = null; 129 } 130 return false; 131 } 132 133 public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool addLinearScalingTerms, IEnumerable<DataForVariable> variables, 134 out IParametricCompiledTerm autoDiffTerm, out double[] initialConstants) { 135 // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree 136 //TODO change ctor 137 var transformator = new TreeToAutoDiffTermConverter(true); 138 var parameters = new AutoDiff.Variable[variables.Count()]; 139 140 int i = 0; 141 foreach(var variable in variables) { 142 var autoDiffVar = new AutoDiff.Variable(); 143 transformator.parameters.Add(variable, autoDiffVar); 144 parameters[i] = autoDiffVar; 145 i++; 146 } 147 148 AutoDiff.Term term; 47 149 48 try { 150 49 term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0)); … … 158 57 transformator.variables.Add(alpha); 159 58 transformator.variables.Add(beta); 160 161 transformator.initialConstants.Add(1.0); 162 transformator.initialConstants.Add(0.0); 163 } 164 165 var compiledTerm = term.Compile(transformator.variables.ToArray(), parameters); 59 } 60 var compiledTerm = term.Compile(transformator.variables.ToArray(), transformator.parameters.Values.ToArray()); 166 61 autoDiffTerm = compiledTerm; 167 initialConstants = transformator.initialConstants.ToArray();168 169 62 return true; 170 63 } catch (ConversionException) { 171 64 autoDiffTerm = null; 172 initialConstants = null;173 65 } 174 66 return false; … … 176 68 177 69 // state for recursive transformation of trees 178 private readonly List<double> initialConstants;179 private readonly Dictionary< DataForVariable, AutoDiff.Variable> parameters;70 private readonly HashSet<ISymbolicExpressionTreeNode> nodesForOptimization; 71 private readonly Dictionary<VariableData, AutoDiff.Variable> parameters; 180 72 private readonly List<AutoDiff.Variable> variables; 181 private readonly bool makeVariableWeightsVariable; 182 183 private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable) { 184 this.makeVariableWeightsVariable = makeVariableWeightsVariable; 185 this.initialConstants = new List<double>(); 186 this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>(); 73 74 private AutoDiffConverter(IEnumerable<ISymbolicExpressionTreeNode> nodesForOptimization, IEnumerable<VariableData> variableData) { 75 this.nodesForOptimization = new HashSet<ISymbolicExpressionTreeNode>(nodesForOptimization); 76 this.parameters = variableData.ToDictionary(k => k, v => new AutoDiff.Variable()); 187 77 this.variables = new List<AutoDiff.Variable>(); 188 78 } … … 190 80 private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) { 191 81 if (node.Symbol is Constant) { 192 initialConstants.Add(((ConstantTreeNode)node).Value); 193 var var = new AutoDiff.Variable(); 194 variables.Add(var); 195 return var; 82 var constantNode = node as ConstantTreeNode; 83 var value = constantNode.Value; 84 if (nodesForOptimization.Contains(node)) { 85 AutoDiff.Variable var = new AutoDiff.Variable(); 86 variables.Add(var); 87 return var; 88 } else { 89 return value; 90 } 196 91 } 197 92 if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) { … … 200 95 // factor variable values are only 0 or 1 and set in x accordingly 201 96 var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty; 202 var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue); 203 204 if (makeVariableWeightsVariable) { 205 initialConstants.Add(varNode.Weight); 206 var w = new AutoDiff.Variable(); 207 variables.Add(w); 208 return AutoDiff.TermBuilder.Product(w, par); 97 var data = new VariableData(varNode.VariableName, varValue, 0); 98 var par = parameters[data]; 99 var value = varNode.Weight; 100 101 if (nodesForOptimization.Contains(node)) { 102 AutoDiff.Variable var = new AutoDiff.Variable(); 103 variables.Add(var); 104 return AutoDiff.TermBuilder.Product(var, par); 209 105 } else { 210 return varNode.Weight * par;106 return AutoDiff.TermBuilder.Product(value, par); 211 107 } 212 108 } … … 215 111 var products = new List<Term>(); 216 112 foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) { 217 var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue); 218 219 initialConstants.Add(factorVarNode.GetValue(variableValue)); 220 var wVar = new AutoDiff.Variable(); 221 variables.Add(wVar); 222 223 products.Add(AutoDiff.TermBuilder.Product(wVar, par)); 113 var data = new VariableData(factorVarNode.VariableName, variableValue, 0); 114 var par = parameters[data]; 115 var value = factorVarNode.GetValue(variableValue); 116 117 if (nodesForOptimization.Contains(node)) { 118 var wVar = new AutoDiff.Variable(); 119 variables.Add(wVar); 120 121 products.Add(AutoDiff.TermBuilder.Product(wVar, par)); 122 } else { 123 products.Add(AutoDiff.TermBuilder.Product(value, par)); 124 } 224 125 } 225 126 return AutoDiff.TermBuilder.Sum(products); … … 227 128 if (node.Symbol is LaggedVariable) { 228 129 var varNode = node as LaggedVariableTreeNode; 229 var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag); 230 231 if (makeVariableWeightsVariable) { 232 initialConstants.Add(varNode.Weight); 233 var w = new AutoDiff.Variable(); 234 variables.Add(w); 235 return AutoDiff.TermBuilder.Product(w, par); 130 var data = new VariableData(varNode.VariableName, string.Empty, varNode.Lag); 131 var par = parameters[data]; 132 var value = varNode.Weight; 133 134 if (nodesForOptimization.Contains(node)) { 135 AutoDiff.Variable var = new AutoDiff.Variable(); 136 variables.Add(var); 137 return AutoDiff.TermBuilder.Product(var, par); 236 138 } else { 237 return varNode.Weight * par; 238 } 139 return AutoDiff.TermBuilder.Product(value, par); 140 } 141 239 142 } 240 143 if (node.Symbol is Addition) { … … 330 233 } 331 234 332 333 // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination 334 // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available 335 private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters, 336 string varName, string varValue = "", int lag = 0) { 337 var data = new DataForVariable(varName, varValue, lag); 338 339 AutoDiff.Variable par = null; 340 if (!parameters.TryGetValue(data, out par)) { 341 // not found -> create new parameter and entries in names and values lists 342 par = new AutoDiff.Variable(); 343 parameters.Add(data, par); 344 } 345 return par; 346 } 235 #region derivations of functions 236 // create function factory for arctangent 237 private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory( 238 eval: Math.Atan, 239 diff: x => 1 / (1 + x * x)); 240 241 private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory( 242 eval: Math.Sin, 243 diff: Math.Cos); 244 245 private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory( 246 eval: Math.Cos, 247 diff: x => -Math.Sin(x)); 248 249 private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory( 250 eval: Math.Tan, 251 diff: x => 1 + Math.Tan(x) * Math.Tan(x)); 252 253 private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory( 254 eval: alglib.errorfunction, 255 diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI)); 256 257 private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory( 258 eval: alglib.normaldistribution, 259 diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI)); 260 261 private static readonly Func<Term, UnaryFunc> abs = UnaryFunc.Factory( 262 eval: Math.Abs, 263 diff: x => Math.Sign(x) 264 ); 265 266 #endregion 267 347 268 348 269 public static bool IsCompatible(ISymbolicExpressionTree tree) { … … 379 300 [Serializable] 380 301 public class ConversionException : Exception { 381 382 public ConversionException() { 383 } 384 385 public ConversionException(string message) : base(message) { 386 } 387 388 public ConversionException(string message, Exception inner) : base(message, inner) { 389 } 390 302 public ConversionException() { } 303 public ConversionException(string message) : base(message) { } 304 public ConversionException(string message, Exception inner) : base(message, inner) { } 391 305 protected ConversionException( 392 306 SerializationInfo info, -
branches/2974_Constants_Optimization/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/ConstantsOptimization/IConstantsOptimizer.cs
r16500 r16507 23 23 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 24 24 25 namespace HeuristicLab.Problems.DataAnalysis.Symbolic {25 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.ConstantsOptimization { 26 26 public interface IConstantsOptimizer { 27 27 bool ApplyLinearScaling { get; set; } -
branches/2974_Constants_Optimization/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/ConstantsOptimization/LMConstantsOptimizer.cs
r16500 r16507 25 25 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 26 26 27 namespace HeuristicLab.Problems.DataAnalysis.Symbolic {27 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.ConstantsOptimization { 28 28 public class LMConstantsOptimizer { 29 private bool ApplyLinearScaling { get; set; }30 private int MaximumIterations { get; set; }31 29 32 p ublic LMConstantsOptimizer() {30 private LMConstantsOptimizer() { } 33 31 32 /// <summary> 33 /// Method to determine whether the numeric constants of the tree can be optimized. This depends primarily on the symbols occuring in the tree. 34 /// </summary> 35 /// <param name="tree">The tree that should be analyzed</param> 36 /// <returns>A flag indicating whether the numeric constants of the tree can be optimized</returns> 37 public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) { 38 return AutoDiffConverter.IsCompatible(tree); 34 39 } 40 35 41 /// <summary> 36 /// 42 /// Optimizes the numeric constants in a symbolic expression tree in place. 37 43 /// </summary> 38 /// <param name="tree">The tree for which the constants are optimized.</param> 39 /// <param name="nodes">The nodes which should be adapted. The nodes must be the same (reference) as the ones used in the tree and the double values specify the inita</param> 40 /// <param name="x">The input data.</param> 41 /// <param name="y">The targer date.</param> 42 /// <param name="applyLinearScaling">Flag that determines whether linear scaling nodes should be added during the optimization.</param> 43 /// <returns> Fit of the symbolic expression tree in terms of ... </returns> 44 public static double OptimizeConstants(ISymbolicExpressionTree tree,IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, int maxIterations = 10) { 45 //if (tree == null) throw new ArgumentNullException("tree"); 46 //if (nodes == null) throw new ArgumentNullException("nodes"); 47 //if (initialConstants == null) throw new ArgumentNullException("intitialConstants"); 48 //if (problemData == null) throw new ArgumentNullException("problemData"); 44 /// <param name="tree">The tree for which the constants should be optimized</param> 45 /// <param name="dataset">The dataset containing the data.</param> 46 /// <param name="targetVariable">The target variable name.</param> 47 /// <param name="rows">The rows for which the data should be extracted.</param> 48 /// <param name="applyLinearScaling">A flag to determine whether linear scaling should be applied during the optimization</param> 49 /// <param name="maxIterations">The maximum number of iterations of the Levenberg-Marquard algorithm.</param> 50 /// <returns></returns> 51 public static double OptimizeConstants(ISymbolicExpressionTree tree, 52 IDataset dataset, string targetVariable, IEnumerable<int> rows, 53 bool applyLinearScaling, int maxIterations = 10) { 54 if (tree == null) throw new ArgumentNullException("tree"); 55 if (dataset == null) throw new ArgumentNullException("dataset"); 56 if (!dataset.ContainsVariable(targetVariable)) throw new ArgumentException("The dataset does not contain the provided target variable."); 49 57 50 //if (!nodes.Any()) return 0; 51 52 var ds = problemData.Dataset; 53 var variables = ConstantsOptimization.Util.GenerateVariables(ds); 54 var laggedVariables = ConstantsOptimization.Util.ExtractLaggedVariables(tree); 55 var allVariables = variables.Union(laggedVariables); 56 57 double[,] x = ConstantsOptimization.Util.ExtractData(ds, allVariables, rows); 58 double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray(); 59 60 double[] initialConstants; 58 var allVariables = Util.ExtractVariables(tree); 59 var numericNodes = Util.ExtractNumericNodes(tree); 61 60 62 61 AutoDiff.IParametricCompiledTerm term; 63 if (! TreeToAutoDiffTermConverter.TryConvertToAutoDiff(tree, applyLinearScaling, allVariables, out term, out initialConstants))64 throw new NotSupportedException("Could not optimize constants of symbolic expression treedue to not supported symbols used in the tree.");62 if (!AutoDiffConverter.TryConvertToAutoDiff(tree, applyLinearScaling, numericNodes, allVariables, out term)) 63 throw new NotSupportedException("Could not convert symbolic expression tree to an AutoDiff term due to not supported symbols used in the tree."); 65 64 66 var constants = (double[])initialConstants.Clone(); 65 //Variables of the symbolic expression tree correspond to parameters in the term 66 //Hence if no parameters are present no variables occur in the tree and the R² = 0 67 if (term.Parameters.Count == 0) return 0.0; 68 69 var initialConstants = Util.ExtractConstants(numericNodes, applyLinearScaling); 70 double[] constants; 71 double[,] x = Util.ExtractData(dataset, allVariables, rows); 72 double[] y = dataset.GetDoubleValues(targetVariable, rows).ToArray(); 73 74 var result = OptimizeConstants(term, initialConstants, x, y, maxIterations, out constants); 75 if (result != 0.0 && constants.Length != 0) 76 Util.UpdateConstants(numericNodes, constants); 77 78 return result; 79 } 80 81 /// <summary> 82 /// Optimizes the numeric coefficents of an AutoDiff Term using the Levenberg-Marquard algorithm. 83 /// </summary> 84 /// <param name="term">The AutoDiff term for which the numeric coefficients should be optimized.</param> 85 /// <param name="initialConstants">The starting values for the numeric coefficients.</param> 86 /// <param name="x">The input data for the optimization.</param> 87 /// <param name="y">The target values for the optimization.</param> 88 /// <param name="maxIterations">The maximum number of iterations of the Levenberg-Marquard</param> 89 /// <param name="constants">The opitmized constants.</param> 90 /// <param name="LM_IterationCallback">An optional callback for detailed analysis that is called in each algorithm iteration.</param> 91 /// <returns>The R² of the term evaluated on the input data x and the target data y using the optimized constants</returns> 92 public static double OptimizeConstants(AutoDiff.IParametricCompiledTerm term, double[] initialConstants, double[,] x, double[] y, 93 int maxIterations, out double[] constants, Action<double[], double, object> LM_IterationCallback = null) { 94 95 if (term.Parameters.Count == 0) { 96 constants = new double[0]; 97 return 0.0; 98 } 99 100 var optimizedConstants = (double[])initialConstants.Clone(); 101 int numberOfRows = x.GetLength(0); 102 int numberOfColumns = x.GetLength(1); 103 int numberOfConstants = optimizedConstants.Length; 104 67 105 alglib.lsfitstate state; 68 106 alglib.lsfitreport rep; 107 alglib.ndimensional_rep xrep = (p, f, obj) => LM_IterationCallback(p, f, obj); 69 108 int retVal; 70 109 71 int numberOfRows = x.GetLength(0);72 int numberOfColumns = x.GetLength(1);73 int numberOfConstants = constants.Length;74 75 110 try { 76 alglib.lsfitcreatefg(x, y, constants, numberOfRows, numberOfColumns, numberOfConstants, cheapfg: false, state: out state);111 alglib.lsfitcreatefg(x, y, optimizedConstants, numberOfRows, numberOfColumns, numberOfConstants, cheapfg: false, state: out state); 77 112 alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations); 78 //alglib.lsfitsetgradientcheck(state, 0.001);79 alglib.lsfitfit(state, Evaluate, EvaluateGradient, null, term);80 alglib.lsfitresults(state, out retVal, out constants, out rep);113 alglib.lsfitsetxrep(state, LM_IterationCallback != null); 114 alglib.lsfitfit(state, Evaluate, EvaluateGradient, xrep, term); 115 alglib.lsfitresults(state, out retVal, out optimizedConstants, out rep); 81 116 } catch (ArithmeticException) { 117 constants = new double[0]; 82 118 return double.NaN; 83 119 } catch (alglib.alglibexception) { 120 constants = new double[0]; 84 121 return double.NaN; 85 122 } 86 123 87 ConstantsOptimization.Util.UpdateConstants(tree, constants); 88 124 constants = optimizedConstants; 89 125 return rep.r2; 90 126 } 127 91 128 92 129 private static void Evaluate(double[] c, double[] x, ref double fx, object o) { … … 101 138 Array.Copy(result.Item1, grad, grad.Length); 102 139 } 103 104 public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {105 return TreeToAutoDiffTermConverter.IsCompatible(tree);106 }107 140 } 108 141 } -
branches/2974_Constants_Optimization/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/ConstantsOptimization/Util.cs
r16500 r16507 25 25 using HeuristicLab.Common; 26 26 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 27 using static HeuristicLab.Problems.DataAnalysis.Symbolic.TreeToAutoDiffTermConverter;28 27 29 28 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.ConstantsOptimization { 30 29 public static class Util { 31 32 public static double[,] ExtractData(IDataset dataset, IEnumerable<DataForVariable> variables, IEnumerable<int> rows) { 33 var x = new double[rows.Count(), variables.Count()]; 34 35 int row = 0; 36 foreach (var r in rows) { 37 int col = 0; 38 foreach (var variable in variables) { 39 if (dataset.VariableHasType<double>(variable.variableName)) { 40 x[row, col] = dataset.GetDoubleValue(variable.variableName, r + variable.lag); 41 } else if (dataset.VariableHasType<string>(variable.variableName)) { 42 x[row, col] = dataset.GetStringValue(variable.variableName, r) == variable.variableValue ? 1 : 0; 43 } else throw new InvalidProgramException("found a variable of unknown type"); 44 col++; 45 } 46 row++; 47 } 48 return x; 49 } 50 51 public static List<DataForVariable> GenerateVariables(IDataset dataset) { 52 var variables = new List<DataForVariable>(); 53 foreach (var doubleVariable in dataset.DoubleVariables) { 54 var data = new DataForVariable(doubleVariable, string.Empty, 0); 55 variables.Add(data); 56 } 57 58 foreach (var stringVariable in dataset.StringVariables) { 59 foreach (var stringValue in dataset.GetStringValues(stringVariable).Distinct()) { 60 var data = new DataForVariable(stringVariable, stringValue, 0); 30 /// <summary> 31 /// Extracts all variable information in a symbolic expression tree. The variable information is necessary to convert a tree in an AutoDiff term. 32 /// </summary> 33 /// <param name="tree">The tree referencing the variables.</param> 34 /// <returns>The data for variables occuring in the tree.</returns> 35 public static List<VariableData> ExtractVariables(ISymbolicExpressionTree tree) { 36 if (tree == null) throw new ArgumentNullException("tree"); 37 38 var variables = new HashSet<VariableData>(); 39 foreach (var node in tree.IterateNodesPrefix().OfType<IVariableTreeNode>()) { 40 string variableName = node.VariableName; 41 int lag = 0; 42 var laggedNode = node as ILaggedTreeNode; 43 if (laggedNode != null) lag = laggedNode.Lag; 44 45 46 var factorNode = node as FactorVariableTreeNode; 47 if (factorNode != null) { 48 foreach (var factorValue in factorNode.Symbol.GetVariableValues(variableName)) { 49 var data = new VariableData(variableName, factorValue, lag); 50 variables.Add(data); 51 } 52 } else { 53 var data = new VariableData(variableName, string.Empty, lag); 61 54 variables.Add(data); 62 55 } 63 56 } 64 return variables;65 }66 67 public static List<DataForVariable> ExtractLaggedVariables(ISymbolicExpressionTree tree) {68 var variables = new HashSet<DataForVariable>();69 foreach (var laggedNode in tree.IterateNodesPrefix().OfType<ILaggedTreeNode>()) {70 var laggedVariableTreeNode = laggedNode as LaggedVariableTreeNode;71 if (laggedVariableTreeNode != null) {72 var data = new DataForVariable(laggedVariableTreeNode.VariableName, string.Empty, laggedVariableTreeNode.Lag);73 if (!variables.Contains(data)) variables.Add(data);74 }75 }76 57 return variables.ToList(); 77 58 } 78 79 public static double[] ExtractConstants(ISymbolicExpressionTree tree) { 59 /// <summary> 60 /// Extract the necessary date for constants optimization with AutoDiff 61 /// </summary> 62 /// <param name="dataset">The dataset holding the data.</param> 63 /// <param name="variables">The variables for which the data from the dataset should be extracted.</param> 64 /// <param name="rows">The rows for which the data should be extracted.</param> 65 /// <returns>A two-dimensiona double array containing the input data.</returns> 66 public static double[,] ExtractData(IDataset dataset, IEnumerable<VariableData> variables, IEnumerable<int> rows) { 67 if (dataset == null) throw new ArgumentNullException("dataset"); 68 if (variables == null) throw new ArgumentNullException("variables"); 69 if (rows == null) throw new ArgumentNullException("rows"); 70 71 var x = new double[rows.Count(), variables.Count()]; 72 73 int col = 0; 74 foreach (var variable in variables) { 75 if (dataset.VariableHasType<double>(variable.variableName)) { 76 IEnumerable<double> values; 77 if (variable.lag == 0) 78 values = dataset.GetDoubleValues(variable.variableName, rows); 79 else 80 values = dataset.GetDoubleValues(variable.variableName, rows.Select(r => r + variable.lag)); 81 82 int row = 0; 83 foreach (var value in values) { 84 x[row, col] = value; 85 row++; 86 } 87 } else if (dataset.VariableHasType<string>(variable.variableName)) { 88 var values = dataset.GetStringValues(variable.variableName, rows); 89 90 int row = 0; 91 foreach (var value in values) { 92 x[row, col] = value == variable.variableValue ? 1 : 0; ; 93 row++; 94 } 95 } else throw new NotSupportedException("found a variable of unknown type"); 96 col++; 97 } 98 99 return x; 100 } 101 102 /// <summary> 103 /// Extracts all numeric nodes from a symbolic expression tree that can be optimized by the constants optimization 104 /// </summary> 105 /// <param name="tree">The tree from which the numeric nodes should be extracted.</param> 106 /// <returns>A list containing all nodes with numeric coefficients.</returns> 107 public static List<ISymbolicExpressionTreeNode> ExtractNumericNodes(ISymbolicExpressionTree tree) { 108 if (tree == null) throw new ArgumentNullException("tree"); 109 110 var nodes = new List<ISymbolicExpressionTreeNode>(); 111 foreach (var node in tree.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) { 112 ConstantTreeNode constantTreeNode = node as ConstantTreeNode; 113 VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase; 114 FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode; 115 if (constantTreeNode != null) nodes.Add(constantTreeNode); 116 else if (variableTreeNodeBase != null) nodes.Add(variableTreeNodeBase); 117 else if (factorVarTreeNode != null) nodes.Add(variableTreeNodeBase); 118 else throw new NotSupportedException(string.Format("Terminal nodes of type {0} are not supported.", node.GetType().GetPrettyName())); 119 } 120 return nodes; 121 } 122 123 /// <summary> 124 /// Extracts all numeric constants from a symbolic expression tree. 125 /// </summary> 126 /// <param name="tree">The tree from which the numeric constants should be extracted.</param> 127 /// <param name="addLinearScalingConstants">Flag to determine whether constants for linear scaling have to be added at the end. 128 /// α *f(x) + β, α = 1.0, β = 0.0 </param> 129 /// <returns> An array containing the numeric constants.</returns> 130 public static double[] ExtractConstants(ISymbolicExpressionTree tree, bool addLinearScalingConstants) { 131 if (tree == null) throw new ArgumentNullException("tree"); 132 return ExtractConstants(tree.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>(), addLinearScalingConstants); 133 } 134 135 /// <summary> 136 /// Extracts all numeric constants from a list of nodes. 137 /// </summary> 138 /// <param name="nodes">The list of nodes for which the numeric constants should be extracted.</param> 139 /// <param name="addLinearScalingConstants">Flag to determine whether constants for linear scaling have to be added at the end. 140 /// α *f(x) + β, α = 1.0, β = 0.0 </param> 141 /// <returns> An array containing the numeric constants.</returns> 142 public static double[] ExtractConstants(IEnumerable<ISymbolicExpressionTreeNode> nodes, bool addLinearScalingConstants) { 143 if (nodes == null) throw new ArgumentNullException("nodes"); 144 80 145 var constants = new List<double>(); 81 foreach (var node in tree.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {146 foreach (var node in nodes) { 82 147 ConstantTreeNode constantTreeNode = node as ConstantTreeNode; 83 148 VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase; … … 90 155 for (int j = 0; j < factorVarTreeNode.Weights.Length; j++) 91 156 constants.Add(factorVarTreeNode.Weights[j]); 92 } else throw new NotSupportedException(string.Format("Terminal nodes of type {0} are not supported.", node.GetType().GetPrettyName())); 93 } 157 } else throw new NotSupportedException(string.Format("Nodes of type {0} are not supported.", node.GetType().GetPrettyName())); 158 } 159 constants.Add(1.0); 160 constants.Add(0.0); 94 161 return constants.ToArray(); 95 162 } 96 163 97 public static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) { 164 /// <summary> 165 /// Sets the numeric constants of the nodes to the provided values. 166 /// </summary> 167 /// <param name="nodes">The nodes whose constants should be updated.</param> 168 /// <param name="constants">The numeric constants which should be set. </param> 169 public static void UpdateConstants(IEnumerable<ISymbolicExpressionTreeNode> nodes, double[] constants) { 170 if (nodes == null) throw new ArgumentNullException("nodes"); 171 if (constants == null) throw new ArgumentNullException("constants"); 172 98 173 int i = 0; 99 foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {174 foreach (var node in nodes) { 100 175 ConstantTreeNode constantTreeNode = node as ConstantTreeNode; 101 176 VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase; … … 111 186 } 112 187 } 188 189 /// <summary> 190 /// Sets all numeric constants of the symbolic expression tree to the provided values. 191 /// </summary> 192 /// <param name="tree">The tree for which the numeric constants should be updated.</param> 193 /// <param name="constants">The numeric constants which should be set.</param> 194 public static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) { 195 if (tree == null) throw new ArgumentNullException("tree"); 196 if (constants == null) throw new ArgumentNullException("constants"); 197 UpdateConstants(tree.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>(), constants); 198 } 113 199 } 114 200 }
Note: See TracChangeset
for help on using the changeset viewer.