Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SymbolicSimplifier.cs @ 3733

Last change on this file since 3733 was 3494, checked in by gkronber, 14 years ago

Changed SymbolicSimplifier to remove StartSymbols. #938 (Data types and operators for regression problems)

File size: 15.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
28using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
29using System.Diagnostics;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  /// <summary>
33  /// Simplistic simplifier for arithmetic expressions
34  /// Rules:
35  ///  * Constants are always the last argument to any function
36  ///  * f(c1, c2) => c3 (constant expression folding)
37  ///  * c1 / ( c2 * Var) => ( var * ( c2 / c1))
38  /// </summary>
39  public class SymbolicSimplifier {
40    private Addition addSymbol = new Addition();
41    private Multiplication mulSymbol = new Multiplication();
42    private Division divSymbol = new Division();
43    private Constant constSymbol = new Constant();
44    private Variable varSymbol = new Variable();
45
46    public SymbolicExpressionTree Simplify(SymbolicExpressionTree originalTree) {
47      var clone = (SymbolicExpressionTreeNode)originalTree.Root.Clone();
48      // macro expand (initially no argument trees)
49      var macroExpandedTree = MacroExpand(clone, clone.SubTrees[0], new List<SymbolicExpressionTreeNode>());
50      return new SymbolicExpressionTree(GetSimplifiedTree(macroExpandedTree));
51    }
52
53    // the argumentTrees list contains already expanded trees used as arguments for invocations
54    private SymbolicExpressionTreeNode MacroExpand(SymbolicExpressionTreeNode root, SymbolicExpressionTreeNode node, IList<SymbolicExpressionTreeNode> argumentTrees) {
55      List<SymbolicExpressionTreeNode> subtrees = new List<SymbolicExpressionTreeNode>(node.SubTrees);
56      while (node.SubTrees.Count > 0) node.SubTrees.RemoveAt(0);
57      if (node.Symbol is InvokeFunction) {
58        var invokeSym = node.Symbol as InvokeFunction;
59        var defunNode = FindFunctionDefinition(root, invokeSym.FunctionName);
60        var macroExpandedArguments = new List<SymbolicExpressionTreeNode>();
61        foreach (var subtree in subtrees) {
62          macroExpandedArguments.Add(MacroExpand(root, subtree, argumentTrees));
63        }
64        return MacroExpand(root, defunNode, macroExpandedArguments);
65      } else if (node.Symbol is Argument) {
66        var argSym = node.Symbol as Argument;
67        // return the correct argument sub-tree (already macro-expanded)
68        return (SymbolicExpressionTreeNode)argumentTrees[argSym.ArgumentIndex].Clone();
69      } else if(node.Symbol is StartSymbol) {
70        return MacroExpand(root, subtrees[0], argumentTrees);
71      } else {
72        // recursive application
73        foreach (var subtree in subtrees) {
74          node.AddSubTree(MacroExpand(root, subtree, argumentTrees));
75        }
76        return node;
77      }
78    }
79
80    private SymbolicExpressionTreeNode FindFunctionDefinition(SymbolicExpressionTreeNode root, string functionName) {
81      foreach (var subtree in root.SubTrees.OfType<DefunTreeNode>()) {
82        if (subtree.FunctionName == functionName) return subtree.SubTrees[0];
83      }
84
85      throw new ArgumentException("Definition of function " + functionName + " not found.");
86    }
87
88    /// <summary>
89    /// Creates a new simplified tree
90    /// </summary>
91    /// <param name="original"></param>
92    /// <returns></returns>
93    public SymbolicExpressionTreeNode GetSimplifiedTree(SymbolicExpressionTreeNode original) {
94      if (IsConstant(original) || IsVariable(original)) {
95        return (SymbolicExpressionTreeNode)original.Clone();
96      } else if (IsAddition(original)) {
97        if (original.SubTrees.Count == 1) {
98          return GetSimplifiedTree(original.SubTrees[0]);
99        } else {
100          // simplify expression x0..xn
101          // make addition (x0..xn)
102          Trace.Assert(original.SubTrees.Count > 1);
103          return original.SubTrees
104            .Select(x => GetSimplifiedTree(x))
105            .Aggregate((a, b) => MakeAddition(a, b));
106        }
107      } else if (IsSubtraction(original)) {
108        if (original.SubTrees.Count == 1) {
109          return Negate(GetSimplifiedTree(original.SubTrees[0]));
110        } else {
111          // simplify expressions x0..xn
112          // make addition (x0,-x1..-xn)
113          Trace.Assert(original.SubTrees.Count > 1);
114          var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
115          return simplifiedTrees.Take(1)
116            .Concat(simplifiedTrees.Skip(1).Select(x => Negate(x)))
117            .Aggregate((a, b) => MakeAddition(a, b));
118        }
119      } else if (IsMultiplication(original)) {
120        if (original.SubTrees.Count == 1) {
121          return GetSimplifiedTree(original.SubTrees[0]);
122        } else {
123          Trace.Assert(original.SubTrees.Count > 1);
124          return original.SubTrees
125            .Select(x => GetSimplifiedTree(x))
126            .Aggregate((a, b) => MakeMultiplication(a, b));
127        }
128      } else if (IsDivision(original)) {
129        if (original.SubTrees.Count == 1) {
130          return Invert(GetSimplifiedTree(original.SubTrees[0]));
131        } else {
132          // simplify expressions x0..xn
133          // make multiplication (x0 * 1/(x1 * x1 * .. * xn))
134          Trace.Assert(original.SubTrees.Count > 1);
135          var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
136          return
137            MakeMultiplication(simplifiedTrees.First(), Invert(simplifiedTrees.Skip(1).Aggregate((a, b) => MakeMultiplication(a, b))));
138        }
139      } else {
140        // can't simplify this function but simplify all subtrees
141        // TODO evaluate the function if original is a constant expression
142        List<SymbolicExpressionTreeNode> subTrees = new List<SymbolicExpressionTreeNode>(original.SubTrees);
143        while (original.SubTrees.Count > 0) original.RemoveSubTree(0);
144        var clone = (SymbolicExpressionTreeNode)original.Clone();
145        foreach (var subTree in subTrees) {
146          clone.AddSubTree(GetSimplifiedTree(subTree));
147          original.AddSubTree(subTree);
148        }
149        return clone;
150      }
151    }
152
153    /// <summary>
154    /// x => x * -1
155    /// Doesn't create new trees and manipulates x
156    /// </summary>
157    /// <param name="x"></param>
158    /// <returns>-x</returns>
159    private SymbolicExpressionTreeNode Negate(SymbolicExpressionTreeNode x) {
160      if (IsConstant(x)) {
161        ((ConstantTreeNode)x).Value *= -1;
162      } else if (IsVariable(x)) {
163        var variableTree = (VariableTreeNode)x;
164        variableTree.Weight *= -1.0;
165      } else if (IsAddition(x)) {
166        // (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)       
167        foreach (var subTree in x.SubTrees) {
168          Negate(subTree);
169        }
170      } else if (IsMultiplication(x) || IsDivision(x)) {
171        // x0 * x1 * .. * xn * -1 => x0 * x1 * .. * -xn
172        Negate(x.SubTrees.Last()); // last is maybe a constant, prefer to negate the constant
173      } else {
174        // any other function
175        return MakeMultiplication(x, MakeConstant(-1));
176      }
177      return x;
178    }
179
180    /// <summary>
181    /// x => 1/x
182    /// Doesn't create new trees and manipulates x
183    /// </summary>
184    /// <param name="x"></param>
185    /// <returns></returns>
186    private SymbolicExpressionTreeNode Invert(SymbolicExpressionTreeNode x) {
187      if (IsConstant(x)) {
188        ((ConstantTreeNode)x).Value = 1.0 / ((ConstantTreeNode)x).Value;
189      } else {
190        // any other function
191        return MakeDivision(MakeConstant(1), x);
192      }
193      return x;
194    }
195
196    private SymbolicExpressionTreeNode MakeDivision(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
197      if (IsConstant(a) && IsConstant(b)) {
198        return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
199      } else if (IsVariable(a) && IsConstant(b)) {
200        var constB = ((ConstantTreeNode)b).Value;
201        ((VariableTreeNode)a).Weight /= constB;
202        return a;
203      } else {
204        var div = divSymbol.CreateTreeNode();
205        div.SubTrees.Add(a);
206        div.SubTrees.Add(b);
207        return div;
208      }
209    }
210
211    private SymbolicExpressionTreeNode MakeAddition(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
212      if (IsConstant(a) && IsConstant(b)) {
213        // merge constants
214        ((ConstantTreeNode)a).Value += ((ConstantTreeNode)b).Value;
215        return a;
216      } else if (IsConstant(a)) {
217        // c + x => x + c
218        // b is not constant => make sure constant is on the right
219        return MakeAddition(b, a);
220      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(0.0)) {
221        // x + 0 => x
222        return a;
223      } else if (IsAddition(a) && IsAddition(b)) {
224        // merge additions
225        var add = addSymbol.CreateTreeNode();
226        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
227        for (int i = 0; i < b.SubTrees.Count - 1; i++) add.AddSubTree(b.SubTrees[i]);
228        if (IsConstant(a.SubTrees.Last()) && IsConstant(b.SubTrees.Last())) {
229          add.AddSubTree(MakeAddition(a.SubTrees.Last(), b.SubTrees.Last()));
230        } else if (IsConstant(a.SubTrees.Last())) {
231          add.AddSubTree(b.SubTrees.Last());
232          add.AddSubTree(a.SubTrees.Last());
233        } else {
234          add.AddSubTree(a.SubTrees.Last());
235          add.AddSubTree(b.SubTrees.Last());
236        }
237        MergeVariables(add);
238        return add;
239      } else if (IsAddition(b)) {
240        return MakeAddition(b, a);
241      } else if (IsAddition(a) && IsConstant(b)) {
242        var add = addSymbol.CreateTreeNode();
243        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
244        if (IsConstant(a.SubTrees.Last()))
245          add.AddSubTree(MakeAddition(a.SubTrees.Last(), b));
246        else {
247          add.AddSubTree(a.SubTrees.Last());
248          add.AddSubTree(b);
249        }
250        return add;
251      } else if (IsAddition(a)) {
252        var add = addSymbol.CreateTreeNode();
253        add.AddSubTree(b);
254        foreach (var subTree in a.SubTrees) {
255          add.AddSubTree(subTree);
256        }
257        MergeVariables(add);
258        return add;
259      } else {
260        var add = addSymbol.CreateTreeNode();
261        add.SubTrees.Add(a);
262        add.SubTrees.Add(b);
263        MergeVariables(add);
264        return add;
265      }
266    }
267
268    private void MergeVariables(SymbolicExpressionTreeNode add) {
269      var subtrees = new List<SymbolicExpressionTreeNode>(add.SubTrees);
270      while (add.SubTrees.Count > 0) add.RemoveSubTree(0);
271      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
272                            group node by node.VariableName into g
273                            select g;
274      var unchangedSubTrees = subtrees.Where(t => !(t is VariableTreeNode));
275
276      foreach (var variableNodeGroup in groupedVarNodes) {
277        var weightSum = variableNodeGroup.Select(t => t.Weight).Sum();
278        var representative = variableNodeGroup.First();
279        representative.Weight = weightSum;
280        add.AddSubTree(representative);
281      }
282      foreach (var unchangedSubtree in unchangedSubTrees)
283        add.AddSubTree(unchangedSubtree);
284    }
285
286    private SymbolicExpressionTreeNode MakeMultiplication(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
287      if (IsConstant(a) && IsConstant(b)) {
288        ((ConstantTreeNode)a).Value *= ((ConstantTreeNode)b).Value;
289        return a;
290      } else if (IsConstant(a)) {
291        return MakeMultiplication(b, a);
292      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
293        return a;
294      } else if (IsConstant(b) && IsVariable(a)) {
295        ((VariableTreeNode)a).Weight *= ((ConstantTreeNode)b).Value;
296        return a;
297      } else if (IsConstant(b) && IsAddition(a)) {
298        return a.SubTrees.Select(x => MakeMultiplication(x, b)).Aggregate((c, d) => MakeAddition(c, d));
299      } else if (IsDivision(a)) {
300        Trace.Assert(a.SubTrees.Count == 2);
301        return MakeDivision(MakeMultiplication(a.SubTrees[0], b), a.SubTrees[1]);
302      } else if (IsDivision(b)) {
303        Trace.Assert(b.SubTrees.Count == 2);
304        return MakeDivision(MakeMultiplication(b.SubTrees[0], a), b.SubTrees[1]);
305      } else if (IsMultiplication(a) && IsMultiplication(b)) {
306        var mul = mulSymbol.CreateTreeNode();
307        for (int i = 0; i < a.SubTrees.Count - 1; i++) mul.AddSubTree(a.SubTrees[i]);
308        for (int i = 0; i < b.SubTrees.Count - 1; i++) mul.AddSubTree(b.SubTrees[i]);
309        mul.AddSubTree(MakeMultiplication(a.SubTrees.Last(), b.SubTrees.Last()));
310        return mul;
311      } else if (IsMultiplication(a)) {
312        var mul = mulSymbol.CreateTreeNode();
313        for (int i = 0; i < a.SubTrees.Count - 1; i++) mul.AddSubTree(a.SubTrees[i]);
314        mul.AddSubTree(MakeMultiplication(a.SubTrees.Last(), b));
315        return mul;
316      } else if (IsMultiplication(b)) {
317        var mul = mulSymbol.CreateTreeNode();
318        for (int i = 0; i < b.SubTrees.Count - 1; i++) mul.AddSubTree(b.SubTrees[i]);
319        mul.AddSubTree(MakeMultiplication(b.SubTrees.Last(), a));
320        return mul;
321      } else {
322        var mul = mulSymbol.CreateTreeNode();
323        mul.SubTrees.Add(a);
324        mul.SubTrees.Add(b);
325        return mul;
326      }
327    }
328
329    #region is symbol ?
330    private bool IsDivision(SymbolicExpressionTreeNode original) {
331      return original.Symbol is Division;
332    }
333
334    private bool IsMultiplication(SymbolicExpressionTreeNode original) {
335      return original.Symbol is Multiplication;
336    }
337
338    private bool IsSubtraction(SymbolicExpressionTreeNode original) {
339      return original.Symbol is Subtraction;
340    }
341
342    private bool IsAddition(SymbolicExpressionTreeNode original) {
343      return original.Symbol is Addition;
344    }
345
346    private bool IsVariable(SymbolicExpressionTreeNode original) {
347      return original.Symbol is Variable;
348    }
349
350    private bool IsConstant(SymbolicExpressionTreeNode original) {
351      return original.Symbol is Constant;
352    }
353    #endregion
354
355    private SymbolicExpressionTreeNode MakeConstant(double value) {
356      ConstantTreeNode constantTreeNode = (ConstantTreeNode)(constSymbol.CreateTreeNode());
357      constantTreeNode.Value = value;
358      return (SymbolicExpressionTreeNode)constantTreeNode;
359    }
360
361    private SymbolicExpressionTreeNode MakeVariable(double weight, string name) {
362      var tree = (VariableTreeNode)varSymbol.CreateTreeNode();
363      tree.Weight = weight;
364      tree.VariableName = name;
365      return tree;
366    }
367  }
368}
Note: See TracBrowser for help on using the repository browser.