Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2974_Constants_Optimization/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/ConstantsOptimization/Util.cs @ 16931

Last change on this file since 16931 was 16522, checked in by mkommend, 6 years ago

#2974: Fixed bug in constants extraction when adding linear scaling coefficients.

File size: 9.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27
28namespace HeuristicLab.Problems.DataAnalysis.Symbolic.ConstantsOptimization {
29  public static class Util {
30    /// <summary>
31    /// Extracts all variable information in a symbolic expression tree. The variable information is necessary to convert a tree in an AutoDiff term.
32    /// </summary>
33    /// <param name="tree">The tree referencing the variables.</param>
34    /// <returns>The data for variables occuring in the tree.</returns>
35    public static List<VariableData> ExtractVariables(ISymbolicExpressionTree tree) {
36      if (tree == null) throw new ArgumentNullException("tree");
37
38      var variables = new HashSet<VariableData>();
39      foreach (var node in tree.IterateNodesPrefix().OfType<IVariableTreeNode>()) {
40        string variableName = node.VariableName;
41        int lag = 0;
42        var laggedNode = node as ILaggedTreeNode;
43        if (laggedNode != null) lag = laggedNode.Lag;
44
45
46        var factorNode = node as FactorVariableTreeNode;
47        if (factorNode != null) {
48          foreach (var factorValue in factorNode.Symbol.GetVariableValues(variableName)) {
49            var data = new VariableData(variableName, factorValue, lag);
50            variables.Add(data);
51          }
52        } else {
53          var data = new VariableData(variableName, string.Empty, lag);
54          variables.Add(data);
55        }
56      }
57      return variables.ToList();
58    }
59    /// <summary>
60    /// Extract the necessary date for constants optimization with AutoDiff
61    /// </summary>
62    /// <param name="dataset">The dataset holding the data.</param>
63    /// <param name="variables">The variables for which the data from the dataset should be extracted.</param>
64    /// <param name="rows">The rows for which the data should be extracted.</param>
65    /// <returns>A two-dimensiona double array containing the input data.</returns>
66    public static double[,] ExtractData(IDataset dataset, IEnumerable<VariableData> variables, IEnumerable<int> rows) {
67      if (dataset == null) throw new ArgumentNullException("dataset");
68      if (variables == null) throw new ArgumentNullException("variables");
69      if (rows == null) throw new ArgumentNullException("rows");
70
71      var x = new double[rows.Count(), variables.Count()];
72
73      int col = 0;
74      foreach (var variable in variables) {
75        if (dataset.VariableHasType<double>(variable.variableName)) {
76          IEnumerable<double> values;
77          if (variable.lag == 0)
78            values = dataset.GetDoubleValues(variable.variableName, rows);
79          else
80            values = dataset.GetDoubleValues(variable.variableName, rows.Select(r => r + variable.lag));
81
82          int row = 0;
83          foreach (var value in values) {
84            x[row, col] = value;
85            row++;
86          }
87        } else if (dataset.VariableHasType<string>(variable.variableName)) {
88          var values = dataset.GetStringValues(variable.variableName, rows);
89
90          int row = 0;
91          foreach (var value in values) {
92            x[row, col] = value == variable.variableValue ? 1 : 0; ;
93            row++;
94          }
95        } else throw new NotSupportedException("found a variable of unknown type");
96        col++;
97      }
98
99      return x;
100    }
101
102    /// <summary>
103    /// Extracts all numeric nodes from a symbolic expression tree that can be optimized by the constants optimization
104    /// </summary>
105    /// <param name="tree">The tree from which the numeric nodes should be extracted.</param>
106    /// <returns>A list containing all nodes with numeric coefficients.</returns>
107    public static List<ISymbolicExpressionTreeNode> ExtractNumericNodes(ISymbolicExpressionTree tree) {
108      if (tree == null) throw new ArgumentNullException("tree");
109
110      var nodes = new List<ISymbolicExpressionTreeNode>();
111      foreach (var node in tree.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
112        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
113        VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase;
114        FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
115        if (constantTreeNode != null) nodes.Add(constantTreeNode);
116        else if (variableTreeNodeBase != null) nodes.Add(variableTreeNodeBase);
117        else if (factorVarTreeNode != null) nodes.Add(variableTreeNodeBase);
118        else throw new NotSupportedException(string.Format("Terminal nodes of type {0} are not supported.", node.GetType().GetPrettyName()));
119      }
120      return nodes;
121    }
122
123    /// <summary>
124    /// Extracts all numeric constants from a symbolic expression tree.
125    /// </summary>
126    /// <param name="tree">The tree from which the numeric constants should be extracted.</param>
127    /// <param name="addLinearScalingConstants">Flag to determine whether constants for linear scaling have to be added at the end.
128    /// α *f(x) + β, α = 1.0,  β = 0.0 </param>
129    /// <returns> An array containing the numeric constants.</returns>
130    public static double[] ExtractConstants(ISymbolicExpressionTree tree, bool addLinearScalingConstants) {
131      if (tree == null) throw new ArgumentNullException("tree");
132      return ExtractConstants(tree.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>(), addLinearScalingConstants);
133    }
134
135    /// <summary>
136    /// Extracts all numeric constants from a list of nodes.
137    /// </summary>
138    /// <param name="nodes">The list of nodes for which the numeric constants should be extracted.</param>
139    /// <param name="addLinearScalingConstants">Flag to determine whether constants for linear scaling have to be added at the end.
140    /// α *f(x) + β, α = 1.0,  β = 0.0 </param>
141    /// <returns> An array containing the numeric constants.</returns>
142    public static double[] ExtractConstants(IEnumerable<ISymbolicExpressionTreeNode> nodes, bool addLinearScalingConstants) {
143      if (nodes == null) throw new ArgumentNullException("nodes");
144
145      var constants = new List<double>();
146      foreach (var node in nodes) {
147        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
148        VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase;
149        FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
150        if (constantTreeNode != null)
151          constants.Add(constantTreeNode.Value);
152        else if (variableTreeNodeBase != null)
153          constants.Add(variableTreeNodeBase.Weight);
154        else if (factorVarTreeNode != null) {
155          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
156            constants.Add(factorVarTreeNode.Weights[j]);
157        } else throw new NotSupportedException(string.Format("Nodes of type {0} are not supported.", node.GetType().GetPrettyName()));
158      }
159      if (addLinearScalingConstants) {
160        constants.Add(1.0);
161        constants.Add(0.0);
162      }
163      return constants.ToArray();
164    }
165
166    /// <summary>
167    /// Sets the numeric constants of the nodes to the provided values.
168    /// </summary>
169    /// <param name="nodes">The nodes whose constants should be updated.</param>
170    /// <param name="constants">The numeric constants which should be set. </param>
171    public static void UpdateConstants(IEnumerable<ISymbolicExpressionTreeNode> nodes, double[] constants) {
172      if (nodes == null) throw new ArgumentNullException("nodes");
173      if (constants == null) throw new ArgumentNullException("constants");
174
175      int i = 0;
176      foreach (var node in nodes) {
177        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
178        VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase;
179        FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
180        if (constantTreeNode != null)
181          constantTreeNode.Value = constants[i++];
182        else if (variableTreeNodeBase != null)
183          variableTreeNodeBase.Weight = constants[i++];
184        else if (factorVarTreeNode != null) {
185          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
186            factorVarTreeNode.Weights[j] = constants[i++];
187        } else throw new NotSupportedException(string.Format("Terminal nodes of type {0} are not supported.", node.GetType().GetPrettyName()));
188      }
189    }
190
191    /// <summary>
192    /// Sets all numeric constants of the symbolic expression tree to the provided values.
193    /// </summary>
194    /// <param name="tree">The tree for which the numeric constants should be updated.</param>
195    /// <param name="constants">The numeric constants which should be set.</param>
196    public static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) {
197      if (tree == null) throw new ArgumentNullException("tree");
198      if (constants == null) throw new ArgumentNullException("constants");
199      UpdateConstants(tree.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>(), constants);
200    }
201  }
202}
Note: See TracBrowser for help on using the repository browser.