#region License Information /* HeuristicLab * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; namespace HeuristicLab.Problems.DataAnalysis.Symbolic.ConstantsOptimization { public static class Util { /// /// Extracts all variable information in a symbolic expression tree. The variable information is necessary to convert a tree in an AutoDiff term. /// /// The tree referencing the variables. /// The data for variables occuring in the tree. public static List ExtractVariables(ISymbolicExpressionTree tree) { if (tree == null) throw new ArgumentNullException("tree"); var variables = new HashSet(); foreach (var node in tree.IterateNodesPrefix().OfType()) { string variableName = node.VariableName; int lag = 0; var laggedNode = node as ILaggedTreeNode; if (laggedNode != null) lag = laggedNode.Lag; var factorNode = node as FactorVariableTreeNode; if (factorNode != null) { foreach (var factorValue in factorNode.Symbol.GetVariableValues(variableName)) { var data = new VariableData(variableName, factorValue, lag); variables.Add(data); } } else { var data = new VariableData(variableName, string.Empty, lag); variables.Add(data); } } return variables.ToList(); } /// /// Extract the necessary date for constants optimization with AutoDiff /// /// The dataset holding the data. /// The variables for which the data from the dataset should be extracted. /// The rows for which the data should be extracted. /// A two-dimensiona double array containing the input data. public static double[,] ExtractData(IDataset dataset, IEnumerable variables, IEnumerable rows) { if (dataset == null) throw new ArgumentNullException("dataset"); if (variables == null) throw new ArgumentNullException("variables"); if (rows == null) throw new ArgumentNullException("rows"); var x = new double[rows.Count(), variables.Count()]; int col = 0; foreach (var variable in variables) { if (dataset.VariableHasType(variable.variableName)) { IEnumerable values; if (variable.lag == 0) values = dataset.GetDoubleValues(variable.variableName, rows); else values = dataset.GetDoubleValues(variable.variableName, rows.Select(r => r + variable.lag)); int row = 0; foreach (var value in values) { x[row, col] = value; row++; } } else if (dataset.VariableHasType(variable.variableName)) { var values = dataset.GetStringValues(variable.variableName, rows); int row = 0; foreach (var value in values) { x[row, col] = value == variable.variableValue ? 1 : 0; ; row++; } } else throw new NotSupportedException("found a variable of unknown type"); col++; } return x; } /// /// Extracts all numeric nodes from a symbolic expression tree that can be optimized by the constants optimization /// /// The tree from which the numeric nodes should be extracted. /// A list containing all nodes with numeric coefficients. public static List ExtractNumericNodes(ISymbolicExpressionTree tree) { if (tree == null) throw new ArgumentNullException("tree"); var nodes = new List(); foreach (var node in tree.IterateNodesPrefix().OfType()) { ConstantTreeNode constantTreeNode = node as ConstantTreeNode; VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase; FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode; if (constantTreeNode != null) nodes.Add(constantTreeNode); else if (variableTreeNodeBase != null) nodes.Add(variableTreeNodeBase); else if (factorVarTreeNode != null) nodes.Add(variableTreeNodeBase); else throw new NotSupportedException(string.Format("Terminal nodes of type {0} are not supported.", node.GetType().GetPrettyName())); } return nodes; } /// /// Extracts all numeric constants from a symbolic expression tree. /// /// The tree from which the numeric constants should be extracted. /// Flag to determine whether constants for linear scaling have to be added at the end. /// α *f(x) + β, α = 1.0, β = 0.0 /// An array containing the numeric constants. public static double[] ExtractConstants(ISymbolicExpressionTree tree, bool addLinearScalingConstants) { if (tree == null) throw new ArgumentNullException("tree"); return ExtractConstants(tree.IterateNodesPrefix().OfType(), addLinearScalingConstants); } /// /// Extracts all numeric constants from a list of nodes. /// /// The list of nodes for which the numeric constants should be extracted. /// Flag to determine whether constants for linear scaling have to be added at the end. /// α *f(x) + β, α = 1.0, β = 0.0 /// An array containing the numeric constants. public static double[] ExtractConstants(IEnumerable nodes, bool addLinearScalingConstants) { if (nodes == null) throw new ArgumentNullException("nodes"); var constants = new List(); foreach (var node in nodes) { ConstantTreeNode constantTreeNode = node as ConstantTreeNode; VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase; FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode; if (constantTreeNode != null) constants.Add(constantTreeNode.Value); else if (variableTreeNodeBase != null) constants.Add(variableTreeNodeBase.Weight); else if (factorVarTreeNode != null) { for (int j = 0; j < factorVarTreeNode.Weights.Length; j++) constants.Add(factorVarTreeNode.Weights[j]); } else throw new NotSupportedException(string.Format("Nodes of type {0} are not supported.", node.GetType().GetPrettyName())); } if (addLinearScalingConstants) { constants.Add(1.0); constants.Add(0.0); } return constants.ToArray(); } /// /// Sets the numeric constants of the nodes to the provided values. /// /// The nodes whose constants should be updated. /// The numeric constants which should be set. public static void UpdateConstants(IEnumerable nodes, double[] constants) { if (nodes == null) throw new ArgumentNullException("nodes"); if (constants == null) throw new ArgumentNullException("constants"); int i = 0; foreach (var node in nodes) { ConstantTreeNode constantTreeNode = node as ConstantTreeNode; VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase; FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode; if (constantTreeNode != null) constantTreeNode.Value = constants[i++]; else if (variableTreeNodeBase != null) variableTreeNodeBase.Weight = constants[i++]; else if (factorVarTreeNode != null) { for (int j = 0; j < factorVarTreeNode.Weights.Length; j++) factorVarTreeNode.Weights[j] = constants[i++]; } else throw new NotSupportedException(string.Format("Terminal nodes of type {0} are not supported.", node.GetType().GetPrettyName())); } } /// /// Sets all numeric constants of the symbolic expression tree to the provided values. /// /// The tree for which the numeric constants should be updated. /// The numeric constants which should be set. public static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) { if (tree == null) throw new ArgumentNullException("tree"); if (constants == null) throw new ArgumentNullException("constants"); UpdateConstants(tree.IterateNodesPrefix().OfType(), constants); } } }