Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SymbolicSimplifier.cs @ 5959

Last change on this file since 5959 was 5465, checked in by gkronber, 14 years ago

#1227 implemented test cases and transformation rules for root and power symbols.

File size: 37.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
29using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  /// <summary>
33  /// Simplistic simplifier for arithmetic expressions
34  /// </summary>
35  public class SymbolicSimplifier {
36    private Addition addSymbol = new Addition();
37    private Subtraction subSymbol = new Subtraction();
38    private Multiplication mulSymbol = new Multiplication();
39    private Division divSymbol = new Division();
40    private Constant constSymbol = new Constant();
41    private Variable varSymbol = new Variable();
42    private Logarithm logSymbol = new Logarithm();
43    private Exponential expSymbol = new Exponential();
44    private Root rootSymbol = new Root();
45    private Power powSymbol = new Power();
46    private Sine sineSymbol = new Sine();
47    private Cosine cosineSymbol = new Cosine();
48    private Tangent tanSymbol = new Tangent();
49    private IfThenElse ifThenElseSymbol = new IfThenElse();
50    private And andSymbol = new And();
51    private Or orSymbol = new Or();
52    private Not notSymbol = new Not();
53    private GreaterThan gtSymbol = new GreaterThan();
54    private LessThan ltSymbol = new LessThan();
55
56    public SymbolicExpressionTree Simplify(SymbolicExpressionTree originalTree) {
57      var clone = (SymbolicExpressionTreeNode)originalTree.Root.Clone();
58      // macro expand (initially no argument trees)
59      var macroExpandedTree = MacroExpand(clone, clone.SubTrees[0], new List<SymbolicExpressionTreeNode>());
60      SymbolicExpressionTreeNode rootNode = (new ProgramRootSymbol()).CreateTreeNode();
61      rootNode.AddSubTree(GetSimplifiedTree(macroExpandedTree));
62      return new SymbolicExpressionTree(rootNode);
63    }
64
65    // the argumentTrees list contains already expanded trees used as arguments for invocations
66    private SymbolicExpressionTreeNode MacroExpand(SymbolicExpressionTreeNode root, SymbolicExpressionTreeNode node, IList<SymbolicExpressionTreeNode> argumentTrees) {
67      List<SymbolicExpressionTreeNode> subtrees = new List<SymbolicExpressionTreeNode>(node.SubTrees);
68      while (node.SubTrees.Count > 0) node.RemoveSubTree(0);
69      if (node.Symbol is InvokeFunction) {
70        var invokeSym = node.Symbol as InvokeFunction;
71        var defunNode = FindFunctionDefinition(root, invokeSym.FunctionName);
72        var macroExpandedArguments = new List<SymbolicExpressionTreeNode>();
73        foreach (var subtree in subtrees) {
74          macroExpandedArguments.Add(MacroExpand(root, subtree, argumentTrees));
75        }
76        return MacroExpand(root, defunNode, macroExpandedArguments);
77      } else if (node.Symbol is Argument) {
78        var argSym = node.Symbol as Argument;
79        // return the correct argument sub-tree (already macro-expanded)
80        return (SymbolicExpressionTreeNode)argumentTrees[argSym.ArgumentIndex].Clone();
81      } else {
82        // recursive application
83        foreach (var subtree in subtrees) {
84          node.AddSubTree(MacroExpand(root, subtree, argumentTrees));
85        }
86        return node;
87      }
88    }
89
90    private SymbolicExpressionTreeNode FindFunctionDefinition(SymbolicExpressionTreeNode root, string functionName) {
91      foreach (var subtree in root.SubTrees.OfType<DefunTreeNode>()) {
92        if (subtree.FunctionName == functionName) return subtree.SubTrees[0];
93      }
94
95      throw new ArgumentException("Definition of function " + functionName + " not found.");
96    }
97
98
99    #region symbol predicates
100    // arithmetic
101    private bool IsDivision(SymbolicExpressionTreeNode node) {
102      return node.Symbol is Division;
103    }
104
105    private bool IsMultiplication(SymbolicExpressionTreeNode node) {
106      return node.Symbol is Multiplication;
107    }
108
109    private bool IsSubtraction(SymbolicExpressionTreeNode node) {
110      return node.Symbol is Subtraction;
111    }
112
113    private bool IsAddition(SymbolicExpressionTreeNode node) {
114      return node.Symbol is Addition;
115    }
116
117    private bool IsAverage(SymbolicExpressionTreeNode node) {
118      return node.Symbol is Average;
119    }
120    // exponential
121    private bool IsLog(SymbolicExpressionTreeNode node) {
122      return node.Symbol is Logarithm;
123    }
124    private bool IsExp(SymbolicExpressionTreeNode node) {
125      return node.Symbol is Exponential;
126    }
127    private bool IsRoot(SymbolicExpressionTreeNode node) {
128      return node.Symbol is Root;
129    }
130    private bool IsPower(SymbolicExpressionTreeNode node) {
131      return node.Symbol is Power;
132    }
133    // trigonometric
134    private bool IsSine(SymbolicExpressionTreeNode node) {
135      return node.Symbol is Sine;
136    }
137    private bool IsCosine(SymbolicExpressionTreeNode node) {
138      return node.Symbol is Cosine;
139    }
140    private bool IsTangent(SymbolicExpressionTreeNode node) {
141      return node.Symbol is Tangent;
142    }
143    // boolean
144    private bool IsIfThenElse(SymbolicExpressionTreeNode node) {
145      return node.Symbol is IfThenElse;
146    }
147    private bool IsAnd(SymbolicExpressionTreeNode node) {
148      return node.Symbol is And;
149    }
150    private bool IsOr(SymbolicExpressionTreeNode node) {
151      return node.Symbol is Or;
152    }
153    private bool IsNot(SymbolicExpressionTreeNode node) {
154      return node.Symbol is Not;
155    }
156    // comparison
157    private bool IsGreaterThan(SymbolicExpressionTreeNode node) {
158      return node.Symbol is GreaterThan;
159    }
160    private bool IsLessThan(SymbolicExpressionTreeNode node) {
161      return node.Symbol is LessThan;
162    }
163
164    private bool IsBoolean(SymbolicExpressionTreeNode node) {
165      return
166        node.Symbol is GreaterThan ||
167        node.Symbol is LessThan ||
168        node.Symbol is And ||
169        node.Symbol is Or;
170    }
171
172    // terminals
173    private bool IsVariable(SymbolicExpressionTreeNode node) {
174      return node.Symbol is Variable;
175    }
176
177    private bool IsConstant(SymbolicExpressionTreeNode node) {
178      return node.Symbol is Constant;
179    }
180
181    #endregion
182
183    /// <summary>
184    /// Creates a new simplified tree
185    /// </summary>
186    /// <param name="original"></param>
187    /// <returns></returns>
188    public SymbolicExpressionTreeNode GetSimplifiedTree(SymbolicExpressionTreeNode original) {
189      if (IsConstant(original) || IsVariable(original)) {
190        return (SymbolicExpressionTreeNode)original.Clone();
191      } else if (IsAddition(original)) {
192        return SimplifyAddition(original);
193      } else if (IsSubtraction(original)) {
194        return SimplifySubtraction(original);
195      } else if (IsMultiplication(original)) {
196        return SimplifyMultiplication(original);
197      } else if (IsDivision(original)) {
198        return SimplifyDivision(original);
199      } else if (IsAverage(original)) {
200        return SimplifyAverage(original);
201      } else if (IsLog(original)) {
202        return SimplifyLog(original);
203      } else if (IsExp(original)) {
204        return SimplifyExp(original);
205      } else if (IsRoot(original)) {
206        return SimplifyRoot(original);
207      } else if (IsPower(original)) {
208        return SimplifyPower(original);
209      } else if (IsSine(original)) {
210        return SimplifySine(original);
211      } else if (IsCosine(original)) {
212        return SimplifyCosine(original);
213      } else if (IsTangent(original)) {
214        return SimplifyTangent(original);
215      } else if (IsIfThenElse(original)) {
216        return SimplifyIfThenElse(original);
217      } else if (IsGreaterThan(original)) {
218        return SimplifyGreaterThan(original);
219      } else if (IsLessThan(original)) {
220        return SimplifyLessThan(original);
221      } else if (IsAnd(original)) {
222        return SimplifyAnd(original);
223      } else if (IsOr(original)) {
224        return SimplifyOr(original);
225      } else if (IsNot(original)) {
226        return SimplifyNot(original);
227      } else {
228        return SimplifyAny(original);
229      }
230    }
231
232
233    #region specific simplification routines
234    private SymbolicExpressionTreeNode SimplifyAny(SymbolicExpressionTreeNode original) {
235      // can't simplify this function but simplify all subtrees
236      List<SymbolicExpressionTreeNode> subTrees = new List<SymbolicExpressionTreeNode>(original.SubTrees);
237      while (original.SubTrees.Count > 0) original.RemoveSubTree(0);
238      var clone = (SymbolicExpressionTreeNode)original.Clone();
239      List<SymbolicExpressionTreeNode> simplifiedSubTrees = new List<SymbolicExpressionTreeNode>();
240      foreach (var subTree in subTrees) {
241        simplifiedSubTrees.Add(GetSimplifiedTree(subTree));
242        original.AddSubTree(subTree);
243      }
244      foreach (var simplifiedSubtree in simplifiedSubTrees) {
245        clone.AddSubTree(simplifiedSubtree);
246      }
247      if (simplifiedSubTrees.TrueForAll(t => IsConstant(t))) {
248        SimplifyConstantExpression(clone);
249      }
250      return clone;
251    }
252
253    private SymbolicExpressionTreeNode SimplifyConstantExpression(SymbolicExpressionTreeNode original) {
254      // not yet implemented
255      return original;
256    }
257
258    private SymbolicExpressionTreeNode SimplifyAverage(SymbolicExpressionTreeNode original) {
259      if (original.SubTrees.Count == 1) {
260        return GetSimplifiedTree(original.SubTrees[0]);
261      } else {
262        // simplify expressions x0..xn
263        // make sum(x0..xn) / n
264        Trace.Assert(original.SubTrees.Count > 1);
265        var sum = original.SubTrees
266          .Select(x => GetSimplifiedTree(x))
267          .Aggregate((a, b) => MakeSum(a, b));
268        return MakeFraction(sum, MakeConstant(original.SubTrees.Count));
269      }
270    }
271
272    private SymbolicExpressionTreeNode SimplifyDivision(SymbolicExpressionTreeNode original) {
273      if (original.SubTrees.Count == 1) {
274        return Invert(GetSimplifiedTree(original.SubTrees[0]));
275      } else {
276        // simplify expressions x0..xn
277        // make multiplication (x0 * 1/(x1 * x1 * .. * xn))
278        Trace.Assert(original.SubTrees.Count > 1);
279        var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
280        return
281          MakeProduct(simplifiedTrees.First(), Invert(simplifiedTrees.Skip(1).Aggregate((a, b) => MakeProduct(a, b))));
282      }
283    }
284
285    private SymbolicExpressionTreeNode SimplifyMultiplication(SymbolicExpressionTreeNode original) {
286      if (original.SubTrees.Count == 1) {
287        return GetSimplifiedTree(original.SubTrees[0]);
288      } else {
289        Trace.Assert(original.SubTrees.Count > 1);
290        return original.SubTrees
291          .Select(x => GetSimplifiedTree(x))
292          .Aggregate((a, b) => MakeProduct(a, b));
293      }
294    }
295
296    private SymbolicExpressionTreeNode SimplifySubtraction(SymbolicExpressionTreeNode original) {
297      if (original.SubTrees.Count == 1) {
298        return Negate(GetSimplifiedTree(original.SubTrees[0]));
299      } else {
300        // simplify expressions x0..xn
301        // make addition (x0,-x1..-xn)
302        Trace.Assert(original.SubTrees.Count > 1);
303        var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
304        return simplifiedTrees.Take(1)
305          .Concat(simplifiedTrees.Skip(1).Select(x => Negate(x)))
306          .Aggregate((a, b) => MakeSum(a, b));
307      }
308    }
309
310    private SymbolicExpressionTreeNode SimplifyAddition(SymbolicExpressionTreeNode original) {
311      if (original.SubTrees.Count == 1) {
312        return GetSimplifiedTree(original.SubTrees[0]);
313      } else {
314        // simplify expression x0..xn
315        // make addition (x0..xn)
316        Trace.Assert(original.SubTrees.Count > 1);
317        return original.SubTrees
318          .Select(x => GetSimplifiedTree(x))
319          .Aggregate((a, b) => MakeSum(a, b));
320      }
321    }
322
323    private SymbolicExpressionTreeNode SimplifyNot(SymbolicExpressionTreeNode original) {
324      return MakeNot(GetSimplifiedTree(original.SubTrees[0]));
325    }
326    private SymbolicExpressionTreeNode SimplifyOr(SymbolicExpressionTreeNode original) {
327      return original.SubTrees
328        .Select(x => GetSimplifiedTree(x))
329        .Aggregate((a, b) => MakeOr(a, b));
330    }
331    private SymbolicExpressionTreeNode SimplifyAnd(SymbolicExpressionTreeNode original) {
332      return original.SubTrees
333        .Select(x => GetSimplifiedTree(x))
334        .Aggregate((a, b) => MakeAnd(a, b));
335    }
336    private SymbolicExpressionTreeNode SimplifyLessThan(SymbolicExpressionTreeNode original) {
337      return MakeLessThan(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]));
338    }
339    private SymbolicExpressionTreeNode SimplifyGreaterThan(SymbolicExpressionTreeNode original) {
340      return MakeGreaterThan(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]));
341    }
342    private SymbolicExpressionTreeNode SimplifyIfThenElse(SymbolicExpressionTreeNode original) {
343      return MakeIfThenElse(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]), GetSimplifiedTree(original.SubTrees[2]));
344    }
345    private SymbolicExpressionTreeNode SimplifyTangent(SymbolicExpressionTreeNode original) {
346      return MakeTangent(GetSimplifiedTree(original.SubTrees[0]));
347    }
348    private SymbolicExpressionTreeNode SimplifyCosine(SymbolicExpressionTreeNode original) {
349      return MakeCosine(GetSimplifiedTree(original.SubTrees[0]));
350    }
351    private SymbolicExpressionTreeNode SimplifySine(SymbolicExpressionTreeNode original) {
352      return MakeSine(GetSimplifiedTree(original.SubTrees[0]));
353    }
354    private SymbolicExpressionTreeNode SimplifyExp(SymbolicExpressionTreeNode original) {
355      return MakeExp(GetSimplifiedTree(original.SubTrees[0]));
356    }
357
358    private SymbolicExpressionTreeNode SimplifyLog(SymbolicExpressionTreeNode original) {
359      return MakeLog(GetSimplifiedTree(original.SubTrees[0]));
360    }
361    private SymbolicExpressionTreeNode SimplifyRoot(SymbolicExpressionTreeNode original) {
362      return MakeRoot(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]));
363    }
364
365    private SymbolicExpressionTreeNode SimplifyPower(SymbolicExpressionTreeNode original) {
366      return MakePower(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]));
367    }
368    #endregion
369
370    #region low level tree restructuring
371    private SymbolicExpressionTreeNode MakeNot(SymbolicExpressionTreeNode t) {
372      return MakeProduct(t, MakeConstant(-1.0));
373    }
374
375    private SymbolicExpressionTreeNode MakeOr(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
376      if (IsConstant(a) && IsConstant(b)) {
377        var constA = a as ConstantTreeNode;
378        var constB = b as ConstantTreeNode;
379        if (constA.Value > 0.0 || constB.Value > 0.0) {
380          return MakeConstant(1.0);
381        } else {
382          return MakeConstant(-1.0);
383        }
384      } else if (IsConstant(a)) {
385        return MakeOr(b, a);
386      } else if (IsConstant(b)) {
387        var constT = b as ConstantTreeNode;
388        if (constT.Value > 0.0) {
389          // boolean expression is necessarily true
390          return MakeConstant(1.0);
391        } else {
392          // the constant value has no effect on the result of the boolean condition so we can drop the constant term
393          var orNode = orSymbol.CreateTreeNode();
394          orNode.AddSubTree(a);
395          return orNode;
396        }
397      } else {
398        var orNode = orSymbol.CreateTreeNode();
399        orNode.AddSubTree(a);
400        orNode.AddSubTree(b);
401        return orNode;
402      }
403    }
404    private SymbolicExpressionTreeNode MakeAnd(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
405      if (IsConstant(a) && IsConstant(b)) {
406        var constA = a as ConstantTreeNode;
407        var constB = b as ConstantTreeNode;
408        if (constA.Value > 0.0 && constB.Value > 0.0) {
409          return MakeConstant(1.0);
410        } else {
411          return MakeConstant(-1.0);
412        }
413      } else if (IsConstant(a)) {
414        return MakeAnd(b, a);
415      } else if (IsConstant(b)) {
416        var constB = b as ConstantTreeNode;
417        if (constB.Value > 0.0) {
418          // the constant value has no effect on the result of the boolean condition so we can drop the constant term
419          var andNode = andSymbol.CreateTreeNode();
420          andNode.AddSubTree(a);
421          return andNode;
422        } else {
423          // boolean expression is necessarily false
424          return MakeConstant(-1.0);
425        }
426      } else {
427        var andNode = andSymbol.CreateTreeNode();
428        andNode.AddSubTree(a);
429        andNode.AddSubTree(b);
430        return andNode;
431      }
432    }
433    private SymbolicExpressionTreeNode MakeLessThan(SymbolicExpressionTreeNode leftSide, SymbolicExpressionTreeNode rightSide) {
434      if (IsConstant(leftSide) && IsConstant(rightSide)) {
435        var lsConst = leftSide as ConstantTreeNode;
436        var rsConst = rightSide as ConstantTreeNode;
437        if (lsConst.Value < rsConst.Value) return MakeConstant(1.0);
438        else return MakeConstant(-1.0);
439      } else {
440        var ltNode = ltSymbol.CreateTreeNode();
441        ltNode.AddSubTree(leftSide);
442        ltNode.AddSubTree(rightSide);
443        return ltNode;
444      }
445    }
446    private SymbolicExpressionTreeNode MakeGreaterThan(SymbolicExpressionTreeNode leftSide, SymbolicExpressionTreeNode rightSide) {
447      if (IsConstant(leftSide) && IsConstant(rightSide)) {
448        var lsConst = leftSide as ConstantTreeNode;
449        var rsConst = rightSide as ConstantTreeNode;
450        if (lsConst.Value > rsConst.Value) return MakeConstant(1.0);
451        else return MakeConstant(-1.0);
452      } else {
453        var gtNode = gtSymbol.CreateTreeNode();
454        gtNode.AddSubTree(leftSide);
455        gtNode.AddSubTree(rightSide);
456        return gtNode;
457      }
458    }
459    private SymbolicExpressionTreeNode MakeIfThenElse(SymbolicExpressionTreeNode condition, SymbolicExpressionTreeNode trueBranch, SymbolicExpressionTreeNode falseBranch) {
460      if (IsConstant(condition)) {
461        var constT = condition as ConstantTreeNode;
462        if (constT.Value > 0.0) return trueBranch;
463        else return falseBranch;
464      } else {
465        var ifNode = ifThenElseSymbol.CreateTreeNode();
466        if (IsBoolean(condition)) {
467          ifNode.AddSubTree(condition);
468        } else {
469          var gtNode = gtSymbol.CreateTreeNode();
470          gtNode.AddSubTree(condition); gtNode.AddSubTree(MakeConstant(0.0));
471          ifNode.AddSubTree(gtNode);
472        }
473        ifNode.AddSubTree(trueBranch);
474        ifNode.AddSubTree(falseBranch);
475        return ifNode;
476      }
477    }
478
479    private SymbolicExpressionTreeNode MakeSine(SymbolicExpressionTreeNode node) {
480      if (IsConstant(node)) {
481        var constT = node as ConstantTreeNode;
482        return MakeConstant(Math.Sin(constT.Value));
483      } else {
484        var sineNode = sineSymbol.CreateTreeNode();
485        sineNode.AddSubTree(node);
486        return sineNode;
487      }
488    }
489    private SymbolicExpressionTreeNode MakeTangent(SymbolicExpressionTreeNode node) {
490      if (IsConstant(node)) {
491        var constT = node as ConstantTreeNode;
492        return MakeConstant(Math.Tan(constT.Value));
493      } else {
494        var tanNode = tanSymbol.CreateTreeNode();
495        tanNode.AddSubTree(node);
496        return tanNode;
497      }
498    }
499    private SymbolicExpressionTreeNode MakeCosine(SymbolicExpressionTreeNode node) {
500      if (IsConstant(node)) {
501        var constT = node as ConstantTreeNode;
502        return MakeConstant(Math.Cos(constT.Value));
503      } else {
504        var cosNode = cosineSymbol.CreateTreeNode();
505        cosNode.AddSubTree(node);
506        return cosNode;
507      }
508    }
509    private SymbolicExpressionTreeNode MakeExp(SymbolicExpressionTreeNode node) {
510      if (IsConstant(node)) {
511        var constT = node as ConstantTreeNode;
512        return MakeConstant(Math.Exp(constT.Value));
513      } else if (IsLog(node)) {
514        return node.SubTrees[0];
515      } else if (IsAddition(node)) {
516        return node.SubTrees.Select(s => MakeExp(s)).Aggregate((s, t) => MakeProduct(s, t));
517      } else if (IsSubtraction(node)) {
518        return node.SubTrees.Select(s => MakeExp(s)).Aggregate((s, t) => MakeProduct(s, Negate(t)));
519      } else {
520        var expNode = expSymbol.CreateTreeNode();
521        expNode.AddSubTree(node);
522        return expNode;
523      }
524    }
525    private SymbolicExpressionTreeNode MakeLog(SymbolicExpressionTreeNode node) {
526      if (IsConstant(node)) {
527        var constT = node as ConstantTreeNode;
528        return MakeConstant(Math.Log(constT.Value));
529      } else if (IsExp(node)) {
530        return node.SubTrees[0];
531      } else if (IsMultiplication(node)) {
532        return node.SubTrees.Select(s => MakeLog(s)).Aggregate((x, y) => MakeSum(x, y));
533      } else if (IsDivision(node)) {
534        var subtractionNode = subSymbol.CreateTreeNode();
535        foreach (var subTree in node.SubTrees) {
536          subtractionNode.AddSubTree(MakeLog(subTree));
537        }
538        return subtractionNode;
539      } else {
540        var logNode = logSymbol.CreateTreeNode();
541        logNode.AddSubTree(node);
542        return logNode;
543      }
544    }
545    private SymbolicExpressionTreeNode MakeRoot(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
546      if (IsConstant(a) && IsConstant(b)) {
547        var constA = a as ConstantTreeNode;
548        var constB = b as ConstantTreeNode;
549        return MakeConstant(Math.Pow(constA.Value, 1.0 / Math.Round(constB.Value)));
550      } else if (IsConstant(b)) {
551        var constB = b as ConstantTreeNode;
552        var constBValue = Math.Round(constB.Value);
553        if (constBValue.IsAlmost(1.0)) {
554          return a;
555        } else if (constBValue.IsAlmost(0.0)) {
556          return MakeConstant(1.0);
557        } else if (constBValue.IsAlmost(-1.0)) {
558          return MakeFraction(MakeConstant(1.0), a);
559        } else if (constBValue < 0) {
560          var rootNode = rootSymbol.CreateTreeNode();
561          rootNode.AddSubTree(a);
562          rootNode.AddSubTree(MakeConstant(-1.0 * constBValue));
563          return MakeFraction(MakeConstant(1.0), rootNode);
564        } else {
565          var rootNode = rootSymbol.CreateTreeNode();
566          rootNode.AddSubTree(a);
567          rootNode.AddSubTree(MakeConstant(constBValue));
568          return rootNode;
569        }
570      } else {
571        var rootNode = rootSymbol.CreateTreeNode();
572        rootNode.AddSubTree(a);
573        rootNode.AddSubTree(b);
574        return rootNode;
575      }
576    }
577    private SymbolicExpressionTreeNode MakePower(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
578      if (IsConstant(a) && IsConstant(b)) {
579        var constA = a as ConstantTreeNode;
580        var constB = b as ConstantTreeNode;
581        return MakeConstant(Math.Pow(constA.Value, Math.Round(constB.Value)));
582      } else if (IsConstant(b)) {
583        var constB = b as ConstantTreeNode;
584        double exponent = Math.Round(constB.Value);
585        if (exponent.IsAlmost(0.0)) {
586          return MakeConstant(1.0);
587        } else if (exponent.IsAlmost(1.0)) {
588          return a;
589        } else if (exponent.IsAlmost(-1.0)) {
590          return MakeFraction(MakeConstant(1.0), a);
591        } else if (exponent < 0) {
592          var powNode = powSymbol.CreateTreeNode();
593          powNode.AddSubTree(a);
594          powNode.AddSubTree(MakeConstant(-1.0 * exponent));
595          return MakeFraction(MakeConstant(1.0), powNode);
596        } else {
597          var powNode = powSymbol.CreateTreeNode();
598          powNode.AddSubTree(a);
599          powNode.AddSubTree(MakeConstant(exponent));
600          return powNode;
601        }
602      } else {
603        var powNode = powSymbol.CreateTreeNode();
604        powNode.AddSubTree(a);
605        powNode.AddSubTree(b);
606        return powNode;
607      }
608    }
609
610
611    // MakeFraction, MakeProduct and MakeSum take two already simplified trees and create a new simplified tree
612    private SymbolicExpressionTreeNode MakeFraction(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
613      if (IsConstant(a) && IsConstant(b)) {
614        // fold constants
615        return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
616      } if (IsConstant(a) && !((ConstantTreeNode)a).Value.IsAlmost(1.0)) {
617        return MakeFraction(MakeConstant(1.0), MakeProduct(b, Invert(a)));
618      } else if (IsVariable(a) && IsConstant(b)) {
619        // merge constant values into variable weights
620        var constB = ((ConstantTreeNode)b).Value;
621        ((VariableTreeNode)a).Weight /= constB;
622        return a;
623      } else if (IsVariable(a) && IsVariable(b) && AreSameVariable(a, b)) {
624        // cancel variables
625        var aVar = a as VariableTreeNode;
626        var bVar = b as VariableTreeNode;
627        return MakeConstant(aVar.Weight / bVar.Weight);
628      } else if (IsAddition(a) && IsConstant(b)) {
629        return a.SubTrees
630          .Select(x => GetSimplifiedTree(x))
631         .Select(x => MakeFraction(x, b))
632         .Aggregate((c, d) => MakeSum(c, d));
633      } else if (IsMultiplication(a) && IsConstant(b)) {
634        return MakeProduct(a, Invert(b));
635      } else if (IsDivision(a) && IsConstant(b)) {
636        // (a1 / a2) / c => (a1 / (a2 * c))
637        Trace.Assert(a.SubTrees.Count == 2);
638        return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
639      } else if (IsDivision(a) && IsDivision(b)) {
640        // (a1 / a2) / (b1 / b2) =>
641        Trace.Assert(a.SubTrees.Count == 2);
642        Trace.Assert(b.SubTrees.Count == 2);
643        return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[1]), MakeProduct(a.SubTrees[1], b.SubTrees[0]));
644      } else if (IsDivision(a)) {
645        // (a1 / a2) / b => (a1 / (a2 * b))
646        Trace.Assert(a.SubTrees.Count == 2);
647        return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
648      } else if (IsDivision(b)) {
649        // a / (b1 / b2) => (a * b2) / b1
650        Trace.Assert(b.SubTrees.Count == 2);
651        return MakeFraction(MakeProduct(a, b.SubTrees[1]), b.SubTrees[0]);
652      } else {
653        var div = divSymbol.CreateTreeNode();
654        div.AddSubTree(a);
655        div.AddSubTree(b);
656        return div;
657      }
658    }
659
660    private SymbolicExpressionTreeNode MakeSum(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
661      if (IsConstant(a) && IsConstant(b)) {
662        // fold constants
663        ((ConstantTreeNode)a).Value += ((ConstantTreeNode)b).Value;
664        return a;
665      } else if (IsConstant(a)) {
666        // c + x => x + c
667        // b is not constant => make sure constant is on the right
668        return MakeSum(b, a);
669      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(0.0)) {
670        // x + 0 => x
671        return a;
672      } else if (IsAddition(a) && IsAddition(b)) {
673        // merge additions
674        var add = addSymbol.CreateTreeNode();
675        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
676        for (int i = 0; i < b.SubTrees.Count - 1; i++) add.AddSubTree(b.SubTrees[i]);
677        if (IsConstant(a.SubTrees.Last()) && IsConstant(b.SubTrees.Last())) {
678          add.AddSubTree(MakeSum(a.SubTrees.Last(), b.SubTrees.Last()));
679        } else if (IsConstant(a.SubTrees.Last())) {
680          add.AddSubTree(b.SubTrees.Last());
681          add.AddSubTree(a.SubTrees.Last());
682        } else {
683          add.AddSubTree(a.SubTrees.Last());
684          add.AddSubTree(b.SubTrees.Last());
685        }
686        MergeVariablesInSum(add);
687        if (add.SubTrees.Count == 1) {
688          return add.SubTrees[0];
689        } else {
690          return add;
691        }
692      } else if (IsAddition(b)) {
693        return MakeSum(b, a);
694      } else if (IsAddition(a) && IsConstant(b)) {
695        // a is an addition and b is a constant => append b to a and make sure the constants are merged
696        var add = addSymbol.CreateTreeNode();
697        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
698        if (IsConstant(a.SubTrees.Last()))
699          add.AddSubTree(MakeSum(a.SubTrees.Last(), b));
700        else {
701          add.AddSubTree(a.SubTrees.Last());
702          add.AddSubTree(b);
703        }
704        return add;
705      } else if (IsAddition(a)) {
706        // a is already an addition => append b
707        var add = addSymbol.CreateTreeNode();
708        add.AddSubTree(b);
709        foreach (var subTree in a.SubTrees) {
710          add.AddSubTree(subTree);
711        }
712        MergeVariablesInSum(add);
713        if (add.SubTrees.Count == 1) {
714          return add.SubTrees[0];
715        } else {
716          return add;
717        }
718      } else {
719        var add = addSymbol.CreateTreeNode();
720        add.AddSubTree(a);
721        add.AddSubTree(b);
722        MergeVariablesInSum(add);
723        if (add.SubTrees.Count == 1) {
724          return add.SubTrees[0];
725        } else {
726          return add;
727        }
728      }
729    }
730
731    // makes sure variable symbols in sums are combined
732    // possible improvement: combine sums of products where the products only reference the same variable
733    private void MergeVariablesInSum(SymbolicExpressionTreeNode sum) {
734      var subtrees = new List<SymbolicExpressionTreeNode>(sum.SubTrees);
735      while (sum.SubTrees.Count > 0) sum.RemoveSubTree(0);
736      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
737                            let lag = (node is LaggedVariableTreeNode) ? ((LaggedVariableTreeNode)node).Lag : 0
738                            group node by node.VariableName + lag into g
739                            select g;
740      var unchangedSubTrees = subtrees.Where(t => !(t is VariableTreeNode));
741
742      foreach (var variableNodeGroup in groupedVarNodes) {
743        var weightSum = variableNodeGroup.Select(t => t.Weight).Sum();
744        var representative = variableNodeGroup.First();
745        representative.Weight = weightSum;
746        sum.AddSubTree(representative);
747      }
748      foreach (var unchangedSubtree in unchangedSubTrees)
749        sum.AddSubTree(unchangedSubtree);
750    }
751
752
753    private SymbolicExpressionTreeNode MakeProduct(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
754      if (IsConstant(a) && IsConstant(b)) {
755        // fold constants
756        ((ConstantTreeNode)a).Value *= ((ConstantTreeNode)b).Value;
757        return a;
758      } else if (IsConstant(a)) {
759        // a * $ => $ * a
760        return MakeProduct(b, a);
761      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
762        // $ * 1.0 => $
763        return a;
764      } else if (IsConstant(b) && IsVariable(a)) {
765        // multiply constants into variables weights
766        ((VariableTreeNode)a).Weight *= ((ConstantTreeNode)b).Value;
767        return a;
768      } else if (IsConstant(b) && IsAddition(a)) {
769        // multiply constants into additions
770        return a.SubTrees.Select(x => MakeProduct(x, b)).Aggregate((c, d) => MakeSum(c, d));
771      } else if (IsDivision(a) && IsDivision(b)) {
772        // (a1 / a2) * (b1 / b2) => (a1 * b1) / (a2 * b2)
773        Trace.Assert(a.SubTrees.Count == 2);
774        Trace.Assert(b.SubTrees.Count == 2);
775        return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[0]), MakeProduct(a.SubTrees[1], b.SubTrees[1]));
776      } else if (IsDivision(a)) {
777        // (a1 / a2) * b => (a1 * b) / a2
778        Trace.Assert(a.SubTrees.Count == 2);
779        return MakeFraction(MakeProduct(a.SubTrees[0], b), a.SubTrees[1]);
780      } else if (IsDivision(b)) {
781        // a * (b1 / b2) => (b1 * a) / b2
782        Trace.Assert(b.SubTrees.Count == 2);
783        return MakeFraction(MakeProduct(b.SubTrees[0], a), b.SubTrees[1]);
784      } else if (IsMultiplication(a) && IsMultiplication(b)) {
785        // merge multiplications (make sure constants are merged)
786        var mul = mulSymbol.CreateTreeNode();
787        for (int i = 0; i < a.SubTrees.Count; i++) mul.AddSubTree(a.SubTrees[i]);
788        for (int i = 0; i < b.SubTrees.Count; i++) mul.AddSubTree(b.SubTrees[i]);
789        MergeVariablesAndConstantsInProduct(mul);
790        return mul;
791      } else if (IsMultiplication(b)) {
792        return MakeProduct(b, a);
793      } else if (IsMultiplication(a)) {
794        // a is already an multiplication => append b
795        a.AddSubTree(b);
796        MergeVariablesAndConstantsInProduct(a);
797        return a;
798      } else {
799        var mul = mulSymbol.CreateTreeNode();
800        mul.SubTrees.Add(a);
801        mul.SubTrees.Add(b);
802        MergeVariablesAndConstantsInProduct(mul);
803        return mul;
804      }
805    }
806    #endregion
807
808
809    #region helper functions
810
811    private bool AreSameVariable(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
812      var aLaggedVar = a as LaggedVariableTreeNode;
813      var bLaggedVar = b as LaggedVariableTreeNode;
814      if (aLaggedVar != null && bLaggedVar != null) {
815        return aLaggedVar.VariableName == bLaggedVar.VariableName &&
816          aLaggedVar.Lag == bLaggedVar.Lag;
817      }
818      var aVar = a as VariableTreeNode;
819      var bVar = b as VariableTreeNode;
820      if (aVar != null && bVar != null) {
821        return aVar.VariableName == bVar.VariableName;
822      }
823      return false;
824    }
825
826    // helper to combine the constant factors in products and to combine variables (powers of 2, 3...)
827    private void MergeVariablesAndConstantsInProduct(SymbolicExpressionTreeNode prod) {
828      var subtrees = new List<SymbolicExpressionTreeNode>(prod.SubTrees);
829      while (prod.SubTrees.Count > 0) prod.RemoveSubTree(0);
830      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
831                            let lag = (node is LaggedVariableTreeNode) ? ((LaggedVariableTreeNode)node).Lag : 0
832                            group node by node.VariableName + lag into g
833                            orderby g.Count()
834                            select g;
835      var constantProduct = (from node in subtrees.OfType<VariableTreeNode>()
836                             select node.Weight)
837                            .Concat(from node in subtrees.OfType<ConstantTreeNode>()
838                                    select node.Value)
839                            .DefaultIfEmpty(1.0)
840                            .Aggregate((c1, c2) => c1 * c2);
841
842      var unchangedSubTrees = from tree in subtrees
843                              where !(tree is VariableTreeNode)
844                              where !(tree is ConstantTreeNode)
845                              select tree;
846
847      foreach (var variableNodeGroup in groupedVarNodes) {
848        var representative = variableNodeGroup.First();
849        representative.Weight = 1.0;
850        if (variableNodeGroup.Count() > 1) {
851          var poly = mulSymbol.CreateTreeNode();
852          for (int p = 0; p < variableNodeGroup.Count(); p++) {
853            poly.AddSubTree((SymbolicExpressionTreeNode)representative.Clone());
854          }
855          prod.AddSubTree(poly);
856        } else {
857          prod.AddSubTree(representative);
858        }
859      }
860
861      foreach (var unchangedSubtree in unchangedSubTrees)
862        prod.AddSubTree(unchangedSubtree);
863
864      if (!constantProduct.IsAlmost(1.0)) {
865        prod.AddSubTree(MakeConstant(constantProduct));
866      }
867    }
868
869
870    /// <summary>
871    /// x => x * -1
872    /// Doesn't create new trees and manipulates x
873    /// </summary>
874    /// <param name="x"></param>
875    /// <returns>-x</returns>
876    private SymbolicExpressionTreeNode Negate(SymbolicExpressionTreeNode x) {
877      if (IsConstant(x)) {
878        ((ConstantTreeNode)x).Value *= -1;
879      } else if (IsVariable(x)) {
880        var variableTree = (VariableTreeNode)x;
881        variableTree.Weight *= -1.0;
882      } else if (IsAddition(x)) {
883        // (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)       
884        for (int i = 0; i < x.SubTrees.Count; i++)
885          x.SubTrees[i] = Negate(x.SubTrees[i]);
886      } else if (IsMultiplication(x) || IsDivision(x)) {
887        // x0 * x1 * .. * xn * -1 => x0 * x1 * .. * -xn
888        x.SubTrees[x.SubTrees.Count - 1] = Negate(x.SubTrees.Last()); // last is maybe a constant, prefer to negate the constant
889      } else {
890        // any other function
891        return MakeProduct(x, MakeConstant(-1));
892      }
893      return x;
894    }
895
896    /// <summary>
897    /// x => 1/x
898    /// Doesn't create new trees and manipulates x
899    /// </summary>
900    /// <param name="x"></param>
901    /// <returns></returns>
902    private SymbolicExpressionTreeNode Invert(SymbolicExpressionTreeNode x) {
903      if (IsConstant(x)) {
904        return MakeConstant(1.0 / ((ConstantTreeNode)x).Value);
905      } else if (IsDivision(x)) {
906        Trace.Assert(x.SubTrees.Count == 2);
907        return MakeFraction(x.SubTrees[1], x.SubTrees[0]);
908      } else {
909        // any other function
910        return MakeFraction(MakeConstant(1), x);
911      }
912    }
913
914    private SymbolicExpressionTreeNode MakeConstant(double value) {
915      ConstantTreeNode constantTreeNode = (ConstantTreeNode)(constSymbol.CreateTreeNode());
916      constantTreeNode.Value = value;
917      return (SymbolicExpressionTreeNode)constantTreeNode;
918    }
919
920    private SymbolicExpressionTreeNode MakeVariable(double weight, string name) {
921      var tree = (VariableTreeNode)varSymbol.CreateTreeNode();
922      tree.Weight = weight;
923      tree.VariableName = name;
924      return tree;
925    }
926    #endregion
927  }
928}
Note: See TracBrowser for help on using the repository browser.