Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SymbolicSimplifier.cs @ 5461

Last change on this file since 5461 was 5461, checked in by gkronber, 12 years ago

#1227 implemented transformations in simplifier to successfully run current set of test cases.

File size: 34.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
29using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  /// <summary>
33  /// Simplistic simplifier for arithmetic expressions
34  /// </summary>
35  public class SymbolicSimplifier {
36    private Addition addSymbol = new Addition();
37    private Subtraction subSymbol = new Subtraction();
38    private Multiplication mulSymbol = new Multiplication();
39    private Division divSymbol = new Division();
40    private Constant constSymbol = new Constant();
41    private Variable varSymbol = new Variable();
42    private Logarithm logSymbol = new Logarithm();
43    private Exponential expSymbol = new Exponential();
44    private Sine sineSymbol = new Sine();
45    private Cosine cosineSymbol = new Cosine();
46    private Tangent tanSymbol = new Tangent();
47    private IfThenElse ifThenElseSymbol = new IfThenElse();
48    private And andSymbol = new And();
49    private Or orSymbol = new Or();
50    private Not notSymbol = new Not();
51    private GreaterThan gtSymbol = new GreaterThan();
52    private LessThan ltSymbol = new LessThan();
53
54    public SymbolicExpressionTree Simplify(SymbolicExpressionTree originalTree) {
55      var clone = (SymbolicExpressionTreeNode)originalTree.Root.Clone();
56      // macro expand (initially no argument trees)
57      var macroExpandedTree = MacroExpand(clone, clone.SubTrees[0], new List<SymbolicExpressionTreeNode>());
58      SymbolicExpressionTreeNode rootNode = (new ProgramRootSymbol()).CreateTreeNode();
59      rootNode.AddSubTree(GetSimplifiedTree(macroExpandedTree));
60      return new SymbolicExpressionTree(rootNode);
61    }
62
63    // the argumentTrees list contains already expanded trees used as arguments for invocations
64    private SymbolicExpressionTreeNode MacroExpand(SymbolicExpressionTreeNode root, SymbolicExpressionTreeNode node, IList<SymbolicExpressionTreeNode> argumentTrees) {
65      List<SymbolicExpressionTreeNode> subtrees = new List<SymbolicExpressionTreeNode>(node.SubTrees);
66      while (node.SubTrees.Count > 0) node.RemoveSubTree(0);
67      if (node.Symbol is InvokeFunction) {
68        var invokeSym = node.Symbol as InvokeFunction;
69        var defunNode = FindFunctionDefinition(root, invokeSym.FunctionName);
70        var macroExpandedArguments = new List<SymbolicExpressionTreeNode>();
71        foreach (var subtree in subtrees) {
72          macroExpandedArguments.Add(MacroExpand(root, subtree, argumentTrees));
73        }
74        return MacroExpand(root, defunNode, macroExpandedArguments);
75      } else if (node.Symbol is Argument) {
76        var argSym = node.Symbol as Argument;
77        // return the correct argument sub-tree (already macro-expanded)
78        return (SymbolicExpressionTreeNode)argumentTrees[argSym.ArgumentIndex].Clone();
79      } else {
80        // recursive application
81        foreach (var subtree in subtrees) {
82          node.AddSubTree(MacroExpand(root, subtree, argumentTrees));
83        }
84        return node;
85      }
86    }
87
88    private SymbolicExpressionTreeNode FindFunctionDefinition(SymbolicExpressionTreeNode root, string functionName) {
89      foreach (var subtree in root.SubTrees.OfType<DefunTreeNode>()) {
90        if (subtree.FunctionName == functionName) return subtree.SubTrees[0];
91      }
92
93      throw new ArgumentException("Definition of function " + functionName + " not found.");
94    }
95
96
97    #region symbol predicates
98    // arithmetic
99    private bool IsDivision(SymbolicExpressionTreeNode node) {
100      return node.Symbol is Division;
101    }
102
103    private bool IsMultiplication(SymbolicExpressionTreeNode node) {
104      return node.Symbol is Multiplication;
105    }
106
107    private bool IsSubtraction(SymbolicExpressionTreeNode node) {
108      return node.Symbol is Subtraction;
109    }
110
111    private bool IsAddition(SymbolicExpressionTreeNode node) {
112      return node.Symbol is Addition;
113    }
114
115    private bool IsAverage(SymbolicExpressionTreeNode node) {
116      return node.Symbol is Average;
117    }
118    // exponential
119    private bool IsLog(SymbolicExpressionTreeNode node) {
120      return node.Symbol is Logarithm;
121    }
122    private bool IsExp(SymbolicExpressionTreeNode node) {
123      return node.Symbol is Exponential;
124    }
125    // trigonometric
126    private bool IsSine(SymbolicExpressionTreeNode node) {
127      return node.Symbol is Sine;
128    }
129    private bool IsCosine(SymbolicExpressionTreeNode node) {
130      return node.Symbol is Cosine;
131    }
132    private bool IsTangent(SymbolicExpressionTreeNode node) {
133      return node.Symbol is Tangent;
134    }
135    // boolean
136    private bool IsIfThenElse(SymbolicExpressionTreeNode node) {
137      return node.Symbol is IfThenElse;
138    }
139    private bool IsAnd(SymbolicExpressionTreeNode node) {
140      return node.Symbol is And;
141    }
142    private bool IsOr(SymbolicExpressionTreeNode node) {
143      return node.Symbol is Or;
144    }
145    private bool IsNot(SymbolicExpressionTreeNode node) {
146      return node.Symbol is Not;
147    }
148    // comparison
149    private bool IsGreaterThan(SymbolicExpressionTreeNode node) {
150      return node.Symbol is GreaterThan;
151    }
152    private bool IsLessThan(SymbolicExpressionTreeNode node) {
153      return node.Symbol is LessThan;
154    }
155
156    private bool IsBoolean(SymbolicExpressionTreeNode node) {
157      return
158        node.Symbol is GreaterThan ||
159        node.Symbol is LessThan ||
160        node.Symbol is And ||
161        node.Symbol is Or;
162    }
163
164    // terminals
165    private bool IsVariable(SymbolicExpressionTreeNode node) {
166      return node.Symbol is Variable;
167    }
168
169    private bool IsConstant(SymbolicExpressionTreeNode node) {
170      return node.Symbol is Constant;
171    }
172
173    #endregion
174
175    /// <summary>
176    /// Creates a new simplified tree
177    /// </summary>
178    /// <param name="original"></param>
179    /// <returns></returns>
180    public SymbolicExpressionTreeNode GetSimplifiedTree(SymbolicExpressionTreeNode original) {
181      if (IsConstant(original) || IsVariable(original)) {
182        return (SymbolicExpressionTreeNode)original.Clone();
183      } else if (IsAddition(original)) {
184        return SimplifyAddition(original);
185      } else if (IsSubtraction(original)) {
186        return SimplifySubtraction(original);
187      } else if (IsMultiplication(original)) {
188        return SimplifyMultiplication(original);
189      } else if (IsDivision(original)) {
190        return SimplifyDivision(original);
191      } else if (IsAverage(original)) {
192        return SimplifyAverage(original);
193      } else if (IsLog(original)) {
194        return SimplifyLog(original);
195      } else if (IsExp(original)) {
196        return SimplifyExp(original);
197      } else if (IsSine(original)) {
198        return SimplifySine(original);
199      } else if (IsCosine(original)) {
200        return SimplifyCosine(original);
201      } else if (IsTangent(original)) {
202        return SimplifyTangent(original);
203      } else if (IsIfThenElse(original)) {
204        return SimplifyIfThenElse(original);
205      } else if (IsGreaterThan(original)) {
206        return SimplifyGreaterThan(original);
207      } else if (IsLessThan(original)) {
208        return SimplifyLessThan(original);
209      } else if (IsAnd(original)) {
210        return SimplifyAnd(original);
211      } else if (IsOr(original)) {
212        return SimplifyOr(original);
213      } else if (IsNot(original)) {
214        return SimplifyNot(original);
215      } else {
216        return SimplifyAny(original);
217      }
218    }
219
220
221    #region specific simplification routines
222    private SymbolicExpressionTreeNode SimplifyAny(SymbolicExpressionTreeNode original) {
223      // can't simplify this function but simplify all subtrees
224      List<SymbolicExpressionTreeNode> subTrees = new List<SymbolicExpressionTreeNode>(original.SubTrees);
225      while (original.SubTrees.Count > 0) original.RemoveSubTree(0);
226      var clone = (SymbolicExpressionTreeNode)original.Clone();
227      List<SymbolicExpressionTreeNode> simplifiedSubTrees = new List<SymbolicExpressionTreeNode>();
228      foreach (var subTree in subTrees) {
229        simplifiedSubTrees.Add(GetSimplifiedTree(subTree));
230        original.AddSubTree(subTree);
231      }
232      foreach (var simplifiedSubtree in simplifiedSubTrees) {
233        clone.AddSubTree(simplifiedSubtree);
234      }
235      if (simplifiedSubTrees.TrueForAll(t => IsConstant(t))) {
236        SimplifyConstantExpression(clone);
237      }
238      return clone;
239    }
240
241    private SymbolicExpressionTreeNode SimplifyConstantExpression(SymbolicExpressionTreeNode original) {
242      // not yet implemented
243      return original;
244    }
245
246    private SymbolicExpressionTreeNode SimplifyAverage(SymbolicExpressionTreeNode original) {
247      if (original.SubTrees.Count == 1) {
248        return GetSimplifiedTree(original.SubTrees[0]);
249      } else {
250        // simplify expressions x0..xn
251        // make sum(x0..xn) / n
252        Trace.Assert(original.SubTrees.Count > 1);
253        var sum = original.SubTrees
254          .Select(x => GetSimplifiedTree(x))
255          .Aggregate((a, b) => MakeSum(a, b));
256        return MakeFraction(sum, MakeConstant(original.SubTrees.Count));
257      }
258    }
259
260    private SymbolicExpressionTreeNode SimplifyDivision(SymbolicExpressionTreeNode original) {
261      if (original.SubTrees.Count == 1) {
262        return Invert(GetSimplifiedTree(original.SubTrees[0]));
263      } else {
264        // simplify expressions x0..xn
265        // make multiplication (x0 * 1/(x1 * x1 * .. * xn))
266        Trace.Assert(original.SubTrees.Count > 1);
267        var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
268        return
269          MakeProduct(simplifiedTrees.First(), Invert(simplifiedTrees.Skip(1).Aggregate((a, b) => MakeProduct(a, b))));
270      }
271    }
272
273    private SymbolicExpressionTreeNode SimplifyMultiplication(SymbolicExpressionTreeNode original) {
274      if (original.SubTrees.Count == 1) {
275        return GetSimplifiedTree(original.SubTrees[0]);
276      } else {
277        Trace.Assert(original.SubTrees.Count > 1);
278        return original.SubTrees
279          .Select(x => GetSimplifiedTree(x))
280          .Aggregate((a, b) => MakeProduct(a, b));
281      }
282    }
283
284    private SymbolicExpressionTreeNode SimplifySubtraction(SymbolicExpressionTreeNode original) {
285      if (original.SubTrees.Count == 1) {
286        return Negate(GetSimplifiedTree(original.SubTrees[0]));
287      } else {
288        // simplify expressions x0..xn
289        // make addition (x0,-x1..-xn)
290        Trace.Assert(original.SubTrees.Count > 1);
291        var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
292        return simplifiedTrees.Take(1)
293          .Concat(simplifiedTrees.Skip(1).Select(x => Negate(x)))
294          .Aggregate((a, b) => MakeSum(a, b));
295      }
296    }
297
298    private SymbolicExpressionTreeNode SimplifyAddition(SymbolicExpressionTreeNode original) {
299      if (original.SubTrees.Count == 1) {
300        return GetSimplifiedTree(original.SubTrees[0]);
301      } else {
302        // simplify expression x0..xn
303        // make addition (x0..xn)
304        Trace.Assert(original.SubTrees.Count > 1);
305        return original.SubTrees
306          .Select(x => GetSimplifiedTree(x))
307          .Aggregate((a, b) => MakeSum(a, b));
308      }
309    }
310
311    private SymbolicExpressionTreeNode SimplifyNot(SymbolicExpressionTreeNode original) {
312      return MakeNot(GetSimplifiedTree(original.SubTrees[0]));
313    }
314    private SymbolicExpressionTreeNode SimplifyOr(SymbolicExpressionTreeNode original) {
315      return original.SubTrees
316        .Select(x => GetSimplifiedTree(x))
317        .Aggregate((a, b) => MakeOr(a, b));
318    }
319    private SymbolicExpressionTreeNode SimplifyAnd(SymbolicExpressionTreeNode original) {
320      return original.SubTrees
321        .Select(x => GetSimplifiedTree(x))
322        .Aggregate((a, b) => MakeAnd(a, b));
323    }
324    private SymbolicExpressionTreeNode SimplifyLessThan(SymbolicExpressionTreeNode original) {
325      return MakeLessThan(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]));
326    }
327    private SymbolicExpressionTreeNode SimplifyGreaterThan(SymbolicExpressionTreeNode original) {
328      return MakeGreaterThan(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]));
329    }
330    private SymbolicExpressionTreeNode SimplifyIfThenElse(SymbolicExpressionTreeNode original) {
331      return MakeIfThenElse(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]), GetSimplifiedTree(original.SubTrees[2]));
332    }
333    private SymbolicExpressionTreeNode SimplifyTangent(SymbolicExpressionTreeNode original) {
334      return MakeTangent(GetSimplifiedTree(original.SubTrees[0]));
335    }
336    private SymbolicExpressionTreeNode SimplifyCosine(SymbolicExpressionTreeNode original) {
337      return MakeCosine(GetSimplifiedTree(original.SubTrees[0]));
338    }
339    private SymbolicExpressionTreeNode SimplifySine(SymbolicExpressionTreeNode original) {
340      return MakeSine(GetSimplifiedTree(original.SubTrees[0]));
341    }
342    private SymbolicExpressionTreeNode SimplifyExp(SymbolicExpressionTreeNode original) {
343      return MakeExp(GetSimplifiedTree(original.SubTrees[0]));
344    }
345
346    private SymbolicExpressionTreeNode SimplifyLog(SymbolicExpressionTreeNode original) {
347      return MakeLog(GetSimplifiedTree(original.SubTrees[0]));
348    }
349
350    #endregion
351
352
353
354    #region low level tree restructuring
355    private SymbolicExpressionTreeNode MakeNot(SymbolicExpressionTreeNode t) {
356      return MakeProduct(t, MakeConstant(-1.0));
357    }
358
359    private SymbolicExpressionTreeNode MakeOr(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
360      if (IsConstant(a) && IsConstant(b)) {
361        var constA = a as ConstantTreeNode;
362        var constB = b as ConstantTreeNode;
363        if (constA.Value > 0.0 || constB.Value > 0.0) {
364          return MakeConstant(1.0);
365        } else {
366          return MakeConstant(-1.0);
367        }
368      } else if (IsConstant(a)) {
369        return MakeOr(b, a);
370      } else if (IsConstant(b)) {
371        var constT = b as ConstantTreeNode;
372        if (constT.Value > 0.0) {
373          // boolean expression is necessarily true
374          return MakeConstant(1.0);
375        } else {
376          // the constant value has no effect on the result of the boolean condition so we can drop the constant term
377          var orNode = orSymbol.CreateTreeNode();
378          orNode.AddSubTree(a);
379          return orNode;
380        }
381      } else {
382        var orNode = orSymbol.CreateTreeNode();
383        orNode.AddSubTree(a);
384        orNode.AddSubTree(b);
385        return orNode;
386      }
387    }
388    private SymbolicExpressionTreeNode MakeAnd(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
389      if (IsConstant(a) && IsConstant(b)) {
390        var constA = a as ConstantTreeNode;
391        var constB = b as ConstantTreeNode;
392        if (constA.Value > 0.0 && constB.Value > 0.0) {
393          return MakeConstant(1.0);
394        } else {
395          return MakeConstant(-1.0);
396        }
397      } else if (IsConstant(a)) {
398        return MakeAnd(b, a);
399      } else if (IsConstant(b)) {
400        var constB = b as ConstantTreeNode;
401        if (constB.Value > 0.0) {
402          // the constant value has no effect on the result of the boolean condition so we can drop the constant term
403          var andNode = andSymbol.CreateTreeNode();
404          andNode.AddSubTree(a);
405          return andNode;
406        } else {
407          // boolean expression is necessarily false
408          return MakeConstant(-1.0);
409        }
410      } else {
411        var andNode = andSymbol.CreateTreeNode();
412        andNode.AddSubTree(a);
413        andNode.AddSubTree(b);
414        return andNode;
415      }
416    }
417    private SymbolicExpressionTreeNode MakeLessThan(SymbolicExpressionTreeNode leftSide, SymbolicExpressionTreeNode rightSide) {
418      if (IsConstant(leftSide) && IsConstant(rightSide)) {
419        var lsConst = leftSide as ConstantTreeNode;
420        var rsConst = rightSide as ConstantTreeNode;
421        if (lsConst.Value < rsConst.Value) return MakeConstant(1.0);
422        else return MakeConstant(-1.0);
423      } else {
424        var ltNode = ltSymbol.CreateTreeNode();
425        ltNode.AddSubTree(leftSide);
426        ltNode.AddSubTree(rightSide);
427        return ltNode;
428      }
429    }
430    private SymbolicExpressionTreeNode MakeGreaterThan(SymbolicExpressionTreeNode leftSide, SymbolicExpressionTreeNode rightSide) {
431      if (IsConstant(leftSide) && IsConstant(rightSide)) {
432        var lsConst = leftSide as ConstantTreeNode;
433        var rsConst = rightSide as ConstantTreeNode;
434        if (lsConst.Value > rsConst.Value) return MakeConstant(1.0);
435        else return MakeConstant(-1.0);
436      } else {
437        var gtNode = gtSymbol.CreateTreeNode();
438        gtNode.AddSubTree(leftSide);
439        gtNode.AddSubTree(rightSide);
440        return gtNode;
441      }
442    }
443    private SymbolicExpressionTreeNode MakeIfThenElse(SymbolicExpressionTreeNode condition, SymbolicExpressionTreeNode trueBranch, SymbolicExpressionTreeNode falseBranch) {
444      if (IsConstant(condition)) {
445        var constT = condition as ConstantTreeNode;
446        if (constT.Value > 0.0) return trueBranch;
447        else return falseBranch;
448      } else {
449        var ifNode = ifThenElseSymbol.CreateTreeNode();
450        if (IsBoolean(condition)) {
451          ifNode.AddSubTree(condition);
452        } else {
453          var gtNode = gtSymbol.CreateTreeNode();
454          gtNode.AddSubTree(condition); gtNode.AddSubTree(MakeConstant(0.0));
455          ifNode.AddSubTree(gtNode);
456        }
457        ifNode.AddSubTree(trueBranch);
458        ifNode.AddSubTree(falseBranch);
459        return ifNode;
460      }
461    }
462
463    private SymbolicExpressionTreeNode MakeSine(SymbolicExpressionTreeNode node) {
464      // todo implement more transformation rules
465      if (IsConstant(node)) {
466        var constT = node as ConstantTreeNode;
467        return MakeConstant(Math.Sin(constT.Value));
468      } else {
469        var sineNode = sineSymbol.CreateTreeNode();
470        sineNode.AddSubTree(node);
471        return sineNode;
472      }
473    }
474    private SymbolicExpressionTreeNode MakeTangent(SymbolicExpressionTreeNode node) {
475      // todo implement more transformation rules
476      if (IsConstant(node)) {
477        var constT = node as ConstantTreeNode;
478        return MakeConstant(Math.Tan(constT.Value));
479      } else {
480        var tanNode = tanSymbol.CreateTreeNode();
481        tanNode.AddSubTree(node);
482        return tanNode;
483      }
484    }
485    private SymbolicExpressionTreeNode MakeCosine(SymbolicExpressionTreeNode node) {
486      // todo implement more transformation rules
487      if (IsConstant(node)) {
488        var constT = node as ConstantTreeNode;
489        return MakeConstant(Math.Cos(constT.Value));
490      } else {
491        var cosNode = cosineSymbol.CreateTreeNode();
492        cosNode.AddSubTree(node);
493        return cosNode;
494      }
495    }
496    private SymbolicExpressionTreeNode MakeExp(SymbolicExpressionTreeNode node) {
497      // todo implement more transformation rules
498      if (IsConstant(node)) {
499        var constT = node as ConstantTreeNode;
500        return MakeConstant(Math.Exp(constT.Value));
501      } else if (IsLog(node)) {
502        return node.SubTrees[0];
503      } else {
504        var expNode = expSymbol.CreateTreeNode();
505        expNode.AddSubTree(node);
506        return expNode;
507      }
508    }
509    private SymbolicExpressionTreeNode MakeLog(SymbolicExpressionTreeNode node) {
510      // todo implement more transformation rules
511      if (IsConstant(node)) {
512        var constT = node as ConstantTreeNode;
513        return MakeConstant(Math.Log(constT.Value));
514      } else if (IsExp(node)) {
515        return node.SubTrees[0];
516      } else if (IsMultiplication(node)) {
517        return node.SubTrees.Select(s => MakeLog(s)).Aggregate((x, y) => MakeSum(x, y));
518      } else if (IsDivision(node)) {
519        var subtractionNode = subSymbol.CreateTreeNode();
520        foreach (var subTree in node.SubTrees) {
521          subtractionNode.AddSubTree(MakeLog(subTree));
522        }
523        return subtractionNode;
524      } else {
525        var logNode = logSymbol.CreateTreeNode();
526        logNode.AddSubTree(node);
527        return logNode;
528      }
529    }
530
531
532    // MakeFraction, MakeProduct and MakeSum take two already simplified trees and create a new simplified tree
533    private SymbolicExpressionTreeNode MakeFraction(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
534      if (IsConstant(a) && IsConstant(b)) {
535        // fold constants
536        return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
537      } if (IsConstant(a) && !((ConstantTreeNode)a).Value.IsAlmost(1.0)) {
538        return MakeFraction(MakeConstant(1.0), MakeProduct(b, Invert(a)));
539      } else if (IsVariable(a) && IsConstant(b)) {
540        // merge constant values into variable weights
541        var constB = ((ConstantTreeNode)b).Value;
542        ((VariableTreeNode)a).Weight /= constB;
543        return a;
544      } else if (IsVariable(a) && IsVariable(b) && AreSameVariable(a, b)) {
545        // cancel variables
546        var aVar = a as VariableTreeNode;
547        var bVar = b as VariableTreeNode;
548        return MakeConstant(aVar.Weight / bVar.Weight);
549      } else if (IsAddition(a) && IsConstant(b)) {
550        return a.SubTrees
551          .Select(x => GetSimplifiedTree(x))
552         .Select(x => MakeFraction(x, b))
553         .Aggregate((c, d) => MakeSum(c, d));
554      } else if (IsMultiplication(a) && IsConstant(b)) {
555        return MakeProduct(a, Invert(b));
556      } else if (IsDivision(a) && IsConstant(b)) {
557        // (a1 / a2) / c => (a1 / (a2 * c))
558        Trace.Assert(a.SubTrees.Count == 2);
559        return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
560      } else if (IsDivision(a) && IsDivision(b)) {
561        // (a1 / a2) / (b1 / b2) =>
562        Trace.Assert(a.SubTrees.Count == 2);
563        Trace.Assert(b.SubTrees.Count == 2);
564        return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[1]), MakeProduct(a.SubTrees[1], b.SubTrees[0]));
565      } else if (IsDivision(a)) {
566        // (a1 / a2) / b => (a1 / (a2 * b))
567        Trace.Assert(a.SubTrees.Count == 2);
568        return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
569      } else if (IsDivision(b)) {
570        // a / (b1 / b2) => (a * b2) / b1
571        Trace.Assert(b.SubTrees.Count == 2);
572        return MakeFraction(MakeProduct(a, b.SubTrees[1]), b.SubTrees[0]);
573      } else {
574        var div = divSymbol.CreateTreeNode();
575        div.AddSubTree(a);
576        div.AddSubTree(b);
577        return div;
578      }
579    }
580
581    private SymbolicExpressionTreeNode MakeSum(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
582      if (IsConstant(a) && IsConstant(b)) {
583        // fold constants
584        ((ConstantTreeNode)a).Value += ((ConstantTreeNode)b).Value;
585        return a;
586      } else if (IsConstant(a)) {
587        // c + x => x + c
588        // b is not constant => make sure constant is on the right
589        return MakeSum(b, a);
590      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(0.0)) {
591        // x + 0 => x
592        return a;
593      } else if (IsAddition(a) && IsAddition(b)) {
594        // merge additions
595        var add = addSymbol.CreateTreeNode();
596        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
597        for (int i = 0; i < b.SubTrees.Count - 1; i++) add.AddSubTree(b.SubTrees[i]);
598        if (IsConstant(a.SubTrees.Last()) && IsConstant(b.SubTrees.Last())) {
599          add.AddSubTree(MakeSum(a.SubTrees.Last(), b.SubTrees.Last()));
600        } else if (IsConstant(a.SubTrees.Last())) {
601          add.AddSubTree(b.SubTrees.Last());
602          add.AddSubTree(a.SubTrees.Last());
603        } else {
604          add.AddSubTree(a.SubTrees.Last());
605          add.AddSubTree(b.SubTrees.Last());
606        }
607        MergeVariablesInSum(add);
608        if (add.SubTrees.Count == 1) {
609          return add.SubTrees[0];
610        } else {
611          return add;
612        }
613      } else if (IsAddition(b)) {
614        return MakeSum(b, a);
615      } else if (IsAddition(a) && IsConstant(b)) {
616        // a is an addition and b is a constant => append b to a and make sure the constants are merged
617        var add = addSymbol.CreateTreeNode();
618        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
619        if (IsConstant(a.SubTrees.Last()))
620          add.AddSubTree(MakeSum(a.SubTrees.Last(), b));
621        else {
622          add.AddSubTree(a.SubTrees.Last());
623          add.AddSubTree(b);
624        }
625        return add;
626      } else if (IsAddition(a)) {
627        // a is already an addition => append b
628        var add = addSymbol.CreateTreeNode();
629        add.AddSubTree(b);
630        foreach (var subTree in a.SubTrees) {
631          add.AddSubTree(subTree);
632        }
633        MergeVariablesInSum(add);
634        if (add.SubTrees.Count == 1) {
635          return add.SubTrees[0];
636        } else {
637          return add;
638        }
639      } else {
640        var add = addSymbol.CreateTreeNode();
641        add.AddSubTree(a);
642        add.AddSubTree(b);
643        MergeVariablesInSum(add);
644        if (add.SubTrees.Count == 1) {
645          return add.SubTrees[0];
646        } else {
647          return add;
648        }
649      }
650    }
651
652    // makes sure variable symbols in sums are combined
653    // possible improvement: combine sums of products where the products only reference the same variable
654    private void MergeVariablesInSum(SymbolicExpressionTreeNode sum) {
655      var subtrees = new List<SymbolicExpressionTreeNode>(sum.SubTrees);
656      while (sum.SubTrees.Count > 0) sum.RemoveSubTree(0);
657      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
658                            let lag = (node is LaggedVariableTreeNode) ? ((LaggedVariableTreeNode)node).Lag : 0
659                            group node by node.VariableName + lag into g
660                            select g;
661      var unchangedSubTrees = subtrees.Where(t => !(t is VariableTreeNode));
662
663      foreach (var variableNodeGroup in groupedVarNodes) {
664        var weightSum = variableNodeGroup.Select(t => t.Weight).Sum();
665        var representative = variableNodeGroup.First();
666        representative.Weight = weightSum;
667        sum.AddSubTree(representative);
668      }
669      foreach (var unchangedSubtree in unchangedSubTrees)
670        sum.AddSubTree(unchangedSubtree);
671    }
672
673
674    private SymbolicExpressionTreeNode MakeProduct(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
675      if (IsConstant(a) && IsConstant(b)) {
676        // fold constants
677        ((ConstantTreeNode)a).Value *= ((ConstantTreeNode)b).Value;
678        return a;
679      } else if (IsConstant(a)) {
680        // a * $ => $ * a
681        return MakeProduct(b, a);
682      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
683        // $ * 1.0 => $
684        return a;
685      } else if (IsConstant(b) && IsVariable(a)) {
686        // multiply constants into variables weights
687        ((VariableTreeNode)a).Weight *= ((ConstantTreeNode)b).Value;
688        return a;
689      } else if (IsConstant(b) && IsAddition(a)) {
690        // multiply constants into additions
691        return a.SubTrees.Select(x => MakeProduct(x, b)).Aggregate((c, d) => MakeSum(c, d));
692      } else if (IsDivision(a) && IsDivision(b)) {
693        // (a1 / a2) * (b1 / b2) => (a1 * b1) / (a2 * b2)
694        Trace.Assert(a.SubTrees.Count == 2);
695        Trace.Assert(b.SubTrees.Count == 2);
696        return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[0]), MakeProduct(a.SubTrees[1], b.SubTrees[1]));
697      } else if (IsDivision(a)) {
698        // (a1 / a2) * b => (a1 * b) / a2
699        Trace.Assert(a.SubTrees.Count == 2);
700        return MakeFraction(MakeProduct(a.SubTrees[0], b), a.SubTrees[1]);
701      } else if (IsDivision(b)) {
702        // a * (b1 / b2) => (b1 * a) / b2
703        Trace.Assert(b.SubTrees.Count == 2);
704        return MakeFraction(MakeProduct(b.SubTrees[0], a), b.SubTrees[1]);
705      } else if (IsMultiplication(a) && IsMultiplication(b)) {
706        // merge multiplications (make sure constants are merged)
707        var mul = mulSymbol.CreateTreeNode();
708        for (int i = 0; i < a.SubTrees.Count; i++) mul.AddSubTree(a.SubTrees[i]);
709        for (int i = 0; i < b.SubTrees.Count; i++) mul.AddSubTree(b.SubTrees[i]);
710        MergeVariablesAndConstantsInProduct(mul);
711        return mul;
712      } else if (IsMultiplication(b)) {
713        return MakeProduct(b, a);
714      } else if (IsMultiplication(a)) {
715        // a is already an multiplication => append b
716        a.AddSubTree(b);
717        MergeVariablesAndConstantsInProduct(a);
718        return a;
719      } else {
720        var mul = mulSymbol.CreateTreeNode();
721        mul.SubTrees.Add(a);
722        mul.SubTrees.Add(b);
723        MergeVariablesAndConstantsInProduct(mul);
724        return mul;
725      }
726    }
727    #endregion
728
729
730    #region helper functions
731
732    private bool AreSameVariable(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
733      var aLaggedVar = a as LaggedVariableTreeNode;
734      var bLaggedVar = b as LaggedVariableTreeNode;
735      if (aLaggedVar != null && bLaggedVar != null) {
736        return aLaggedVar.VariableName == bLaggedVar.VariableName &&
737          aLaggedVar.Lag == bLaggedVar.Lag;
738      }
739      var aVar = a as VariableTreeNode;
740      var bVar = b as VariableTreeNode;
741      if (aVar != null && bVar != null) {
742        return aVar.VariableName == bVar.VariableName;
743      }
744      return false;
745    }
746
747    // helper to combine the constant factors in products and to combine variables (powers of 2, 3...)
748    private void MergeVariablesAndConstantsInProduct(SymbolicExpressionTreeNode prod) {
749      var subtrees = new List<SymbolicExpressionTreeNode>(prod.SubTrees);
750      while (prod.SubTrees.Count > 0) prod.RemoveSubTree(0);
751      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
752                            let lag = (node is LaggedVariableTreeNode) ? ((LaggedVariableTreeNode)node).Lag : 0
753                            group node by node.VariableName + lag into g
754                            orderby g.Count()
755                            select g;
756      var constantProduct = (from node in subtrees.OfType<VariableTreeNode>()
757                             select node.Weight)
758                            .Concat(from node in subtrees.OfType<ConstantTreeNode>()
759                                    select node.Value)
760                            .DefaultIfEmpty(1.0)
761                            .Aggregate((c1, c2) => c1 * c2);
762
763      var unchangedSubTrees = from tree in subtrees
764                              where !(tree is VariableTreeNode)
765                              where !(tree is ConstantTreeNode)
766                              select tree;
767
768      foreach (var variableNodeGroup in groupedVarNodes) {
769        var representative = variableNodeGroup.First();
770        representative.Weight = 1.0;
771        if (variableNodeGroup.Count() > 1) {
772          var poly = mulSymbol.CreateTreeNode();
773          for (int p = 0; p < variableNodeGroup.Count(); p++) {
774            poly.AddSubTree((SymbolicExpressionTreeNode)representative.Clone());
775          }
776          prod.AddSubTree(poly);
777        } else {
778          prod.AddSubTree(representative);
779        }
780      }
781
782      foreach (var unchangedSubtree in unchangedSubTrees)
783        prod.AddSubTree(unchangedSubtree);
784
785      if (!constantProduct.IsAlmost(1.0)) {
786        prod.AddSubTree(MakeConstant(constantProduct));
787      }
788    }
789
790
791    /// <summary>
792    /// x => x * -1
793    /// Doesn't create new trees and manipulates x
794    /// </summary>
795    /// <param name="x"></param>
796    /// <returns>-x</returns>
797    private SymbolicExpressionTreeNode Negate(SymbolicExpressionTreeNode x) {
798      if (IsConstant(x)) {
799        ((ConstantTreeNode)x).Value *= -1;
800      } else if (IsVariable(x)) {
801        var variableTree = (VariableTreeNode)x;
802        variableTree.Weight *= -1.0;
803      } else if (IsAddition(x)) {
804        // (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)       
805        for (int i = 0; i < x.SubTrees.Count; i++)
806          x.SubTrees[i] = Negate(x.SubTrees[i]);
807      } else if (IsMultiplication(x) || IsDivision(x)) {
808        // x0 * x1 * .. * xn * -1 => x0 * x1 * .. * -xn
809        x.SubTrees[x.SubTrees.Count - 1] = Negate(x.SubTrees.Last()); // last is maybe a constant, prefer to negate the constant
810      } else {
811        // any other function
812        return MakeProduct(x, MakeConstant(-1));
813      }
814      return x;
815    }
816
817    /// <summary>
818    /// x => 1/x
819    /// Doesn't create new trees and manipulates x
820    /// </summary>
821    /// <param name="x"></param>
822    /// <returns></returns>
823    private SymbolicExpressionTreeNode Invert(SymbolicExpressionTreeNode x) {
824      if (IsConstant(x)) {
825        return MakeConstant(1.0 / ((ConstantTreeNode)x).Value);
826      } else if (IsDivision(x)) {
827        Trace.Assert(x.SubTrees.Count == 2);
828        return MakeFraction(x.SubTrees[1], x.SubTrees[0]);
829      } else {
830        // any other function
831        return MakeFraction(MakeConstant(1), x);
832      }
833    }
834
835    private SymbolicExpressionTreeNode MakeConstant(double value) {
836      ConstantTreeNode constantTreeNode = (ConstantTreeNode)(constSymbol.CreateTreeNode());
837      constantTreeNode.Value = value;
838      return (SymbolicExpressionTreeNode)constantTreeNode;
839    }
840
841    private SymbolicExpressionTreeNode MakeVariable(double weight, string name) {
842      var tree = (VariableTreeNode)varSymbol.CreateTreeNode();
843      tree.Weight = weight;
844      tree.VariableName = name;
845      return tree;
846    }
847    #endregion
848  }
849}
Note: See TracBrowser for help on using the repository browser.