Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SymbolicSimplifier.cs @ 4220

Last change on this file since 4220 was 4220, checked in by gkronber, 14 years ago

Added better support for simplification of fractions and products and cleaned code a little bit. #1026

File size: 21.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
29using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  /// <summary>
33  /// Simplistic simplifier for arithmetic expressions
34  /// </summary>
35  public class SymbolicSimplifier {
36    private Addition addSymbol = new Addition();
37    private Multiplication mulSymbol = new Multiplication();
38    private Division divSymbol = new Division();
39    private Constant constSymbol = new Constant();
40    private Variable varSymbol = new Variable();
41
42    public SymbolicExpressionTree Simplify(SymbolicExpressionTree originalTree) {
43      var clone = (SymbolicExpressionTreeNode)originalTree.Root.Clone();
44      // macro expand (initially no argument trees)
45      var macroExpandedTree = MacroExpand(clone, clone.SubTrees[0], new List<SymbolicExpressionTreeNode>());
46      return new SymbolicExpressionTree(GetSimplifiedTree(macroExpandedTree));
47    }
48
49    // the argumentTrees list contains already expanded trees used as arguments for invocations
50    private SymbolicExpressionTreeNode MacroExpand(SymbolicExpressionTreeNode root, SymbolicExpressionTreeNode node, IList<SymbolicExpressionTreeNode> argumentTrees) {
51      List<SymbolicExpressionTreeNode> subtrees = new List<SymbolicExpressionTreeNode>(node.SubTrees);
52      while (node.SubTrees.Count > 0) node.RemoveSubTree(0);
53      if (node.Symbol is InvokeFunction) {
54        var invokeSym = node.Symbol as InvokeFunction;
55        var defunNode = FindFunctionDefinition(root, invokeSym.FunctionName);
56        var macroExpandedArguments = new List<SymbolicExpressionTreeNode>();
57        foreach (var subtree in subtrees) {
58          macroExpandedArguments.Add(MacroExpand(root, subtree, argumentTrees));
59        }
60        return MacroExpand(root, defunNode, macroExpandedArguments);
61      } else if (node.Symbol is Argument) {
62        var argSym = node.Symbol as Argument;
63        // return the correct argument sub-tree (already macro-expanded)
64        return (SymbolicExpressionTreeNode)argumentTrees[argSym.ArgumentIndex].Clone();
65      } else if (node.Symbol is StartSymbol) {
66        return MacroExpand(root, subtrees[0], argumentTrees);
67      } else {
68        // recursive application
69        foreach (var subtree in subtrees) {
70          node.AddSubTree(MacroExpand(root, subtree, argumentTrees));
71        }
72        return node;
73      }
74    }
75
76    private SymbolicExpressionTreeNode FindFunctionDefinition(SymbolicExpressionTreeNode root, string functionName) {
77      foreach (var subtree in root.SubTrees.OfType<DefunTreeNode>()) {
78        if (subtree.FunctionName == functionName) return subtree.SubTrees[0];
79      }
80
81      throw new ArgumentException("Definition of function " + functionName + " not found.");
82    }
83
84
85    #region symbol predicates
86    private bool IsDivision(SymbolicExpressionTreeNode original) {
87      return original.Symbol is Division;
88    }
89
90    private bool IsMultiplication(SymbolicExpressionTreeNode original) {
91      return original.Symbol is Multiplication;
92    }
93
94    private bool IsSubtraction(SymbolicExpressionTreeNode original) {
95      return original.Symbol is Subtraction;
96    }
97
98    private bool IsAddition(SymbolicExpressionTreeNode original) {
99      return original.Symbol is Addition;
100    }
101
102    private bool IsVariable(SymbolicExpressionTreeNode original) {
103      return original.Symbol is Variable;
104    }
105
106    private bool IsConstant(SymbolicExpressionTreeNode original) {
107      return original.Symbol is Constant;
108    }
109
110    private bool IsAverage(SymbolicExpressionTreeNode original) {
111      return original.Symbol is Average;
112    }
113    private bool IsLog(SymbolicExpressionTreeNode original) {
114      return original.Symbol is Logarithm;
115    }
116    #endregion
117
118    /// <summary>
119    /// Creates a new simplified tree
120    /// </summary>
121    /// <param name="original"></param>
122    /// <returns></returns>
123    public SymbolicExpressionTreeNode GetSimplifiedTree(SymbolicExpressionTreeNode original) {
124      if (IsConstant(original) || IsVariable(original)) {
125        return (SymbolicExpressionTreeNode)original.Clone();
126      } else if (IsAddition(original)) {
127        return SimplifyAddition(original);
128      } else if (IsSubtraction(original)) {
129        return SimplifySubtraction(original);
130      } else if (IsMultiplication(original)) {
131        return SimplifyMultiplication(original);
132      } else if (IsDivision(original)) {
133        return SimplifyDivision(original);
134      } else if (IsAverage(original)) {
135        return SimplifyAverage(original);
136      } else if (IsLog(original)) {
137        // TODO simplify logarditm
138        return SimplifyAny(original);
139      } else if (IsAverage(original)) {
140        return SimplifyAverage(original);
141      } else {
142        return SimplifyAny(original);
143      }
144    }
145
146    #region specific simplification routines
147    private SymbolicExpressionTreeNode SimplifyAny(SymbolicExpressionTreeNode original) {
148      // can't simplify this function but simplify all subtrees
149      List<SymbolicExpressionTreeNode> subTrees = new List<SymbolicExpressionTreeNode>(original.SubTrees);
150      while (original.SubTrees.Count > 0) original.RemoveSubTree(0);
151      var clone = (SymbolicExpressionTreeNode)original.Clone();
152      List<SymbolicExpressionTreeNode> simplifiedSubTrees = new List<SymbolicExpressionTreeNode>();
153      foreach (var subTree in subTrees) {
154        simplifiedSubTrees.Add(GetSimplifiedTree(subTree));
155        original.AddSubTree(subTree);
156      }
157      foreach (var simplifiedSubtree in simplifiedSubTrees) {
158        clone.AddSubTree(simplifiedSubtree);
159      }
160      if (simplifiedSubTrees.TrueForAll(t => IsConstant(t))) {
161        SimplifyConstantExpression(clone);
162      }
163      return clone;
164    }
165
166    private SymbolicExpressionTreeNode SimplifyConstantExpression(SymbolicExpressionTreeNode original) {
167      // not yet implemented
168      return original;
169    }
170
171    private SymbolicExpressionTreeNode SimplifyAverage(SymbolicExpressionTreeNode original) {
172      if (original.SubTrees.Count == 1) {
173        return GetSimplifiedTree(original.SubTrees[0]);
174      } else {
175        // simplify expressions x0..xn
176        // make sum(x0..xn) / n
177        Trace.Assert(original.SubTrees.Count > 1);
178        var sum = original.SubTrees
179          .Select(x => GetSimplifiedTree(x))
180          .Aggregate((a, b) => MakeSum(a, b));
181        return MakeFraction(sum, MakeConstant(original.SubTrees.Count));
182      }
183    }
184
185    private SymbolicExpressionTreeNode SimplifyDivision(SymbolicExpressionTreeNode original) {
186      if (original.SubTrees.Count == 1) {
187        return Invert(GetSimplifiedTree(original.SubTrees[0]));
188      } else {
189        // simplify expressions x0..xn
190        // make multiplication (x0 * 1/(x1 * x1 * .. * xn))
191        Trace.Assert(original.SubTrees.Count > 1);
192        var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
193        return
194          MakeProduct(simplifiedTrees.First(), Invert(simplifiedTrees.Skip(1).Aggregate((a, b) => MakeProduct(a, b))));
195      }
196    }
197
198    private SymbolicExpressionTreeNode SimplifyMultiplication(SymbolicExpressionTreeNode original) {
199      if (original.SubTrees.Count == 1) {
200        return GetSimplifiedTree(original.SubTrees[0]);
201      } else {
202        Trace.Assert(original.SubTrees.Count > 1);
203        return original.SubTrees
204          .Select(x => GetSimplifiedTree(x))
205          .Aggregate((a, b) => MakeProduct(a, b));
206      }
207    }
208
209    private SymbolicExpressionTreeNode SimplifySubtraction(SymbolicExpressionTreeNode original) {
210      if (original.SubTrees.Count == 1) {
211        return Negate(GetSimplifiedTree(original.SubTrees[0]));
212      } else {
213        // simplify expressions x0..xn
214        // make addition (x0,-x1..-xn)
215        Trace.Assert(original.SubTrees.Count > 1);
216        var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
217        return simplifiedTrees.Take(1)
218          .Concat(simplifiedTrees.Skip(1).Select(x => Negate(x)))
219          .Aggregate((a, b) => MakeSum(a, b));
220      }
221    }
222
223    private SymbolicExpressionTreeNode SimplifyAddition(SymbolicExpressionTreeNode original) {
224      if (original.SubTrees.Count == 1) {
225        return GetSimplifiedTree(original.SubTrees[0]);
226      } else {
227        // simplify expression x0..xn
228        // make addition (x0..xn)
229        Trace.Assert(original.SubTrees.Count > 1);
230        return original.SubTrees
231          .Select(x => GetSimplifiedTree(x))
232          .Aggregate((a, b) => MakeSum(a, b));
233      }
234    }
235    #endregion
236
237
238
239    #region low level tree restructuring
240    // each make* method must return a simplified tree
241
242    private SymbolicExpressionTreeNode MakeFraction(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
243      if (IsConstant(a) && IsConstant(b)) {
244        // fold constants
245        return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
246      } if (IsConstant(a) && !((ConstantTreeNode)a).Value.IsAlmost(1.0)) {
247        return MakeFraction(MakeConstant(1.0), MakeProduct(b, Invert(a)));
248      } else if (IsVariable(a) && IsConstant(b)) {
249        // merge constant values into variable weights
250        var constB = ((ConstantTreeNode)b).Value;
251        ((VariableTreeNode)a).Weight /= constB;
252        return a;
253      } else if (IsAddition(a) && IsConstant(b)) {
254        return a.SubTrees
255         .Select(x => MakeFraction(x, b))
256         .Aggregate((c, d) => MakeSum(c, d));
257      } else if (IsMultiplication(a) && IsConstant(b)) {
258        return MakeProduct(a, Invert(b));
259      } else if (IsDivision(a) && IsConstant(b)) {
260        // (a1 / a2) / c => (a1 / (a2 * c))
261        Trace.Assert(a.SubTrees.Count == 2);
262        return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
263      } else if (IsDivision(a) && IsDivision(b)) {
264        // (a1 / a2) / (b1 / b2) =>
265        Trace.Assert(a.SubTrees.Count == 2);
266        Trace.Assert(b.SubTrees.Count == 2);
267        return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[1]), MakeProduct(a.SubTrees[1], b.SubTrees[0]));
268      } else if (IsDivision(a)) {
269        // (a1 / a2) / b => (a1 / (a2 * b))
270        Trace.Assert(a.SubTrees.Count == 2);
271        return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
272      } else if (IsDivision(b)) {
273        // a / (b1 / b2) => (a * b2) / b1
274        Trace.Assert(b.SubTrees.Count == 2);
275        return MakeFraction(MakeProduct(a, b.SubTrees[1]), b.SubTrees[0]);
276      } else {
277        var div = divSymbol.CreateTreeNode();
278        div.AddSubTree(a);
279        div.AddSubTree(b);
280        return div;
281      }
282    }
283
284    private SymbolicExpressionTreeNode MakeSum(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
285      if (IsConstant(a) && IsConstant(b)) {
286        // fold constants
287        ((ConstantTreeNode)a).Value += ((ConstantTreeNode)b).Value;
288        return a;
289      } else if (IsConstant(a)) {
290        // c + x => x + c
291        // b is not constant => make sure constant is on the right
292        return MakeSum(b, a);
293      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(0.0)) {
294        // x + 0 => x
295        return a;
296      } else if (IsAddition(a) && IsAddition(b)) {
297        // merge additions
298        var add = addSymbol.CreateTreeNode();
299        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
300        for (int i = 0; i < b.SubTrees.Count - 1; i++) add.AddSubTree(b.SubTrees[i]);
301        if (IsConstant(a.SubTrees.Last()) && IsConstant(b.SubTrees.Last())) {
302          add.AddSubTree(MakeSum(a.SubTrees.Last(), b.SubTrees.Last()));
303        } else if (IsConstant(a.SubTrees.Last())) {
304          add.AddSubTree(b.SubTrees.Last());
305          add.AddSubTree(a.SubTrees.Last());
306        } else {
307          add.AddSubTree(a.SubTrees.Last());
308          add.AddSubTree(b.SubTrees.Last());
309        }
310        MergeVariablesInSum(add);
311        return add;
312      } else if (IsAddition(b)) {
313        return MakeSum(b, a);
314      } else if (IsAddition(a) && IsConstant(b)) {
315        // a is an addition and b is a constant => append b to a and make sure the constants are merged
316        var add = addSymbol.CreateTreeNode();
317        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
318        if (IsConstant(a.SubTrees.Last()))
319          add.AddSubTree(MakeSum(a.SubTrees.Last(), b));
320        else {
321          add.AddSubTree(a.SubTrees.Last());
322          add.AddSubTree(b);
323        }
324        return add;
325      } else if (IsAddition(a)) {
326        // a is already an addition => append b
327        var add = addSymbol.CreateTreeNode();
328        add.AddSubTree(b);
329        foreach (var subTree in a.SubTrees) {
330          add.AddSubTree(subTree);
331        }
332        MergeVariablesInSum(add);
333        return add;
334      } else {
335        var add = addSymbol.CreateTreeNode();
336        add.AddSubTree(a);
337        add.AddSubTree(b);
338        MergeVariablesInSum(add);
339        return add;
340      }
341    }
342
343    private void MergeVariablesInSum(SymbolicExpressionTreeNode sum) {
344      var subtrees = new List<SymbolicExpressionTreeNode>(sum.SubTrees);
345      while (sum.SubTrees.Count > 0) sum.RemoveSubTree(0);
346      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
347                            group node by node.VariableName into g
348                            select g;
349      var unchangedSubTrees = subtrees.Where(t => !(t is VariableTreeNode));
350
351      foreach (var variableNodeGroup in groupedVarNodes) {
352        var weightSum = variableNodeGroup.Select(t => t.Weight).Sum();
353        var representative = variableNodeGroup.First();
354        representative.Weight = weightSum;
355        sum.AddSubTree(representative);
356      }
357      foreach (var unchangedSubtree in unchangedSubTrees)
358        sum.AddSubTree(unchangedSubtree);
359    }
360
361
362    private SymbolicExpressionTreeNode MakeProduct(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
363      if (IsConstant(a) && IsConstant(b)) {
364        // fold constants
365        ((ConstantTreeNode)a).Value *= ((ConstantTreeNode)b).Value;
366        return a;
367      } else if (IsConstant(a)) {
368        // a * $ => $ * a
369        return MakeProduct(b, a);
370      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
371        // $ * 1.0 => $
372        return a;
373      } else if (IsConstant(b) && IsVariable(a)) {
374        // multiply constants into variables weights
375        ((VariableTreeNode)a).Weight *= ((ConstantTreeNode)b).Value;
376        return a;
377      } else if (IsConstant(b) && IsAddition(a)) {
378        // multiply constants into additions
379        return a.SubTrees.Select(x => MakeProduct(x, b)).Aggregate((c, d) => MakeSum(c, d));
380      } else if (IsDivision(a) && IsDivision(b)) {
381        // (a1 / a2) * (b1 / b2) => (a1 * b1) / (a2 * b2)
382        Trace.Assert(a.SubTrees.Count == 2);
383        Trace.Assert(b.SubTrees.Count == 2);
384        return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[0]), MakeProduct(a.SubTrees[1], b.SubTrees[1]));
385      } else if (IsDivision(a)) {
386        // (a1 / a2) * b => (a1 * b) / a2
387        Trace.Assert(a.SubTrees.Count == 2);
388        return MakeFraction(MakeProduct(a.SubTrees[0], b), a.SubTrees[1]);
389      } else if (IsDivision(b)) {
390        // a * (b1 / b2) => (b1 * a) / b2
391        Trace.Assert(b.SubTrees.Count == 2);
392        return MakeFraction(MakeProduct(b.SubTrees[0], a), b.SubTrees[1]);
393      } else if (IsMultiplication(a) && IsMultiplication(b)) {
394        // merge multiplications (make sure constants are merged)
395        var mul = mulSymbol.CreateTreeNode();
396        for (int i = 0; i < a.SubTrees.Count; i++) mul.AddSubTree(a.SubTrees[i]);
397        for (int i = 0; i < b.SubTrees.Count; i++) mul.AddSubTree(b.SubTrees[i]);
398        MergeVariablesAndConstantsInProduct(mul);
399        return mul;
400      } else if (IsMultiplication(b)) {
401        return MakeProduct(b, a);
402      } else if (IsMultiplication(a)) {
403        // a is already an multiplication => append b
404        a.AddSubTree(b);
405        MergeVariablesAndConstantsInProduct(a);
406        return a;
407      } else {
408        var mul = mulSymbol.CreateTreeNode();
409        mul.SubTrees.Add(a);
410        mul.SubTrees.Add(b);
411        MergeVariablesAndConstantsInProduct(mul);
412        return mul;
413      }
414    }
415    #endregion
416
417    private void MergeVariablesAndConstantsInProduct(SymbolicExpressionTreeNode prod) {
418      var subtrees = new List<SymbolicExpressionTreeNode>(prod.SubTrees);
419      while (prod.SubTrees.Count > 0) prod.RemoveSubTree(0);
420      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
421                            group node by node.VariableName into g
422                            orderby g.Count()
423                            select g;
424      var constantProduct = (from node in subtrees.OfType<VariableTreeNode>()
425                             select node.Weight)
426                            .Concat(from node in subtrees.OfType<ConstantTreeNode>()
427                                    select node.Value)
428                            .DefaultIfEmpty(1.0)
429                            .Aggregate((c1, c2) => c1 * c2);
430
431      var unchangedSubTrees = from tree in subtrees
432                              where !(tree is VariableTreeNode)
433                              where !(tree is ConstantTreeNode)
434                              select tree;
435
436      foreach (var variableNodeGroup in groupedVarNodes) {
437        var representative = variableNodeGroup.First();
438        representative.Weight = 1.0;
439        if (variableNodeGroup.Count() > 1) {
440          var poly = mulSymbol.CreateTreeNode();
441          for (int p = 0; p < variableNodeGroup.Count(); p++) {
442            poly.AddSubTree((SymbolicExpressionTreeNode)representative.Clone());
443          }
444          prod.AddSubTree(poly);
445        } else {
446          prod.AddSubTree(representative);
447        }
448      }
449
450      foreach (var unchangedSubtree in unchangedSubTrees)
451        prod.AddSubTree(unchangedSubtree);
452
453      if (!constantProduct.IsAlmost(1.0)) {
454        prod.AddSubTree(MakeConstant(constantProduct));
455      }
456    }
457
458
459    #region helper functions
460    /// <summary>
461    /// x => x * -1
462    /// Doesn't create new trees and manipulates x
463    /// </summary>
464    /// <param name="x"></param>
465    /// <returns>-x</returns>
466    private SymbolicExpressionTreeNode Negate(SymbolicExpressionTreeNode x) {
467      if (IsConstant(x)) {
468        ((ConstantTreeNode)x).Value *= -1;
469      } else if (IsVariable(x)) {
470        var variableTree = (VariableTreeNode)x;
471        variableTree.Weight *= -1.0;
472      } else if (IsAddition(x)) {
473        // (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)       
474        foreach (var subTree in x.SubTrees) {
475          Negate(subTree);
476        }
477      } else if (IsMultiplication(x) || IsDivision(x)) {
478        // x0 * x1 * .. * xn * -1 => x0 * x1 * .. * -xn
479        Negate(x.SubTrees.Last()); // last is maybe a constant, prefer to negate the constant
480      } else {
481        // any other function
482        return MakeProduct(x, MakeConstant(-1));
483      }
484      return x;
485    }
486
487    /// <summary>
488    /// x => 1/x
489    /// Doesn't create new trees and manipulates x
490    /// </summary>
491    /// <param name="x"></param>
492    /// <returns></returns>
493    private SymbolicExpressionTreeNode Invert(SymbolicExpressionTreeNode x) {
494      if (IsConstant(x)) {
495        ((ConstantTreeNode)x).Value = 1.0 / ((ConstantTreeNode)x).Value;
496      } else if (IsDivision(x)) {
497        Trace.Assert(x.SubTrees.Count == 2);
498        return MakeFraction(x.SubTrees[1], x.SubTrees[0]);
499      } else {
500        // any other function
501        return MakeFraction(MakeConstant(1), x);
502      }
503      return x;
504    }
505
506    private SymbolicExpressionTreeNode MakeConstant(double value) {
507      ConstantTreeNode constantTreeNode = (ConstantTreeNode)(constSymbol.CreateTreeNode());
508      constantTreeNode.Value = value;
509      return (SymbolicExpressionTreeNode)constantTreeNode;
510    }
511
512    private SymbolicExpressionTreeNode MakeVariable(double weight, string name) {
513      var tree = (VariableTreeNode)varSymbol.CreateTreeNode();
514      tree.Weight = weight;
515      tree.VariableName = name;
516      return tree;
517    }
518    #endregion
519  }
520}
Note: See TracBrowser for help on using the repository browser.