Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SymbolicSimplifier.cs @ 4462

Last change on this file since 4462 was 4462, checked in by gkronber, 12 years ago

Changed symbolic simplifier to work for multi-variate models and return a symbolic expression tree that can be directly evaluated. #1142

File size: 21.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
29using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  /// <summary>
33  /// Simplistic simplifier for arithmetic expressions
34  /// </summary>
35  public class SymbolicSimplifier {
36    private Addition addSymbol = new Addition();
37    private Multiplication mulSymbol = new Multiplication();
38    private Division divSymbol = new Division();
39    private Constant constSymbol = new Constant();
40    private Variable varSymbol = new Variable();
41
42    public SymbolicExpressionTree Simplify(SymbolicExpressionTree originalTree) {
43      var clone = (SymbolicExpressionTreeNode)originalTree.Root.Clone();
44      // macro expand (initially no argument trees)
45      var macroExpandedTree = MacroExpand(clone, clone.SubTrees[0], new List<SymbolicExpressionTreeNode>());
46      SymbolicExpressionTreeNode rootNode = (new ProgramRootSymbol()).CreateTreeNode();
47      rootNode.AddSubTree(GetSimplifiedTree(macroExpandedTree));
48      return new SymbolicExpressionTree(rootNode);
49    }
50
51    // the argumentTrees list contains already expanded trees used as arguments for invocations
52    private SymbolicExpressionTreeNode MacroExpand(SymbolicExpressionTreeNode root, SymbolicExpressionTreeNode node, IList<SymbolicExpressionTreeNode> argumentTrees) {
53      List<SymbolicExpressionTreeNode> subtrees = new List<SymbolicExpressionTreeNode>(node.SubTrees);
54      while (node.SubTrees.Count > 0) node.RemoveSubTree(0);
55      if (node.Symbol is InvokeFunction) {
56        var invokeSym = node.Symbol as InvokeFunction;
57        var defunNode = FindFunctionDefinition(root, invokeSym.FunctionName);
58        var macroExpandedArguments = new List<SymbolicExpressionTreeNode>();
59        foreach (var subtree in subtrees) {
60          macroExpandedArguments.Add(MacroExpand(root, subtree, argumentTrees));
61        }
62        return MacroExpand(root, defunNode, macroExpandedArguments);
63      } else if (node.Symbol is Argument) {
64        var argSym = node.Symbol as Argument;
65        // return the correct argument sub-tree (already macro-expanded)
66        return (SymbolicExpressionTreeNode)argumentTrees[argSym.ArgumentIndex].Clone();
67      } else {
68        // recursive application
69        foreach (var subtree in subtrees) {
70          node.AddSubTree(MacroExpand(root, subtree, argumentTrees));
71        }
72        return node;
73      }
74    }
75
76    private SymbolicExpressionTreeNode FindFunctionDefinition(SymbolicExpressionTreeNode root, string functionName) {
77      foreach (var subtree in root.SubTrees.OfType<DefunTreeNode>()) {
78        if (subtree.FunctionName == functionName) return subtree.SubTrees[0];
79      }
80
81      throw new ArgumentException("Definition of function " + functionName + " not found.");
82    }
83
84
85    #region symbol predicates
86    private bool IsDivision(SymbolicExpressionTreeNode original) {
87      return original.Symbol is Division;
88    }
89
90    private bool IsMultiplication(SymbolicExpressionTreeNode original) {
91      return original.Symbol is Multiplication;
92    }
93
94    private bool IsSubtraction(SymbolicExpressionTreeNode original) {
95      return original.Symbol is Subtraction;
96    }
97
98    private bool IsAddition(SymbolicExpressionTreeNode original) {
99      return original.Symbol is Addition;
100    }
101
102    private bool IsVariable(SymbolicExpressionTreeNode original) {
103      return original.Symbol is Variable;
104    }
105
106    private bool IsConstant(SymbolicExpressionTreeNode original) {
107      return original.Symbol is Constant;
108    }
109
110    private bool IsAverage(SymbolicExpressionTreeNode original) {
111      return original.Symbol is Average;
112    }
113    private bool IsLog(SymbolicExpressionTreeNode original) {
114      return original.Symbol is Logarithm;
115    }
116    private bool IsIfThenElse(SymbolicExpressionTreeNode original) {
117      return original.Symbol is IfThenElse;
118    }
119    #endregion
120
121    /// <summary>
122    /// Creates a new simplified tree
123    /// </summary>
124    /// <param name="original"></param>
125    /// <returns></returns>
126    public SymbolicExpressionTreeNode GetSimplifiedTree(SymbolicExpressionTreeNode original) {
127      if (IsConstant(original) || IsVariable(original)) {
128        return (SymbolicExpressionTreeNode)original.Clone();
129      } else if (IsAddition(original)) {
130        return SimplifyAddition(original);
131      } else if (IsSubtraction(original)) {
132        return SimplifySubtraction(original);
133      } else if (IsMultiplication(original)) {
134        return SimplifyMultiplication(original);
135      } else if (IsDivision(original)) {
136        return SimplifyDivision(original);
137      } else if (IsAverage(original)) {
138        return SimplifyAverage(original);
139      } else if (IsLog(original)) {
140        // TODO simplify logarithm
141        return SimplifyAny(original);
142      } else if (IsIfThenElse(original)) {
143        // TODO simplify conditionals
144        return SimplifyAny(original);
145      } else if (IsAverage(original)) {
146        return SimplifyAverage(original);
147      } else {
148        return SimplifyAny(original);
149      }
150    }
151
152    #region specific simplification routines
153    private SymbolicExpressionTreeNode SimplifyAny(SymbolicExpressionTreeNode original) {
154      // can't simplify this function but simplify all subtrees
155      List<SymbolicExpressionTreeNode> subTrees = new List<SymbolicExpressionTreeNode>(original.SubTrees);
156      while (original.SubTrees.Count > 0) original.RemoveSubTree(0);
157      var clone = (SymbolicExpressionTreeNode)original.Clone();
158      List<SymbolicExpressionTreeNode> simplifiedSubTrees = new List<SymbolicExpressionTreeNode>();
159      foreach (var subTree in subTrees) {
160        simplifiedSubTrees.Add(GetSimplifiedTree(subTree));
161        original.AddSubTree(subTree);
162      }
163      foreach (var simplifiedSubtree in simplifiedSubTrees) {
164        clone.AddSubTree(simplifiedSubtree);
165      }
166      if (simplifiedSubTrees.TrueForAll(t => IsConstant(t))) {
167        SimplifyConstantExpression(clone);
168      }
169      return clone;
170    }
171
172    private SymbolicExpressionTreeNode SimplifyConstantExpression(SymbolicExpressionTreeNode original) {
173      // not yet implemented
174      return original;
175    }
176
177    private SymbolicExpressionTreeNode SimplifyAverage(SymbolicExpressionTreeNode original) {
178      if (original.SubTrees.Count == 1) {
179        return GetSimplifiedTree(original.SubTrees[0]);
180      } else {
181        // simplify expressions x0..xn
182        // make sum(x0..xn) / n
183        Trace.Assert(original.SubTrees.Count > 1);
184        var sum = original.SubTrees
185          .Select(x => GetSimplifiedTree(x))
186          .Aggregate((a, b) => MakeSum(a, b));
187        return MakeFraction(sum, MakeConstant(original.SubTrees.Count));
188      }
189    }
190
191    private SymbolicExpressionTreeNode SimplifyDivision(SymbolicExpressionTreeNode original) {
192      if (original.SubTrees.Count == 1) {
193        return Invert(GetSimplifiedTree(original.SubTrees[0]));
194      } else {
195        // simplify expressions x0..xn
196        // make multiplication (x0 * 1/(x1 * x1 * .. * xn))
197        Trace.Assert(original.SubTrees.Count > 1);
198        var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
199        return
200          MakeProduct(simplifiedTrees.First(), Invert(simplifiedTrees.Skip(1).Aggregate((a, b) => MakeProduct(a, b))));
201      }
202    }
203
204    private SymbolicExpressionTreeNode SimplifyMultiplication(SymbolicExpressionTreeNode original) {
205      if (original.SubTrees.Count == 1) {
206        return GetSimplifiedTree(original.SubTrees[0]);
207      } else {
208        Trace.Assert(original.SubTrees.Count > 1);
209        return original.SubTrees
210          .Select(x => GetSimplifiedTree(x))
211          .Aggregate((a, b) => MakeProduct(a, b));
212      }
213    }
214
215    private SymbolicExpressionTreeNode SimplifySubtraction(SymbolicExpressionTreeNode original) {
216      if (original.SubTrees.Count == 1) {
217        return Negate(GetSimplifiedTree(original.SubTrees[0]));
218      } else {
219        // simplify expressions x0..xn
220        // make addition (x0,-x1..-xn)
221        Trace.Assert(original.SubTrees.Count > 1);
222        var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
223        return simplifiedTrees.Take(1)
224          .Concat(simplifiedTrees.Skip(1).Select(x => Negate(x)))
225          .Aggregate((a, b) => MakeSum(a, b));
226      }
227    }
228
229    private SymbolicExpressionTreeNode SimplifyAddition(SymbolicExpressionTreeNode original) {
230      if (original.SubTrees.Count == 1) {
231        return GetSimplifiedTree(original.SubTrees[0]);
232      } else {
233        // simplify expression x0..xn
234        // make addition (x0..xn)
235        Trace.Assert(original.SubTrees.Count > 1);
236        return original.SubTrees
237          .Select(x => GetSimplifiedTree(x))
238          .Aggregate((a, b) => MakeSum(a, b));
239      }
240    }
241    #endregion
242
243
244
245    #region low level tree restructuring
246    // MakeFraction, MakeProduct and MakeSum take two already simplified trees and create a new simplified tree
247
248    private SymbolicExpressionTreeNode MakeFraction(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
249      if (IsConstant(a) && IsConstant(b)) {
250        // fold constants
251        return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
252      } if (IsConstant(a) && !((ConstantTreeNode)a).Value.IsAlmost(1.0)) {
253        return MakeFraction(MakeConstant(1.0), MakeProduct(b, Invert(a)));
254      } else if (IsVariable(a) && IsConstant(b)) {
255        // merge constant values into variable weights
256        var constB = ((ConstantTreeNode)b).Value;
257        ((VariableTreeNode)a).Weight /= constB;
258        return a;
259      } else if (IsAddition(a) && IsConstant(b)) {
260        return a.SubTrees
261          .Select(x => GetSimplifiedTree(x))
262         .Select(x => MakeFraction(x, b))
263         .Aggregate((c, d) => MakeSum(c, d));
264      } else if (IsMultiplication(a) && IsConstant(b)) {
265        return MakeProduct(a, Invert(b));
266      } else if (IsDivision(a) && IsConstant(b)) {
267        // (a1 / a2) / c => (a1 / (a2 * c))
268        Trace.Assert(a.SubTrees.Count == 2);
269        return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
270      } else if (IsDivision(a) && IsDivision(b)) {
271        // (a1 / a2) / (b1 / b2) =>
272        Trace.Assert(a.SubTrees.Count == 2);
273        Trace.Assert(b.SubTrees.Count == 2);
274        return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[1]), MakeProduct(a.SubTrees[1], b.SubTrees[0]));
275      } else if (IsDivision(a)) {
276        // (a1 / a2) / b => (a1 / (a2 * b))
277        Trace.Assert(a.SubTrees.Count == 2);
278        return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
279      } else if (IsDivision(b)) {
280        // a / (b1 / b2) => (a * b2) / b1
281        Trace.Assert(b.SubTrees.Count == 2);
282        return MakeFraction(MakeProduct(a, b.SubTrees[1]), b.SubTrees[0]);
283      } else {
284        var div = divSymbol.CreateTreeNode();
285        div.AddSubTree(a);
286        div.AddSubTree(b);
287        return div;
288      }
289    }
290
291    private SymbolicExpressionTreeNode MakeSum(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
292      if (IsConstant(a) && IsConstant(b)) {
293        // fold constants
294        ((ConstantTreeNode)a).Value += ((ConstantTreeNode)b).Value;
295        return a;
296      } else if (IsConstant(a)) {
297        // c + x => x + c
298        // b is not constant => make sure constant is on the right
299        return MakeSum(b, a);
300      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(0.0)) {
301        // x + 0 => x
302        return a;
303      } else if (IsAddition(a) && IsAddition(b)) {
304        // merge additions
305        var add = addSymbol.CreateTreeNode();
306        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
307        for (int i = 0; i < b.SubTrees.Count - 1; i++) add.AddSubTree(b.SubTrees[i]);
308        if (IsConstant(a.SubTrees.Last()) && IsConstant(b.SubTrees.Last())) {
309          add.AddSubTree(MakeSum(a.SubTrees.Last(), b.SubTrees.Last()));
310        } else if (IsConstant(a.SubTrees.Last())) {
311          add.AddSubTree(b.SubTrees.Last());
312          add.AddSubTree(a.SubTrees.Last());
313        } else {
314          add.AddSubTree(a.SubTrees.Last());
315          add.AddSubTree(b.SubTrees.Last());
316        }
317        MergeVariablesInSum(add);
318        return add;
319      } else if (IsAddition(b)) {
320        return MakeSum(b, a);
321      } else if (IsAddition(a) && IsConstant(b)) {
322        // a is an addition and b is a constant => append b to a and make sure the constants are merged
323        var add = addSymbol.CreateTreeNode();
324        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
325        if (IsConstant(a.SubTrees.Last()))
326          add.AddSubTree(MakeSum(a.SubTrees.Last(), b));
327        else {
328          add.AddSubTree(a.SubTrees.Last());
329          add.AddSubTree(b);
330        }
331        return add;
332      } else if (IsAddition(a)) {
333        // a is already an addition => append b
334        var add = addSymbol.CreateTreeNode();
335        add.AddSubTree(b);
336        foreach (var subTree in a.SubTrees) {
337          add.AddSubTree(subTree);
338        }
339        MergeVariablesInSum(add);
340        return add;
341      } else {
342        var add = addSymbol.CreateTreeNode();
343        add.AddSubTree(a);
344        add.AddSubTree(b);
345        MergeVariablesInSum(add);
346        return add;
347      }
348    }
349
350    // makes sure variable symbols in sums are combined
351    // possible improvment: combine sums of products where the products only reference the same variable
352    private void MergeVariablesInSum(SymbolicExpressionTreeNode sum) {
353      var subtrees = new List<SymbolicExpressionTreeNode>(sum.SubTrees);
354      while (sum.SubTrees.Count > 0) sum.RemoveSubTree(0);
355      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
356                            group node by node.VariableName into g
357                            select g;
358      var unchangedSubTrees = subtrees.Where(t => !(t is VariableTreeNode));
359
360      foreach (var variableNodeGroup in groupedVarNodes) {
361        var weightSum = variableNodeGroup.Select(t => t.Weight).Sum();
362        var representative = variableNodeGroup.First();
363        representative.Weight = weightSum;
364        sum.AddSubTree(representative);
365      }
366      foreach (var unchangedSubtree in unchangedSubTrees)
367        sum.AddSubTree(unchangedSubtree);
368    }
369
370
371    private SymbolicExpressionTreeNode MakeProduct(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
372      if (IsConstant(a) && IsConstant(b)) {
373        // fold constants
374        ((ConstantTreeNode)a).Value *= ((ConstantTreeNode)b).Value;
375        return a;
376      } else if (IsConstant(a)) {
377        // a * $ => $ * a
378        return MakeProduct(b, a);
379      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
380        // $ * 1.0 => $
381        return a;
382      } else if (IsConstant(b) && IsVariable(a)) {
383        // multiply constants into variables weights
384        ((VariableTreeNode)a).Weight *= ((ConstantTreeNode)b).Value;
385        return a;
386      } else if (IsConstant(b) && IsAddition(a)) {
387        // multiply constants into additions
388        return a.SubTrees.Select(x => MakeProduct(x, b)).Aggregate((c, d) => MakeSum(c, d));
389      } else if (IsDivision(a) && IsDivision(b)) {
390        // (a1 / a2) * (b1 / b2) => (a1 * b1) / (a2 * b2)
391        Trace.Assert(a.SubTrees.Count == 2);
392        Trace.Assert(b.SubTrees.Count == 2);
393        return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[0]), MakeProduct(a.SubTrees[1], b.SubTrees[1]));
394      } else if (IsDivision(a)) {
395        // (a1 / a2) * b => (a1 * b) / a2
396        Trace.Assert(a.SubTrees.Count == 2);
397        return MakeFraction(MakeProduct(a.SubTrees[0], b), a.SubTrees[1]);
398      } else if (IsDivision(b)) {
399        // a * (b1 / b2) => (b1 * a) / b2
400        Trace.Assert(b.SubTrees.Count == 2);
401        return MakeFraction(MakeProduct(b.SubTrees[0], a), b.SubTrees[1]);
402      } else if (IsMultiplication(a) && IsMultiplication(b)) {
403        // merge multiplications (make sure constants are merged)
404        var mul = mulSymbol.CreateTreeNode();
405        for (int i = 0; i < a.SubTrees.Count; i++) mul.AddSubTree(a.SubTrees[i]);
406        for (int i = 0; i < b.SubTrees.Count; i++) mul.AddSubTree(b.SubTrees[i]);
407        MergeVariablesAndConstantsInProduct(mul);
408        return mul;
409      } else if (IsMultiplication(b)) {
410        return MakeProduct(b, a);
411      } else if (IsMultiplication(a)) {
412        // a is already an multiplication => append b
413        a.AddSubTree(b);
414        MergeVariablesAndConstantsInProduct(a);
415        return a;
416      } else {
417        var mul = mulSymbol.CreateTreeNode();
418        mul.SubTrees.Add(a);
419        mul.SubTrees.Add(b);
420        MergeVariablesAndConstantsInProduct(mul);
421        return mul;
422      }
423    }
424    #endregion
425
426    // helper to combine the constant factors in products and to combine variables (powers of 2, 3...)
427    private void MergeVariablesAndConstantsInProduct(SymbolicExpressionTreeNode prod) {
428      var subtrees = new List<SymbolicExpressionTreeNode>(prod.SubTrees);
429      while (prod.SubTrees.Count > 0) prod.RemoveSubTree(0);
430      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
431                            group node by node.VariableName into g
432                            orderby g.Count()
433                            select g;
434      var constantProduct = (from node in subtrees.OfType<VariableTreeNode>()
435                             select node.Weight)
436                            .Concat(from node in subtrees.OfType<ConstantTreeNode>()
437                                    select node.Value)
438                            .DefaultIfEmpty(1.0)
439                            .Aggregate((c1, c2) => c1 * c2);
440
441      var unchangedSubTrees = from tree in subtrees
442                              where !(tree is VariableTreeNode)
443                              where !(tree is ConstantTreeNode)
444                              select tree;
445
446      foreach (var variableNodeGroup in groupedVarNodes) {
447        var representative = variableNodeGroup.First();
448        representative.Weight = 1.0;
449        if (variableNodeGroup.Count() > 1) {
450          var poly = mulSymbol.CreateTreeNode();
451          for (int p = 0; p < variableNodeGroup.Count(); p++) {
452            poly.AddSubTree((SymbolicExpressionTreeNode)representative.Clone());
453          }
454          prod.AddSubTree(poly);
455        } else {
456          prod.AddSubTree(representative);
457        }
458      }
459
460      foreach (var unchangedSubtree in unchangedSubTrees)
461        prod.AddSubTree(unchangedSubtree);
462
463      if (!constantProduct.IsAlmost(1.0)) {
464        prod.AddSubTree(MakeConstant(constantProduct));
465      }
466    }
467
468
469    #region helper functions
470    /// <summary>
471    /// x => x * -1
472    /// Doesn't create new trees and manipulates x
473    /// </summary>
474    /// <param name="x"></param>
475    /// <returns>-x</returns>
476    private SymbolicExpressionTreeNode Negate(SymbolicExpressionTreeNode x) {
477      if (IsConstant(x)) {
478        ((ConstantTreeNode)x).Value *= -1;
479      } else if (IsVariable(x)) {
480        var variableTree = (VariableTreeNode)x;
481        variableTree.Weight *= -1.0;
482      } else if (IsAddition(x)) {
483        // (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)       
484        foreach (var subTree in x.SubTrees) {
485          Negate(subTree);
486        }
487      } else if (IsMultiplication(x) || IsDivision(x)) {
488        // x0 * x1 * .. * xn * -1 => x0 * x1 * .. * -xn
489        Negate(x.SubTrees.Last()); // last is maybe a constant, prefer to negate the constant
490      } else {
491        // any other function
492        return MakeProduct(x, MakeConstant(-1));
493      }
494      return x;
495    }
496
497    /// <summary>
498    /// x => 1/x
499    /// Doesn't create new trees and manipulates x
500    /// </summary>
501    /// <param name="x"></param>
502    /// <returns></returns>
503    private SymbolicExpressionTreeNode Invert(SymbolicExpressionTreeNode x) {
504      if (IsConstant(x)) {
505        return MakeConstant(1.0 / ((ConstantTreeNode)x).Value);
506      } else if (IsDivision(x)) {
507        Trace.Assert(x.SubTrees.Count == 2);
508        return MakeFraction(x.SubTrees[1], x.SubTrees[0]);
509      } else {
510        // any other function
511        return MakeFraction(MakeConstant(1), x);
512      }
513    }
514
515    private SymbolicExpressionTreeNode MakeConstant(double value) {
516      ConstantTreeNode constantTreeNode = (ConstantTreeNode)(constSymbol.CreateTreeNode());
517      constantTreeNode.Value = value;
518      return (SymbolicExpressionTreeNode)constantTreeNode;
519    }
520
521    private SymbolicExpressionTreeNode MakeVariable(double weight, string name) {
522      var tree = (VariableTreeNode)varSymbol.CreateTreeNode();
523      tree.Weight = weight;
524      tree.VariableName = name;
525      return tree;
526    }
527    #endregion
528  }
529}
Note: See TracBrowser for help on using the repository browser.