Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SymbolicSimplifier.cs @ 3901

Last change on this file since 3901 was 3876, checked in by gkronber, 15 years ago

Added support for simplification of average functions and improved simplification of division function. #1026.

File size: 16.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
28using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
29using System.Diagnostics;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  /// <summary>
33  /// Simplistic simplifier for arithmetic expressions
34  /// Rules:
35  ///  * Constants are always the last argument to any function
36  ///  * f(c1, c2) => c3 (constant expression folding)
37  ///  * c1 / ( c2 * Var) => ( var * ( c2 / c1))
38  /// </summary>
39  public class SymbolicSimplifier {
40    private Addition addSymbol = new Addition();
41    private Multiplication mulSymbol = new Multiplication();
42    private Division divSymbol = new Division();
43    private Constant constSymbol = new Constant();
44    private Variable varSymbol = new Variable();
45
46    public SymbolicExpressionTree Simplify(SymbolicExpressionTree originalTree) {
47      var clone = (SymbolicExpressionTreeNode)originalTree.Root.Clone();
48      // macro expand (initially no argument trees)
49      var macroExpandedTree = MacroExpand(clone, clone.SubTrees[0], new List<SymbolicExpressionTreeNode>());
50      return new SymbolicExpressionTree(GetSimplifiedTree(macroExpandedTree));
51    }
52
53    // the argumentTrees list contains already expanded trees used as arguments for invocations
54    private SymbolicExpressionTreeNode MacroExpand(SymbolicExpressionTreeNode root, SymbolicExpressionTreeNode node, IList<SymbolicExpressionTreeNode> argumentTrees) {
55      List<SymbolicExpressionTreeNode> subtrees = new List<SymbolicExpressionTreeNode>(node.SubTrees);
56      while (node.SubTrees.Count > 0) node.SubTrees.RemoveAt(0);
57      if (node.Symbol is InvokeFunction) {
58        var invokeSym = node.Symbol as InvokeFunction;
59        var defunNode = FindFunctionDefinition(root, invokeSym.FunctionName);
60        var macroExpandedArguments = new List<SymbolicExpressionTreeNode>();
61        foreach (var subtree in subtrees) {
62          macroExpandedArguments.Add(MacroExpand(root, subtree, argumentTrees));
63        }
64        return MacroExpand(root, defunNode, macroExpandedArguments);
65      } else if (node.Symbol is Argument) {
66        var argSym = node.Symbol as Argument;
67        // return the correct argument sub-tree (already macro-expanded)
68        return (SymbolicExpressionTreeNode)argumentTrees[argSym.ArgumentIndex].Clone();
69      } else if (node.Symbol is StartSymbol) {
70        return MacroExpand(root, subtrees[0], argumentTrees);
71      } else {
72        // recursive application
73        foreach (var subtree in subtrees) {
74          node.AddSubTree(MacroExpand(root, subtree, argumentTrees));
75        }
76        return node;
77      }
78    }
79
80    private SymbolicExpressionTreeNode FindFunctionDefinition(SymbolicExpressionTreeNode root, string functionName) {
81      foreach (var subtree in root.SubTrees.OfType<DefunTreeNode>()) {
82        if (subtree.FunctionName == functionName) return subtree.SubTrees[0];
83      }
84
85      throw new ArgumentException("Definition of function " + functionName + " not found.");
86    }
87
88    /// <summary>
89    /// Creates a new simplified tree
90    /// </summary>
91    /// <param name="original"></param>
92    /// <returns></returns>
93    public SymbolicExpressionTreeNode GetSimplifiedTree(SymbolicExpressionTreeNode original) {
94      if (IsConstant(original) || IsVariable(original)) {
95        return (SymbolicExpressionTreeNode)original.Clone();
96      } else if (IsAddition(original)) {
97        if (original.SubTrees.Count == 1) {
98          return GetSimplifiedTree(original.SubTrees[0]);
99        } else {
100          // simplify expression x0..xn
101          // make addition (x0..xn)
102          Trace.Assert(original.SubTrees.Count > 1);
103          return original.SubTrees
104            .Select(x => GetSimplifiedTree(x))
105            .Aggregate((a, b) => MakeAddition(a, b));
106        }
107      } else if (IsSubtraction(original)) {
108        if (original.SubTrees.Count == 1) {
109          return Negate(GetSimplifiedTree(original.SubTrees[0]));
110        } else {
111          // simplify expressions x0..xn
112          // make addition (x0,-x1..-xn)
113          Trace.Assert(original.SubTrees.Count > 1);
114          var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
115          return simplifiedTrees.Take(1)
116            .Concat(simplifiedTrees.Skip(1).Select(x => Negate(x)))
117            .Aggregate((a, b) => MakeAddition(a, b));
118        }
119      } else if (IsMultiplication(original)) {
120        if (original.SubTrees.Count == 1) {
121          return GetSimplifiedTree(original.SubTrees[0]);
122        } else {
123          Trace.Assert(original.SubTrees.Count > 1);
124          return original.SubTrees
125            .Select(x => GetSimplifiedTree(x))
126            .Aggregate((a, b) => MakeMultiplication(a, b));
127        }
128      } else if (IsDivision(original)) {
129        if (original.SubTrees.Count == 1) {
130          return Invert(GetSimplifiedTree(original.SubTrees[0]));
131        } else {
132          // simplify expressions x0..xn
133          // make multiplication (x0 * 1/(x1 * x1 * .. * xn))
134          Trace.Assert(original.SubTrees.Count > 1);
135          var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
136          return
137            MakeMultiplication(simplifiedTrees.First(), Invert(simplifiedTrees.Skip(1).Aggregate((a, b) => MakeMultiplication(a, b))));
138        }
139      } else if (IsAverage(original)) {
140        if (original.SubTrees.Count == 1) {
141          return GetSimplifiedTree(original.SubTrees[0]);
142        } else {
143          // simpliy expressions x0..xn
144          // make sum(x0..xn) / n
145          Trace.Assert(original.SubTrees.Count > 1);
146          var sum = original.SubTrees
147            .Select(x => GetSimplifiedTree(x))
148            .Aggregate((a, b) => MakeAddition(a, b));
149          return MakeDivision(sum, MakeConstant(original.SubTrees.Count));
150        }
151      } else {
152        // can't simplify this function but simplify all subtrees
153        // TODO evaluate the function if original is a constant expression
154        List<SymbolicExpressionTreeNode> subTrees = new List<SymbolicExpressionTreeNode>(original.SubTrees);
155        while (original.SubTrees.Count > 0) original.RemoveSubTree(0);
156        var clone = (SymbolicExpressionTreeNode)original.Clone();
157        foreach (var subTree in subTrees) {
158          clone.AddSubTree(GetSimplifiedTree(subTree));
159          original.AddSubTree(subTree);
160        }
161        return clone;
162      }
163    }
164
165    /// <summary>
166    /// x => x * -1
167    /// Doesn't create new trees and manipulates x
168    /// </summary>
169    /// <param name="x"></param>
170    /// <returns>-x</returns>
171    private SymbolicExpressionTreeNode Negate(SymbolicExpressionTreeNode x) {
172      if (IsConstant(x)) {
173        ((ConstantTreeNode)x).Value *= -1;
174      } else if (IsVariable(x)) {
175        var variableTree = (VariableTreeNode)x;
176        variableTree.Weight *= -1.0;
177      } else if (IsAddition(x)) {
178        // (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)       
179        foreach (var subTree in x.SubTrees) {
180          Negate(subTree);
181        }
182      } else if (IsMultiplication(x) || IsDivision(x)) {
183        // x0 * x1 * .. * xn * -1 => x0 * x1 * .. * -xn
184        Negate(x.SubTrees.Last()); // last is maybe a constant, prefer to negate the constant
185      } else {
186        // any other function
187        return MakeMultiplication(x, MakeConstant(-1));
188      }
189      return x;
190    }
191
192    /// <summary>
193    /// x => 1/x
194    /// Doesn't create new trees and manipulates x
195    /// </summary>
196    /// <param name="x"></param>
197    /// <returns></returns>
198    private SymbolicExpressionTreeNode Invert(SymbolicExpressionTreeNode x) {
199      if (IsConstant(x)) {
200        ((ConstantTreeNode)x).Value = 1.0 / ((ConstantTreeNode)x).Value;
201      } else {
202        // any other function
203        return MakeDivision(MakeConstant(1), x);
204      }
205      return x;
206    }
207
208    private SymbolicExpressionTreeNode MakeDivision(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
209      if (IsConstant(a) && IsConstant(b)) {
210        return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
211      } else if (IsVariable(a) && IsConstant(b)) {
212        var constB = ((ConstantTreeNode)b).Value;
213        ((VariableTreeNode)a).Weight /= constB;
214        return a;
215      } else {
216        var div = divSymbol.CreateTreeNode();
217        div.SubTrees.Add(a);
218        div.SubTrees.Add(b);
219        return div;
220      }
221    }
222
223    private SymbolicExpressionTreeNode MakeAddition(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
224      if (IsConstant(a) && IsConstant(b)) {
225        // merge constants
226        ((ConstantTreeNode)a).Value += ((ConstantTreeNode)b).Value;
227        return a;
228      } else if (IsConstant(a)) {
229        // c + x => x + c
230        // b is not constant => make sure constant is on the right
231        return MakeAddition(b, a);
232      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(0.0)) {
233        // x + 0 => x
234        return a;
235      } else if (IsAddition(a) && IsAddition(b)) {
236        // merge additions
237        var add = addSymbol.CreateTreeNode();
238        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
239        for (int i = 0; i < b.SubTrees.Count - 1; i++) add.AddSubTree(b.SubTrees[i]);
240        if (IsConstant(a.SubTrees.Last()) && IsConstant(b.SubTrees.Last())) {
241          add.AddSubTree(MakeAddition(a.SubTrees.Last(), b.SubTrees.Last()));
242        } else if (IsConstant(a.SubTrees.Last())) {
243          add.AddSubTree(b.SubTrees.Last());
244          add.AddSubTree(a.SubTrees.Last());
245        } else {
246          add.AddSubTree(a.SubTrees.Last());
247          add.AddSubTree(b.SubTrees.Last());
248        }
249        MergeVariables(add);
250        return add;
251      } else if (IsAddition(b)) {
252        return MakeAddition(b, a);
253      } else if (IsAddition(a) && IsConstant(b)) {
254        var add = addSymbol.CreateTreeNode();
255        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
256        if (IsConstant(a.SubTrees.Last()))
257          add.AddSubTree(MakeAddition(a.SubTrees.Last(), b));
258        else {
259          add.AddSubTree(a.SubTrees.Last());
260          add.AddSubTree(b);
261        }
262        return add;
263      } else if (IsAddition(a)) {
264        var add = addSymbol.CreateTreeNode();
265        add.AddSubTree(b);
266        foreach (var subTree in a.SubTrees) {
267          add.AddSubTree(subTree);
268        }
269        MergeVariables(add);
270        return add;
271      } else {
272        var add = addSymbol.CreateTreeNode();
273        add.SubTrees.Add(a);
274        add.SubTrees.Add(b);
275        MergeVariables(add);
276        return add;
277      }
278    }
279
280    private void MergeVariables(SymbolicExpressionTreeNode add) {
281      var subtrees = new List<SymbolicExpressionTreeNode>(add.SubTrees);
282      while (add.SubTrees.Count > 0) add.RemoveSubTree(0);
283      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
284                            group node by node.VariableName into g
285                            select g;
286      var unchangedSubTrees = subtrees.Where(t => !(t is VariableTreeNode));
287
288      foreach (var variableNodeGroup in groupedVarNodes) {
289        var weightSum = variableNodeGroup.Select(t => t.Weight).Sum();
290        var representative = variableNodeGroup.First();
291        representative.Weight = weightSum;
292        add.AddSubTree(representative);
293      }
294      foreach (var unchangedSubtree in unchangedSubTrees)
295        add.AddSubTree(unchangedSubtree);
296    }
297
298    private SymbolicExpressionTreeNode MakeMultiplication(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
299      if (IsConstant(a) && IsConstant(b)) {
300        ((ConstantTreeNode)a).Value *= ((ConstantTreeNode)b).Value;
301        return a;
302      } else if (IsConstant(a)) {
303        return MakeMultiplication(b, a);
304      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
305        return a;
306      } else if (IsConstant(b) && IsVariable(a)) {
307        ((VariableTreeNode)a).Weight *= ((ConstantTreeNode)b).Value;
308        return a;
309      } else if (IsConstant(b) && IsAddition(a)) {
310        return a.SubTrees.Select(x => MakeMultiplication(x, b)).Aggregate((c, d) => MakeAddition(c, d));
311      } else if(IsDivision(a) && IsDivision(b)) {
312        return MakeDivision(MakeMultiplication(a.SubTrees[0], b.SubTrees[0]), MakeMultiplication(a.SubTrees[1], b.SubTrees[1]));
313      } else if (IsDivision(a)) {
314        Trace.Assert(a.SubTrees.Count == 2);
315        return MakeDivision(MakeMultiplication(a.SubTrees[0], b), a.SubTrees[1]);
316      } else if (IsDivision(b)) {
317        Trace.Assert(b.SubTrees.Count == 2);
318        return MakeDivision(MakeMultiplication(b.SubTrees[0], a), b.SubTrees[1]);
319      } else if (IsMultiplication(a) && IsMultiplication(b)) {
320        var mul = mulSymbol.CreateTreeNode();
321        for (int i = 0; i < a.SubTrees.Count - 1; i++) mul.AddSubTree(a.SubTrees[i]);
322        for (int i = 0; i < b.SubTrees.Count - 1; i++) mul.AddSubTree(b.SubTrees[i]);
323        mul.AddSubTree(MakeMultiplication(a.SubTrees.Last(), b.SubTrees.Last()));
324        return mul;
325      } else if (IsMultiplication(a)) {
326        var mul = mulSymbol.CreateTreeNode();
327        for (int i = 0; i < a.SubTrees.Count - 1; i++) mul.AddSubTree(a.SubTrees[i]);
328        mul.AddSubTree(MakeMultiplication(a.SubTrees.Last(), b));
329        return mul;
330      } else if (IsMultiplication(b)) {
331        var mul = mulSymbol.CreateTreeNode();
332        for (int i = 0; i < b.SubTrees.Count - 1; i++) mul.AddSubTree(b.SubTrees[i]);
333        mul.AddSubTree(MakeMultiplication(b.SubTrees.Last(), a));
334        return mul;
335      } else {
336        var mul = mulSymbol.CreateTreeNode();
337        mul.SubTrees.Add(a);
338        mul.SubTrees.Add(b);
339        return mul;
340      }
341    }
342
343    #region is symbol ?
344    private bool IsDivision(SymbolicExpressionTreeNode original) {
345      return original.Symbol is Division;
346    }
347
348    private bool IsMultiplication(SymbolicExpressionTreeNode original) {
349      return original.Symbol is Multiplication;
350    }
351
352    private bool IsSubtraction(SymbolicExpressionTreeNode original) {
353      return original.Symbol is Subtraction;
354    }
355
356    private bool IsAddition(SymbolicExpressionTreeNode original) {
357      return original.Symbol is Addition;
358    }
359
360    private bool IsVariable(SymbolicExpressionTreeNode original) {
361      return original.Symbol is Variable;
362    }
363
364    private bool IsConstant(SymbolicExpressionTreeNode original) {
365      return original.Symbol is Constant;
366    }
367
368    private bool IsAverage(SymbolicExpressionTreeNode original) {
369      return original.Symbol is Average;
370    }
371    #endregion
372
373    private SymbolicExpressionTreeNode MakeConstant(double value) {
374      ConstantTreeNode constantTreeNode = (ConstantTreeNode)(constSymbol.CreateTreeNode());
375      constantTreeNode.Value = value;
376      return (SymbolicExpressionTreeNode)constantTreeNode;
377    }
378
379    private SymbolicExpressionTreeNode MakeVariable(double weight, string name) {
380      var tree = (VariableTreeNode)varSymbol.CreateTreeNode();
381      tree.Weight = weight;
382      tree.VariableName = name;
383      return tree;
384    }
385  }
386}
Note: See TracBrowser for help on using the repository browser.