Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SymbolicSimplifier.cs @ 4249

Last change on this file since 4249 was 4239, checked in by gkronber, 14 years ago

Merged improvements of symbolic simplifier (revisions: r4220, r4226, r4235:4238) back into trunk. #1026

File size: 21.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
29using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  /// <summary>
33  /// Simplistic simplifier for arithmetic expressions
34  /// </summary>
35  public class SymbolicSimplifier {
36    private Addition addSymbol = new Addition();
37    private Multiplication mulSymbol = new Multiplication();
38    private Division divSymbol = new Division();
39    private Constant constSymbol = new Constant();
40    private Variable varSymbol = new Variable();
41
42    public SymbolicExpressionTree Simplify(SymbolicExpressionTree originalTree) {
43      var clone = (SymbolicExpressionTreeNode)originalTree.Root.Clone();
44      // macro expand (initially no argument trees)
45      var macroExpandedTree = MacroExpand(clone, clone.SubTrees[0], new List<SymbolicExpressionTreeNode>());
46      return new SymbolicExpressionTree(GetSimplifiedTree(macroExpandedTree));
47    }
48
49    // the argumentTrees list contains already expanded trees used as arguments for invocations
50    private SymbolicExpressionTreeNode MacroExpand(SymbolicExpressionTreeNode root, SymbolicExpressionTreeNode node, IList<SymbolicExpressionTreeNode> argumentTrees) {
51      List<SymbolicExpressionTreeNode> subtrees = new List<SymbolicExpressionTreeNode>(node.SubTrees);
52      while (node.SubTrees.Count > 0) node.RemoveSubTree(0);
53      if (node.Symbol is InvokeFunction) {
54        var invokeSym = node.Symbol as InvokeFunction;
55        var defunNode = FindFunctionDefinition(root, invokeSym.FunctionName);
56        var macroExpandedArguments = new List<SymbolicExpressionTreeNode>();
57        foreach (var subtree in subtrees) {
58          macroExpandedArguments.Add(MacroExpand(root, subtree, argumentTrees));
59        }
60        return MacroExpand(root, defunNode, macroExpandedArguments);
61      } else if (node.Symbol is Argument) {
62        var argSym = node.Symbol as Argument;
63        // return the correct argument sub-tree (already macro-expanded)
64        return (SymbolicExpressionTreeNode)argumentTrees[argSym.ArgumentIndex].Clone();
65      } else if (node.Symbol is StartSymbol) {
66        return MacroExpand(root, subtrees[0], argumentTrees);
67      } else {
68        // recursive application
69        foreach (var subtree in subtrees) {
70          node.AddSubTree(MacroExpand(root, subtree, argumentTrees));
71        }
72        return node;
73      }
74    }
75
76    private SymbolicExpressionTreeNode FindFunctionDefinition(SymbolicExpressionTreeNode root, string functionName) {
77      foreach (var subtree in root.SubTrees.OfType<DefunTreeNode>()) {
78        if (subtree.FunctionName == functionName) return subtree.SubTrees[0];
79      }
80
81      throw new ArgumentException("Definition of function " + functionName + " not found.");
82    }
83
84
85    #region symbol predicates
86    private bool IsDivision(SymbolicExpressionTreeNode original) {
87      return original.Symbol is Division;
88    }
89
90    private bool IsMultiplication(SymbolicExpressionTreeNode original) {
91      return original.Symbol is Multiplication;
92    }
93
94    private bool IsSubtraction(SymbolicExpressionTreeNode original) {
95      return original.Symbol is Subtraction;
96    }
97
98    private bool IsAddition(SymbolicExpressionTreeNode original) {
99      return original.Symbol is Addition;
100    }
101
102    private bool IsVariable(SymbolicExpressionTreeNode original) {
103      return original.Symbol is Variable;
104    }
105
106    private bool IsConstant(SymbolicExpressionTreeNode original) {
107      return original.Symbol is Constant;
108    }
109
110    private bool IsAverage(SymbolicExpressionTreeNode original) {
111      return original.Symbol is Average;
112    }
113    private bool IsLog(SymbolicExpressionTreeNode original) {
114      return original.Symbol is Logarithm;
115    }
116    private bool IsIfThenElse(SymbolicExpressionTreeNode original) {
117      return original.Symbol is IfThenElse;
118    }
119    #endregion
120
121    /// <summary>
122    /// Creates a new simplified tree
123    /// </summary>
124    /// <param name="original"></param>
125    /// <returns></returns>
126    public SymbolicExpressionTreeNode GetSimplifiedTree(SymbolicExpressionTreeNode original) {
127      if (IsConstant(original) || IsVariable(original)) {
128        return (SymbolicExpressionTreeNode)original.Clone();
129      } else if (IsAddition(original)) {
130        return SimplifyAddition(original);
131      } else if (IsSubtraction(original)) {
132        return SimplifySubtraction(original);
133      } else if (IsMultiplication(original)) {
134        return SimplifyMultiplication(original);
135      } else if (IsDivision(original)) {
136        return SimplifyDivision(original);
137      } else if (IsAverage(original)) {
138        return SimplifyAverage(original);
139      } else if (IsLog(original)) {
140        // TODO simplify logarithm
141        return SimplifyAny(original);
142      } else if (IsIfThenElse(original)) {
143        // TODO simplify conditionals
144        return SimplifyAny(original);
145      } else if (IsAverage(original)) {
146        return SimplifyAverage(original);
147      } else {
148        return SimplifyAny(original);
149      }
150    }
151
152    #region specific simplification routines
153    private SymbolicExpressionTreeNode SimplifyAny(SymbolicExpressionTreeNode original) {
154      // can't simplify this function but simplify all subtrees
155      List<SymbolicExpressionTreeNode> subTrees = new List<SymbolicExpressionTreeNode>(original.SubTrees);
156      while (original.SubTrees.Count > 0) original.RemoveSubTree(0);
157      var clone = (SymbolicExpressionTreeNode)original.Clone();
158      List<SymbolicExpressionTreeNode> simplifiedSubTrees = new List<SymbolicExpressionTreeNode>();
159      foreach (var subTree in subTrees) {
160        simplifiedSubTrees.Add(GetSimplifiedTree(subTree));
161        original.AddSubTree(subTree);
162      }
163      foreach (var simplifiedSubtree in simplifiedSubTrees) {
164        clone.AddSubTree(simplifiedSubtree);
165      }
166      if (simplifiedSubTrees.TrueForAll(t => IsConstant(t))) {
167        SimplifyConstantExpression(clone);
168      }
169      return clone;
170    }
171
172    private SymbolicExpressionTreeNode SimplifyConstantExpression(SymbolicExpressionTreeNode original) {
173      // not yet implemented
174      return original;
175    }
176
177    private SymbolicExpressionTreeNode SimplifyAverage(SymbolicExpressionTreeNode original) {
178      if (original.SubTrees.Count == 1) {
179        return GetSimplifiedTree(original.SubTrees[0]);
180      } else {
181        // simplify expressions x0..xn
182        // make sum(x0..xn) / n
183        Trace.Assert(original.SubTrees.Count > 1);
184        var sum = original.SubTrees
185          .Select(x => GetSimplifiedTree(x))
186          .Aggregate((a, b) => MakeSum(a, b));
187        return MakeFraction(sum, MakeConstant(original.SubTrees.Count));
188      }
189    }
190
191    private SymbolicExpressionTreeNode SimplifyDivision(SymbolicExpressionTreeNode original) {
192      if (original.SubTrees.Count == 1) {
193        return Invert(GetSimplifiedTree(original.SubTrees[0]));
194      } else {
195        // simplify expressions x0..xn
196        // make multiplication (x0 * 1/(x1 * x1 * .. * xn))
197        Trace.Assert(original.SubTrees.Count > 1);
198        var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
199        return
200          MakeProduct(simplifiedTrees.First(), Invert(simplifiedTrees.Skip(1).Aggregate((a, b) => MakeProduct(a, b))));
201      }
202    }
203
204    private SymbolicExpressionTreeNode SimplifyMultiplication(SymbolicExpressionTreeNode original) {
205      if (original.SubTrees.Count == 1) {
206        return GetSimplifiedTree(original.SubTrees[0]);
207      } else {
208        Trace.Assert(original.SubTrees.Count > 1);
209        return original.SubTrees
210          .Select(x => GetSimplifiedTree(x))
211          .Aggregate((a, b) => MakeProduct(a, b));
212      }
213    }
214
215    private SymbolicExpressionTreeNode SimplifySubtraction(SymbolicExpressionTreeNode original) {
216      if (original.SubTrees.Count == 1) {
217        return Negate(GetSimplifiedTree(original.SubTrees[0]));
218      } else {
219        // simplify expressions x0..xn
220        // make addition (x0,-x1..-xn)
221        Trace.Assert(original.SubTrees.Count > 1);
222        var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
223        return simplifiedTrees.Take(1)
224          .Concat(simplifiedTrees.Skip(1).Select(x => Negate(x)))
225          .Aggregate((a, b) => MakeSum(a, b));
226      }
227    }
228
229    private SymbolicExpressionTreeNode SimplifyAddition(SymbolicExpressionTreeNode original) {
230      if (original.SubTrees.Count == 1) {
231        return GetSimplifiedTree(original.SubTrees[0]);
232      } else {
233        // simplify expression x0..xn
234        // make addition (x0..xn)
235        Trace.Assert(original.SubTrees.Count > 1);
236        return original.SubTrees
237          .Select(x => GetSimplifiedTree(x))
238          .Aggregate((a, b) => MakeSum(a, b));
239      }
240    }
241    #endregion
242
243
244
245    #region low level tree restructuring
246    // MakeFraction, MakeProduct and MakeSum take two already simplified trees and create a new simplified tree
247
248    private SymbolicExpressionTreeNode MakeFraction(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
249      if (IsConstant(a) && IsConstant(b)) {
250        // fold constants
251        return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
252      } if (IsConstant(a) && !((ConstantTreeNode)a).Value.IsAlmost(1.0)) {
253        return MakeFraction(MakeConstant(1.0), MakeProduct(b, Invert(a)));
254      } else if (IsVariable(a) && IsConstant(b)) {
255        // merge constant values into variable weights
256        var constB = ((ConstantTreeNode)b).Value;
257        ((VariableTreeNode)a).Weight /= constB;
258        return a;
259      } else if (IsAddition(a) && IsConstant(b)) {
260        return a.SubTrees
261         .Select(x => MakeFraction(x, b))
262         .Aggregate((c, d) => MakeSum(c, d));
263      } else if (IsMultiplication(a) && IsConstant(b)) {
264        return MakeProduct(a, Invert(b));
265      } else if (IsDivision(a) && IsConstant(b)) {
266        // (a1 / a2) / c => (a1 / (a2 * c))
267        Trace.Assert(a.SubTrees.Count == 2);
268        return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
269      } else if (IsDivision(a) && IsDivision(b)) {
270        // (a1 / a2) / (b1 / b2) =>
271        Trace.Assert(a.SubTrees.Count == 2);
272        Trace.Assert(b.SubTrees.Count == 2);
273        return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[1]), MakeProduct(a.SubTrees[1], b.SubTrees[0]));
274      } else if (IsDivision(a)) {
275        // (a1 / a2) / b => (a1 / (a2 * b))
276        Trace.Assert(a.SubTrees.Count == 2);
277        return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
278      } else if (IsDivision(b)) {
279        // a / (b1 / b2) => (a * b2) / b1
280        Trace.Assert(b.SubTrees.Count == 2);
281        return MakeFraction(MakeProduct(a, b.SubTrees[1]), b.SubTrees[0]);
282      } else {
283        var div = divSymbol.CreateTreeNode();
284        div.AddSubTree(a);
285        div.AddSubTree(b);
286        return div;
287      }
288    }
289
290    private SymbolicExpressionTreeNode MakeSum(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
291      if (IsConstant(a) && IsConstant(b)) {
292        // fold constants
293        ((ConstantTreeNode)a).Value += ((ConstantTreeNode)b).Value;
294        return a;
295      } else if (IsConstant(a)) {
296        // c + x => x + c
297        // b is not constant => make sure constant is on the right
298        return MakeSum(b, a);
299      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(0.0)) {
300        // x + 0 => x
301        return a;
302      } else if (IsAddition(a) && IsAddition(b)) {
303        // merge additions
304        var add = addSymbol.CreateTreeNode();
305        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
306        for (int i = 0; i < b.SubTrees.Count - 1; i++) add.AddSubTree(b.SubTrees[i]);
307        if (IsConstant(a.SubTrees.Last()) && IsConstant(b.SubTrees.Last())) {
308          add.AddSubTree(MakeSum(a.SubTrees.Last(), b.SubTrees.Last()));
309        } else if (IsConstant(a.SubTrees.Last())) {
310          add.AddSubTree(b.SubTrees.Last());
311          add.AddSubTree(a.SubTrees.Last());
312        } else {
313          add.AddSubTree(a.SubTrees.Last());
314          add.AddSubTree(b.SubTrees.Last());
315        }
316        MergeVariablesInSum(add);
317        return add;
318      } else if (IsAddition(b)) {
319        return MakeSum(b, a);
320      } else if (IsAddition(a) && IsConstant(b)) {
321        // a is an addition and b is a constant => append b to a and make sure the constants are merged
322        var add = addSymbol.CreateTreeNode();
323        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
324        if (IsConstant(a.SubTrees.Last()))
325          add.AddSubTree(MakeSum(a.SubTrees.Last(), b));
326        else {
327          add.AddSubTree(a.SubTrees.Last());
328          add.AddSubTree(b);
329        }
330        return add;
331      } else if (IsAddition(a)) {
332        // a is already an addition => append b
333        var add = addSymbol.CreateTreeNode();
334        add.AddSubTree(b);
335        foreach (var subTree in a.SubTrees) {
336          add.AddSubTree(subTree);
337        }
338        MergeVariablesInSum(add);
339        return add;
340      } else {
341        var add = addSymbol.CreateTreeNode();
342        add.AddSubTree(a);
343        add.AddSubTree(b);
344        MergeVariablesInSum(add);
345        return add;
346      }
347    }
348
349    // makes sure variable symbols in sums are combined
350    // possible improvment: combine sums of products where the products only reference the same variable
351    private void MergeVariablesInSum(SymbolicExpressionTreeNode sum) {
352      var subtrees = new List<SymbolicExpressionTreeNode>(sum.SubTrees);
353      while (sum.SubTrees.Count > 0) sum.RemoveSubTree(0);
354      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
355                            group node by node.VariableName into g
356                            select g;
357      var unchangedSubTrees = subtrees.Where(t => !(t is VariableTreeNode));
358
359      foreach (var variableNodeGroup in groupedVarNodes) {
360        var weightSum = variableNodeGroup.Select(t => t.Weight).Sum();
361        var representative = variableNodeGroup.First();
362        representative.Weight = weightSum;
363        sum.AddSubTree(representative);
364      }
365      foreach (var unchangedSubtree in unchangedSubTrees)
366        sum.AddSubTree(unchangedSubtree);
367    }
368
369
370    private SymbolicExpressionTreeNode MakeProduct(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
371      if (IsConstant(a) && IsConstant(b)) {
372        // fold constants
373        ((ConstantTreeNode)a).Value *= ((ConstantTreeNode)b).Value;
374        return a;
375      } else if (IsConstant(a)) {
376        // a * $ => $ * a
377        return MakeProduct(b, a);
378      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
379        // $ * 1.0 => $
380        return a;
381      } else if (IsConstant(b) && IsVariable(a)) {
382        // multiply constants into variables weights
383        ((VariableTreeNode)a).Weight *= ((ConstantTreeNode)b).Value;
384        return a;
385      } else if (IsConstant(b) && IsAddition(a)) {
386        // multiply constants into additions
387        return a.SubTrees.Select(x => MakeProduct(x, b)).Aggregate((c, d) => MakeSum(c, d));
388      } else if (IsDivision(a) && IsDivision(b)) {
389        // (a1 / a2) * (b1 / b2) => (a1 * b1) / (a2 * b2)
390        Trace.Assert(a.SubTrees.Count == 2);
391        Trace.Assert(b.SubTrees.Count == 2);
392        return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[0]), MakeProduct(a.SubTrees[1], b.SubTrees[1]));
393      } else if (IsDivision(a)) {
394        // (a1 / a2) * b => (a1 * b) / a2
395        Trace.Assert(a.SubTrees.Count == 2);
396        return MakeFraction(MakeProduct(a.SubTrees[0], b), a.SubTrees[1]);
397      } else if (IsDivision(b)) {
398        // a * (b1 / b2) => (b1 * a) / b2
399        Trace.Assert(b.SubTrees.Count == 2);
400        return MakeFraction(MakeProduct(b.SubTrees[0], a), b.SubTrees[1]);
401      } else if (IsMultiplication(a) && IsMultiplication(b)) {
402        // merge multiplications (make sure constants are merged)
403        var mul = mulSymbol.CreateTreeNode();
404        for (int i = 0; i < a.SubTrees.Count; i++) mul.AddSubTree(a.SubTrees[i]);
405        for (int i = 0; i < b.SubTrees.Count; i++) mul.AddSubTree(b.SubTrees[i]);
406        MergeVariablesAndConstantsInProduct(mul);
407        return mul;
408      } else if (IsMultiplication(b)) {
409        return MakeProduct(b, a);
410      } else if (IsMultiplication(a)) {
411        // a is already an multiplication => append b
412        a.AddSubTree(b);
413        MergeVariablesAndConstantsInProduct(a);
414        return a;
415      } else {
416        var mul = mulSymbol.CreateTreeNode();
417        mul.SubTrees.Add(a);
418        mul.SubTrees.Add(b);
419        MergeVariablesAndConstantsInProduct(mul);
420        return mul;
421      }
422    }
423    #endregion
424
425    // helper to combine the constant factors in products and to combine variables (powers of 2, 3...)
426    private void MergeVariablesAndConstantsInProduct(SymbolicExpressionTreeNode prod) {
427      var subtrees = new List<SymbolicExpressionTreeNode>(prod.SubTrees);
428      while (prod.SubTrees.Count > 0) prod.RemoveSubTree(0);
429      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
430                            group node by node.VariableName into g
431                            orderby g.Count()
432                            select g;
433      var constantProduct = (from node in subtrees.OfType<VariableTreeNode>()
434                             select node.Weight)
435                            .Concat(from node in subtrees.OfType<ConstantTreeNode>()
436                                    select node.Value)
437                            .DefaultIfEmpty(1.0)
438                            .Aggregate((c1, c2) => c1 * c2);
439
440      var unchangedSubTrees = from tree in subtrees
441                              where !(tree is VariableTreeNode)
442                              where !(tree is ConstantTreeNode)
443                              select tree;
444
445      foreach (var variableNodeGroup in groupedVarNodes) {
446        var representative = variableNodeGroup.First();
447        representative.Weight = 1.0;
448        if (variableNodeGroup.Count() > 1) {
449          var poly = mulSymbol.CreateTreeNode();
450          for (int p = 0; p < variableNodeGroup.Count(); p++) {
451            poly.AddSubTree((SymbolicExpressionTreeNode)representative.Clone());
452          }
453          prod.AddSubTree(poly);
454        } else {
455          prod.AddSubTree(representative);
456        }
457      }
458
459      foreach (var unchangedSubtree in unchangedSubTrees)
460        prod.AddSubTree(unchangedSubtree);
461
462      if (!constantProduct.IsAlmost(1.0)) {
463        prod.AddSubTree(MakeConstant(constantProduct));
464      }
465    }
466
467
468    #region helper functions
469    /// <summary>
470    /// x => x * -1
471    /// Doesn't create new trees and manipulates x
472    /// </summary>
473    /// <param name="x"></param>
474    /// <returns>-x</returns>
475    private SymbolicExpressionTreeNode Negate(SymbolicExpressionTreeNode x) {
476      if (IsConstant(x)) {
477        ((ConstantTreeNode)x).Value *= -1;
478      } else if (IsVariable(x)) {
479        var variableTree = (VariableTreeNode)x;
480        variableTree.Weight *= -1.0;
481      } else if (IsAddition(x)) {
482        // (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)       
483        foreach (var subTree in x.SubTrees) {
484          Negate(subTree);
485        }
486      } else if (IsMultiplication(x) || IsDivision(x)) {
487        // x0 * x1 * .. * xn * -1 => x0 * x1 * .. * -xn
488        Negate(x.SubTrees.Last()); // last is maybe a constant, prefer to negate the constant
489      } else {
490        // any other function
491        return MakeProduct(x, MakeConstant(-1));
492      }
493      return x;
494    }
495
496    /// <summary>
497    /// x => 1/x
498    /// Doesn't create new trees and manipulates x
499    /// </summary>
500    /// <param name="x"></param>
501    /// <returns></returns>
502    private SymbolicExpressionTreeNode Invert(SymbolicExpressionTreeNode x) {
503      if (IsConstant(x)) {
504        return MakeConstant(1.0 / ((ConstantTreeNode)x).Value);
505      } else if (IsDivision(x)) {
506        Trace.Assert(x.SubTrees.Count == 2);
507        return MakeFraction(x.SubTrees[1], x.SubTrees[0]);
508      } else {
509        // any other function
510        return MakeFraction(MakeConstant(1), x);
511      }
512    }
513
514    private SymbolicExpressionTreeNode MakeConstant(double value) {
515      ConstantTreeNode constantTreeNode = (ConstantTreeNode)(constSymbol.CreateTreeNode());
516      constantTreeNode.Value = value;
517      return (SymbolicExpressionTreeNode)constantTreeNode;
518    }
519
520    private SymbolicExpressionTreeNode MakeVariable(double weight, string name) {
521      var tree = (VariableTreeNode)varSymbol.CreateTreeNode();
522      tree.Weight = weight;
523      tree.VariableName = name;
524      return tree;
525    }
526    #endregion
527  }
528}
Note: See TracBrowser for help on using the repository browser.