Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SymbolicSimplifier.cs @ 4026

Last change on this file since 4026 was 3985, checked in by gkronber, 14 years ago

Fixed statements that modify the list of sub-trees of a SymbolicExpressionTreeNodes directly. #938

File size: 16.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
28using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
29using System.Diagnostics;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  /// <summary>
33  /// Simplistic simplifier for arithmetic expressions
34  /// Rules:
35  ///  * Constants are always the last argument to any function
36  ///  * f(c1, c2) => c3 (constant expression folding)
37  /// </summary>
38  public class SymbolicSimplifier {
39    private Addition addSymbol = new Addition();
40    private Multiplication mulSymbol = new Multiplication();
41    private Division divSymbol = new Division();
42    private Constant constSymbol = new Constant();
43    private Variable varSymbol = new Variable();
44
45    public SymbolicExpressionTree Simplify(SymbolicExpressionTree originalTree) {
46      var clone = (SymbolicExpressionTreeNode)originalTree.Root.Clone();
47      // macro expand (initially no argument trees)
48      var macroExpandedTree = MacroExpand(clone, clone.SubTrees[0], new List<SymbolicExpressionTreeNode>());
49      return new SymbolicExpressionTree(GetSimplifiedTree(macroExpandedTree));
50    }
51
52    // the argumentTrees list contains already expanded trees used as arguments for invocations
53    private SymbolicExpressionTreeNode MacroExpand(SymbolicExpressionTreeNode root, SymbolicExpressionTreeNode node, IList<SymbolicExpressionTreeNode> argumentTrees) {
54      List<SymbolicExpressionTreeNode> subtrees = new List<SymbolicExpressionTreeNode>(node.SubTrees);
55      while (node.SubTrees.Count > 0) node.RemoveSubTree(0);
56      if (node.Symbol is InvokeFunction) {
57        var invokeSym = node.Symbol as InvokeFunction;
58        var defunNode = FindFunctionDefinition(root, invokeSym.FunctionName);
59        var macroExpandedArguments = new List<SymbolicExpressionTreeNode>();
60        foreach (var subtree in subtrees) {
61          macroExpandedArguments.Add(MacroExpand(root, subtree, argumentTrees));
62        }
63        return MacroExpand(root, defunNode, macroExpandedArguments);
64      } else if (node.Symbol is Argument) {
65        var argSym = node.Symbol as Argument;
66        // return the correct argument sub-tree (already macro-expanded)
67        return (SymbolicExpressionTreeNode)argumentTrees[argSym.ArgumentIndex].Clone();
68      } else if (node.Symbol is StartSymbol) {
69        return MacroExpand(root, subtrees[0], argumentTrees);
70      } else {
71        // recursive application
72        foreach (var subtree in subtrees) {
73          node.AddSubTree(MacroExpand(root, subtree, argumentTrees));
74        }
75        return node;
76      }
77    }
78
79    private SymbolicExpressionTreeNode FindFunctionDefinition(SymbolicExpressionTreeNode root, string functionName) {
80      foreach (var subtree in root.SubTrees.OfType<DefunTreeNode>()) {
81        if (subtree.FunctionName == functionName) return subtree.SubTrees[0];
82      }
83
84      throw new ArgumentException("Definition of function " + functionName + " not found.");
85    }
86
87    /// <summary>
88    /// Creates a new simplified tree
89    /// </summary>
90    /// <param name="original"></param>
91    /// <returns></returns>
92    public SymbolicExpressionTreeNode GetSimplifiedTree(SymbolicExpressionTreeNode original) {
93      if (IsConstant(original) || IsVariable(original)) {
94        return (SymbolicExpressionTreeNode)original.Clone();
95      } else if (IsAddition(original)) {
96        if (original.SubTrees.Count == 1) {
97          return GetSimplifiedTree(original.SubTrees[0]);
98        } else {
99          // simplify expression x0..xn
100          // make addition (x0..xn)
101          Trace.Assert(original.SubTrees.Count > 1);
102          return original.SubTrees
103            .Select(x => GetSimplifiedTree(x))
104            .Aggregate((a, b) => MakeAddition(a, b));
105        }
106      } else if (IsSubtraction(original)) {
107        if (original.SubTrees.Count == 1) {
108          return Negate(GetSimplifiedTree(original.SubTrees[0]));
109        } else {
110          // simplify expressions x0..xn
111          // make addition (x0,-x1..-xn)
112          Trace.Assert(original.SubTrees.Count > 1);
113          var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
114          return simplifiedTrees.Take(1)
115            .Concat(simplifiedTrees.Skip(1).Select(x => Negate(x)))
116            .Aggregate((a, b) => MakeAddition(a, b));
117        }
118      } else if (IsMultiplication(original)) {
119        if (original.SubTrees.Count == 1) {
120          return GetSimplifiedTree(original.SubTrees[0]);
121        } else {
122          Trace.Assert(original.SubTrees.Count > 1);
123          return original.SubTrees
124            .Select(x => GetSimplifiedTree(x))
125            .Aggregate((a, b) => MakeMultiplication(a, b));
126        }
127      } else if (IsDivision(original)) {
128        if (original.SubTrees.Count == 1) {
129          return Invert(GetSimplifiedTree(original.SubTrees[0]));
130        } else {
131          // simplify expressions x0..xn
132          // make multiplication (x0 * 1/(x1 * x1 * .. * xn))
133          Trace.Assert(original.SubTrees.Count > 1);
134          var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
135          return
136            MakeMultiplication(simplifiedTrees.First(), Invert(simplifiedTrees.Skip(1).Aggregate((a, b) => MakeMultiplication(a, b))));
137        }
138      } else if (IsAverage(original)) {
139        if (original.SubTrees.Count == 1) {
140          return GetSimplifiedTree(original.SubTrees[0]);
141        } else {
142          // simpliy expressions x0..xn
143          // make sum(x0..xn) / n
144          Trace.Assert(original.SubTrees.Count > 1);
145          var sum = original.SubTrees
146            .Select(x => GetSimplifiedTree(x))
147            .Aggregate((a, b) => MakeAddition(a, b));
148          return MakeDivision(sum, MakeConstant(original.SubTrees.Count));
149        }
150      } else {
151        // can't simplify this function but simplify all subtrees
152        // TODO evaluate the function if original is a constant expression
153        List<SymbolicExpressionTreeNode> subTrees = new List<SymbolicExpressionTreeNode>(original.SubTrees);
154        while (original.SubTrees.Count > 0) original.RemoveSubTree(0);
155        var clone = (SymbolicExpressionTreeNode)original.Clone();
156        foreach (var subTree in subTrees) {
157          clone.AddSubTree(GetSimplifiedTree(subTree));
158          original.AddSubTree(subTree);
159        }
160        return clone;
161      }
162    }
163
164    /// <summary>
165    /// x => x * -1
166    /// Doesn't create new trees and manipulates x
167    /// </summary>
168    /// <param name="x"></param>
169    /// <returns>-x</returns>
170    private SymbolicExpressionTreeNode Negate(SymbolicExpressionTreeNode x) {
171      if (IsConstant(x)) {
172        ((ConstantTreeNode)x).Value *= -1;
173      } else if (IsVariable(x)) {
174        var variableTree = (VariableTreeNode)x;
175        variableTree.Weight *= -1.0;
176      } else if (IsAddition(x)) {
177        // (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)       
178        foreach (var subTree in x.SubTrees) {
179          Negate(subTree);
180        }
181      } else if (IsMultiplication(x) || IsDivision(x)) {
182        // x0 * x1 * .. * xn * -1 => x0 * x1 * .. * -xn
183        Negate(x.SubTrees.Last()); // last is maybe a constant, prefer to negate the constant
184      } else {
185        // any other function
186        return MakeMultiplication(x, MakeConstant(-1));
187      }
188      return x;
189    }
190
191    /// <summary>
192    /// x => 1/x
193    /// Doesn't create new trees and manipulates x
194    /// </summary>
195    /// <param name="x"></param>
196    /// <returns></returns>
197    private SymbolicExpressionTreeNode Invert(SymbolicExpressionTreeNode x) {
198      if (IsConstant(x)) {
199        ((ConstantTreeNode)x).Value = 1.0 / ((ConstantTreeNode)x).Value;
200      } else {
201        // any other function
202        return MakeDivision(MakeConstant(1), x);
203      }
204      return x;
205    }
206
207    private SymbolicExpressionTreeNode MakeDivision(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
208      if (IsConstant(a) && IsConstant(b)) {
209        return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
210      } else if (IsVariable(a) && IsConstant(b)) {
211        var constB = ((ConstantTreeNode)b).Value;
212        ((VariableTreeNode)a).Weight /= constB;
213        return a;
214      } else {
215        var div = divSymbol.CreateTreeNode();
216        div.AddSubTree(a);
217        div.AddSubTree(b);
218        return div;
219      }
220    }
221
222    private SymbolicExpressionTreeNode MakeAddition(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
223      if (IsConstant(a) && IsConstant(b)) {
224        // merge constants
225        ((ConstantTreeNode)a).Value += ((ConstantTreeNode)b).Value;
226        return a;
227      } else if (IsConstant(a)) {
228        // c + x => x + c
229        // b is not constant => make sure constant is on the right
230        return MakeAddition(b, a);
231      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(0.0)) {
232        // x + 0 => x
233        return a;
234      } else if (IsAddition(a) && IsAddition(b)) {
235        // merge additions
236        var add = addSymbol.CreateTreeNode();
237        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
238        for (int i = 0; i < b.SubTrees.Count - 1; i++) add.AddSubTree(b.SubTrees[i]);
239        if (IsConstant(a.SubTrees.Last()) && IsConstant(b.SubTrees.Last())) {
240          add.AddSubTree(MakeAddition(a.SubTrees.Last(), b.SubTrees.Last()));
241        } else if (IsConstant(a.SubTrees.Last())) {
242          add.AddSubTree(b.SubTrees.Last());
243          add.AddSubTree(a.SubTrees.Last());
244        } else {
245          add.AddSubTree(a.SubTrees.Last());
246          add.AddSubTree(b.SubTrees.Last());
247        }
248        MergeVariables(add);
249        return add;
250      } else if (IsAddition(b)) {
251        return MakeAddition(b, a);
252      } else if (IsAddition(a) && IsConstant(b)) {
253        var add = addSymbol.CreateTreeNode();
254        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
255        if (IsConstant(a.SubTrees.Last()))
256          add.AddSubTree(MakeAddition(a.SubTrees.Last(), b));
257        else {
258          add.AddSubTree(a.SubTrees.Last());
259          add.AddSubTree(b);
260        }
261        return add;
262      } else if (IsAddition(a)) {
263        var add = addSymbol.CreateTreeNode();
264        add.AddSubTree(b);
265        foreach (var subTree in a.SubTrees) {
266          add.AddSubTree(subTree);
267        }
268        MergeVariables(add);
269        return add;
270      } else {
271        var add = addSymbol.CreateTreeNode();
272        add.AddSubTree(a);
273        add.AddSubTree(b);
274        MergeVariables(add);
275        return add;
276      }
277    }
278
279    private void MergeVariables(SymbolicExpressionTreeNode add) {
280      var subtrees = new List<SymbolicExpressionTreeNode>(add.SubTrees);
281      while (add.SubTrees.Count > 0) add.RemoveSubTree(0);
282      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
283                            group node by node.VariableName into g
284                            select g;
285      var unchangedSubTrees = subtrees.Where(t => !(t is VariableTreeNode));
286
287      foreach (var variableNodeGroup in groupedVarNodes) {
288        var weightSum = variableNodeGroup.Select(t => t.Weight).Sum();
289        var representative = variableNodeGroup.First();
290        representative.Weight = weightSum;
291        add.AddSubTree(representative);
292      }
293      foreach (var unchangedSubtree in unchangedSubTrees)
294        add.AddSubTree(unchangedSubtree);
295    }
296
297    private SymbolicExpressionTreeNode MakeMultiplication(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
298      if (IsConstant(a) && IsConstant(b)) {
299        ((ConstantTreeNode)a).Value *= ((ConstantTreeNode)b).Value;
300        return a;
301      } else if (IsConstant(a)) {
302        return MakeMultiplication(b, a);
303      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
304        return a;
305      } else if (IsConstant(b) && IsVariable(a)) {
306        ((VariableTreeNode)a).Weight *= ((ConstantTreeNode)b).Value;
307        return a;
308      } else if (IsConstant(b) && IsAddition(a)) {
309        return a.SubTrees.Select(x => MakeMultiplication(x, b)).Aggregate((c, d) => MakeAddition(c, d));
310      } else if(IsDivision(a) && IsDivision(b)) {
311        return MakeDivision(MakeMultiplication(a.SubTrees[0], b.SubTrees[0]), MakeMultiplication(a.SubTrees[1], b.SubTrees[1]));
312      } else if (IsDivision(a)) {
313        Trace.Assert(a.SubTrees.Count == 2);
314        return MakeDivision(MakeMultiplication(a.SubTrees[0], b), a.SubTrees[1]);
315      } else if (IsDivision(b)) {
316        Trace.Assert(b.SubTrees.Count == 2);
317        return MakeDivision(MakeMultiplication(b.SubTrees[0], a), b.SubTrees[1]);
318      } else if (IsMultiplication(a) && IsMultiplication(b)) {
319        var mul = mulSymbol.CreateTreeNode();
320        for (int i = 0; i < a.SubTrees.Count - 1; i++) mul.AddSubTree(a.SubTrees[i]);
321        for (int i = 0; i < b.SubTrees.Count - 1; i++) mul.AddSubTree(b.SubTrees[i]);
322        mul.AddSubTree(MakeMultiplication(a.SubTrees.Last(), b.SubTrees.Last()));
323        return mul;
324      } else if (IsMultiplication(a)) {
325        var mul = mulSymbol.CreateTreeNode();
326        for (int i = 0; i < a.SubTrees.Count - 1; i++) mul.AddSubTree(a.SubTrees[i]);
327        mul.AddSubTree(MakeMultiplication(a.SubTrees.Last(), b));
328        return mul;
329      } else if (IsMultiplication(b)) {
330        var mul = mulSymbol.CreateTreeNode();
331        for (int i = 0; i < b.SubTrees.Count - 1; i++) mul.AddSubTree(b.SubTrees[i]);
332        mul.AddSubTree(MakeMultiplication(b.SubTrees.Last(), a));
333        return mul;
334      } else {
335        var mul = mulSymbol.CreateTreeNode();
336        mul.SubTrees.Add(a);
337        mul.SubTrees.Add(b);
338        return mul;
339      }
340    }
341
342    #region is symbol ?
343    private bool IsDivision(SymbolicExpressionTreeNode original) {
344      return original.Symbol is Division;
345    }
346
347    private bool IsMultiplication(SymbolicExpressionTreeNode original) {
348      return original.Symbol is Multiplication;
349    }
350
351    private bool IsSubtraction(SymbolicExpressionTreeNode original) {
352      return original.Symbol is Subtraction;
353    }
354
355    private bool IsAddition(SymbolicExpressionTreeNode original) {
356      return original.Symbol is Addition;
357    }
358
359    private bool IsVariable(SymbolicExpressionTreeNode original) {
360      return original.Symbol is Variable;
361    }
362
363    private bool IsConstant(SymbolicExpressionTreeNode original) {
364      return original.Symbol is Constant;
365    }
366
367    private bool IsAverage(SymbolicExpressionTreeNode original) {
368      return original.Symbol is Average;
369    }
370    #endregion
371
372    private SymbolicExpressionTreeNode MakeConstant(double value) {
373      ConstantTreeNode constantTreeNode = (ConstantTreeNode)(constSymbol.CreateTreeNode());
374      constantTreeNode.Value = value;
375      return (SymbolicExpressionTreeNode)constantTreeNode;
376    }
377
378    private SymbolicExpressionTreeNode MakeVariable(double weight, string name) {
379      var tree = (VariableTreeNode)varSymbol.CreateTreeNode();
380      tree.Weight = weight;
381      tree.VariableName = name;
382      return tree;
383    }
384  }
385}
Note: See TracBrowser for help on using the repository browser.