#region License Information /* HeuristicLab * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using static HeuristicLab.Problems.DataAnalysis.Symbolic.TreeToAutoDiffTermConverter; namespace HeuristicLab.Problems.DataAnalysis.Symbolic.ConstantsOptimization { public static class Util { public static double[,] ExtractData(IDataset dataset, IEnumerable variables, IEnumerable rows) { var x = new double[rows.Count(), variables.Count()]; int row = 0; foreach (var r in rows) { int col = 0; foreach (var variable in variables) { if (dataset.VariableHasType(variable.variableName)) { x[row, col] = dataset.GetDoubleValue(variable.variableName, r + variable.lag); } else if (dataset.VariableHasType(variable.variableName)) { x[row, col] = dataset.GetStringValue(variable.variableName, r) == variable.variableValue ? 1 : 0; } else throw new InvalidProgramException("found a variable of unknown type"); col++; } row++; } return x; } public static List GenerateVariables(IDataset dataset) { var variables = new List(); foreach (var doubleVariable in dataset.DoubleVariables) { var data = new DataForVariable(doubleVariable, string.Empty, 0); variables.Add(data); } foreach (var stringVariable in dataset.StringVariables) { foreach (var stringValue in dataset.GetStringValues(stringVariable).Distinct()) { var data = new DataForVariable(stringVariable, stringValue, 0); variables.Add(data); } } return variables; } public static List ExtractLaggedVariables(ISymbolicExpressionTree tree) { var variables = new HashSet(); foreach (var laggedNode in tree.IterateNodesPrefix().OfType()) { var laggedVariableTreeNode = laggedNode as LaggedVariableTreeNode; if (laggedVariableTreeNode != null) { var data = new DataForVariable(laggedVariableTreeNode.VariableName, string.Empty, laggedVariableTreeNode.Lag); if (!variables.Contains(data)) variables.Add(data); } } return variables.ToList(); } public static double[] ExtractConstants(ISymbolicExpressionTree tree) { var constants = new List(); foreach (var node in tree.IterateNodesPrefix().OfType()) { ConstantTreeNode constantTreeNode = node as ConstantTreeNode; VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase; FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode; if (constantTreeNode != null) constants.Add(constantTreeNode.Value); else if (variableTreeNodeBase != null) constants.Add(variableTreeNodeBase.Weight); else if (factorVarTreeNode != null) { for (int j = 0; j < factorVarTreeNode.Weights.Length; j++) constants.Add(factorVarTreeNode.Weights[j]); } else throw new NotSupportedException(string.Format("Terminal nodes of type {0} are not supported.", node.GetType().GetPrettyName())); } return constants.ToArray(); } public static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) { int i = 0; foreach (var node in tree.Root.IterateNodesPrefix().OfType()) { ConstantTreeNode constantTreeNode = node as ConstantTreeNode; VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase; FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode; if (constantTreeNode != null) constantTreeNode.Value = constants[i++]; else if (variableTreeNodeBase != null) variableTreeNodeBase.Weight = constants[i++]; else if (factorVarTreeNode != null) { for (int j = 0; j < factorVarTreeNode.Weights.Length; j++) factorVarTreeNode.Weights[j] = constants[i++]; } else throw new NotSupportedException(string.Format("Terminal nodes of type {0} are not supported.", node.GetType().GetPrettyName())); } } } }