Free cookie consent management tool by TermsFeed Policy Generator

source: branches/SymbolicSimplifier/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SymbolicSimplifier.cs @ 5298

Last change on this file since 5298 was 4878, checked in by gkronber, 14 years ago

Added constant folding to the symbolic simplifier for all data analysis symbols. #1227

File size: 31.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
29using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  /// <summary>
33  /// Simplistic simplifier for arithmetic expressions
34  /// </summary>
35  public class SymbolicSimplifier {
36    private Addition addSymbol = new Addition();
37    private Multiplication mulSymbol = new Multiplication();
38    private Division divSymbol = new Division();
39    private Constant constSymbol = new Constant();
40    private Variable varSymbol = new Variable();
41    private Logarithm logSymbol = new Logarithm();
42    private Exponential expSymbol = new Exponential();
43    private Sine sineSymbol = new Sine();
44    private Cosine cosineSymbol = new Cosine();
45    private Tangent tanSymbol = new Tangent();
46    private IfThenElse ifThenElseSymbol = new IfThenElse();
47    private And andSymbol = new And();
48    private Or orSymbol = new Or();
49    private Not notSymbol = new Not();
50    private GreaterThan gtSymbol = new GreaterThan();
51    private LessThan ltSymbol = new LessThan();
52
53    public SymbolicExpressionTree Simplify(SymbolicExpressionTree originalTree) {
54      var clone = (SymbolicExpressionTreeNode)originalTree.Root.Clone();
55      // macro expand (initially no argument trees)
56      var macroExpandedTree = MacroExpand(clone, clone.SubTrees[0], new List<SymbolicExpressionTreeNode>());
57      SymbolicExpressionTreeNode rootNode = (new ProgramRootSymbol()).CreateTreeNode();
58      rootNode.AddSubTree(GetSimplifiedTree(macroExpandedTree));
59      return new SymbolicExpressionTree(rootNode);
60    }
61
62    // the argumentTrees list contains already expanded trees used as arguments for invocations
63    private SymbolicExpressionTreeNode MacroExpand(SymbolicExpressionTreeNode root, SymbolicExpressionTreeNode node, IList<SymbolicExpressionTreeNode> argumentTrees) {
64      List<SymbolicExpressionTreeNode> subtrees = new List<SymbolicExpressionTreeNode>(node.SubTrees);
65      while (node.SubTrees.Count > 0) node.RemoveSubTree(0);
66      if (node.Symbol is InvokeFunction) {
67        var invokeSym = node.Symbol as InvokeFunction;
68        var defunNode = FindFunctionDefinition(root, invokeSym.FunctionName);
69        var macroExpandedArguments = new List<SymbolicExpressionTreeNode>();
70        foreach (var subtree in subtrees) {
71          macroExpandedArguments.Add(MacroExpand(root, subtree, argumentTrees));
72        }
73        return MacroExpand(root, defunNode, macroExpandedArguments);
74      } else if (node.Symbol is Argument) {
75        var argSym = node.Symbol as Argument;
76        // return the correct argument sub-tree (already macro-expanded)
77        return (SymbolicExpressionTreeNode)argumentTrees[argSym.ArgumentIndex].Clone();
78      } else {
79        // recursive application
80        foreach (var subtree in subtrees) {
81          node.AddSubTree(MacroExpand(root, subtree, argumentTrees));
82        }
83        return node;
84      }
85    }
86
87    private SymbolicExpressionTreeNode FindFunctionDefinition(SymbolicExpressionTreeNode root, string functionName) {
88      foreach (var subtree in root.SubTrees.OfType<DefunTreeNode>()) {
89        if (subtree.FunctionName == functionName) return subtree.SubTrees[0];
90      }
91
92      throw new ArgumentException("Definition of function " + functionName + " not found.");
93    }
94
95
96    #region symbol predicates
97    // arithmetic
98    private bool IsDivision(SymbolicExpressionTreeNode node) {
99      return node.Symbol is Division;
100    }
101
102    private bool IsMultiplication(SymbolicExpressionTreeNode node) {
103      return node.Symbol is Multiplication;
104    }
105
106    private bool IsSubtraction(SymbolicExpressionTreeNode node) {
107      return node.Symbol is Subtraction;
108    }
109
110    private bool IsAddition(SymbolicExpressionTreeNode node) {
111      return node.Symbol is Addition;
112    }
113
114    private bool IsAverage(SymbolicExpressionTreeNode node) {
115      return node.Symbol is Average;
116    }
117    // exponential
118    private bool IsLog(SymbolicExpressionTreeNode node) {
119      return node.Symbol is Logarithm;
120    }
121    private bool IsExp(SymbolicExpressionTreeNode node) {
122      return node.Symbol is Exponential;
123    }
124    // trigonometric
125    private bool IsSine(SymbolicExpressionTreeNode node) {
126      return node.Symbol is Sine;
127    }
128    private bool IsCosine(SymbolicExpressionTreeNode node) {
129      return node.Symbol is Cosine;
130    }
131    private bool IsTangent(SymbolicExpressionTreeNode node) {
132      return node.Symbol is Tangent;
133    }
134    // boolean
135    private bool IsIfThenElse(SymbolicExpressionTreeNode node) {
136      return node.Symbol is IfThenElse;
137    }
138    private bool IsAnd(SymbolicExpressionTreeNode node) {
139      return node.Symbol is And;
140    }
141    private bool IsOr(SymbolicExpressionTreeNode node) {
142      return node.Symbol is Or;
143    }
144    private bool IsNot(SymbolicExpressionTreeNode node) {
145      return node.Symbol is Not;
146    }
147    // comparison
148    private bool IsGreaterThan(SymbolicExpressionTreeNode node) {
149      return node.Symbol is GreaterThan;
150    }
151    private bool IsLessThan(SymbolicExpressionTreeNode node) {
152      return node.Symbol is LessThan;
153    }
154
155    // terminals
156    private bool IsVariable(SymbolicExpressionTreeNode node) {
157      return node.Symbol is Variable;
158    }
159
160    private bool IsConstant(SymbolicExpressionTreeNode node) {
161      return node.Symbol is Constant;
162    }
163
164    #endregion
165
166    /// <summary>
167    /// Creates a new simplified tree
168    /// </summary>
169    /// <param name="original"></param>
170    /// <returns></returns>
171    public SymbolicExpressionTreeNode GetSimplifiedTree(SymbolicExpressionTreeNode original) {
172      if (IsConstant(original) || IsVariable(original)) {
173        return (SymbolicExpressionTreeNode)original.Clone();
174      } else if (IsAddition(original)) {
175        return SimplifyAddition(original);
176      } else if (IsSubtraction(original)) {
177        return SimplifySubtraction(original);
178      } else if (IsMultiplication(original)) {
179        return SimplifyMultiplication(original);
180      } else if (IsDivision(original)) {
181        return SimplifyDivision(original);
182      } else if (IsAverage(original)) {
183        return SimplifyAverage(original);
184      } else if (IsLog(original)) {
185        return SimplifyLog(original);
186      } else if (IsExp(original)) {
187        return SimplifyExp(original);
188      } else if (IsSine(original)) {
189        return SimplifySine(original);
190      } else if (IsCosine(original)) {
191        return SimplifyCosine(original);
192      } else if (IsTangent(original)) {
193        return SimplifyTangent(original);
194      } else if (IsIfThenElse(original)) {
195        return SimplifyIfThenElse(original);
196      } else if (IsGreaterThan(original)) {
197        return SimplifyGreaterThan(original);
198      } else if (IsLessThan(original)) {
199        return SimplifyLessThan(original);
200      } else if (IsAnd(original)) {
201        return SimplifyAnd(original);
202      } else if (IsOr(original)) {
203        return SimplifyOr(original);
204      } else if (IsNot(original)) {
205        return SimplifyNot(original);
206      } else {
207        return SimplifyAny(original);
208      }
209    }
210
211
212    #region specific simplification routines
213    private SymbolicExpressionTreeNode SimplifyAny(SymbolicExpressionTreeNode original) {
214      // can't simplify this function but simplify all subtrees
215      List<SymbolicExpressionTreeNode> subTrees = new List<SymbolicExpressionTreeNode>(original.SubTrees);
216      while (original.SubTrees.Count > 0) original.RemoveSubTree(0);
217      var clone = (SymbolicExpressionTreeNode)original.Clone();
218      List<SymbolicExpressionTreeNode> simplifiedSubTrees = new List<SymbolicExpressionTreeNode>();
219      foreach (var subTree in subTrees) {
220        simplifiedSubTrees.Add(GetSimplifiedTree(subTree));
221        original.AddSubTree(subTree);
222      }
223      foreach (var simplifiedSubtree in simplifiedSubTrees) {
224        clone.AddSubTree(simplifiedSubtree);
225      }
226      return clone;
227    }
228
229    private SymbolicExpressionTreeNode SimplifyAverage(SymbolicExpressionTreeNode original) {
230      if (original.SubTrees.Count == 1) {
231        return GetSimplifiedTree(original.SubTrees[0]);
232      } else {
233        // simplify expressions x0..xn
234        // make sum(x0..xn) / n
235        Trace.Assert(original.SubTrees.Count > 1);
236        var sum = original.SubTrees
237          .Select(x => GetSimplifiedTree(x))
238          .Aggregate((a, b) => MakeSum(a, b));
239        return MakeFraction(sum, MakeConstant(original.SubTrees.Count));
240      }
241    }
242
243    private SymbolicExpressionTreeNode SimplifyDivision(SymbolicExpressionTreeNode original) {
244      if (original.SubTrees.Count == 1) {
245        return Invert(GetSimplifiedTree(original.SubTrees[0]));
246      } else {
247        // simplify expressions x0..xn
248        // make multiplication (x0 * 1/(x1 * x1 * .. * xn))
249        Trace.Assert(original.SubTrees.Count > 1);
250        var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
251        return
252          MakeProduct(simplifiedTrees.First(), Invert(simplifiedTrees.Skip(1).Aggregate((a, b) => MakeProduct(a, b))));
253      }
254    }
255
256    private SymbolicExpressionTreeNode SimplifyMultiplication(SymbolicExpressionTreeNode original) {
257      if (original.SubTrees.Count == 1) {
258        return GetSimplifiedTree(original.SubTrees[0]);
259      } else {
260        Trace.Assert(original.SubTrees.Count > 1);
261        return original.SubTrees
262          .Select(x => GetSimplifiedTree(x))
263          .Aggregate((a, b) => MakeProduct(a, b));
264      }
265    }
266
267    private SymbolicExpressionTreeNode SimplifySubtraction(SymbolicExpressionTreeNode original) {
268      if (original.SubTrees.Count == 1) {
269        return Negate(GetSimplifiedTree(original.SubTrees[0]));
270      } else {
271        // simplify expressions x0..xn
272        // make addition (x0,-x1..-xn)
273        Trace.Assert(original.SubTrees.Count > 1);
274        var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
275        return simplifiedTrees.Take(1)
276          .Concat(simplifiedTrees.Skip(1).Select(x => Negate(x)))
277          .Aggregate((a, b) => MakeSum(a, b));
278      }
279    }
280
281    private SymbolicExpressionTreeNode SimplifyAddition(SymbolicExpressionTreeNode original) {
282      if (original.SubTrees.Count == 1) {
283        return GetSimplifiedTree(original.SubTrees[0]);
284      } else {
285        // simplify expression x0..xn
286        // make addition (x0..xn)
287        Trace.Assert(original.SubTrees.Count > 1);
288        return original.SubTrees
289          .Select(x => GetSimplifiedTree(x))
290          .Aggregate((a, b) => MakeSum(a, b));
291      }
292    }
293
294    private SymbolicExpressionTreeNode SimplifyNot(SymbolicExpressionTreeNode original) {
295      return MakeNot(GetSimplifiedTree(original.SubTrees[0]));
296    }
297    private SymbolicExpressionTreeNode SimplifyOr(SymbolicExpressionTreeNode original) {
298      return original.SubTrees
299        .Select(x => GetSimplifiedTree(x))
300        .Aggregate((a, b) => MakeOr(a, b));
301    }
302    private SymbolicExpressionTreeNode SimplifyAnd(SymbolicExpressionTreeNode original) {
303      return original.SubTrees
304        .Select(x => GetSimplifiedTree(x))
305        .Aggregate((a, b) => MakeAnd(a, b));
306    }
307    private SymbolicExpressionTreeNode SimplifyLessThan(SymbolicExpressionTreeNode original) {
308      return MakeLessThan(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]));
309    }
310    private SymbolicExpressionTreeNode SimplifyGreaterThan(SymbolicExpressionTreeNode original) {
311      return MakeGreaterThan(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]));
312    }
313    private SymbolicExpressionTreeNode SimplifyIfThenElse(SymbolicExpressionTreeNode original) {
314      return MakeIfThenElse(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]), GetSimplifiedTree(original.SubTrees[2]));
315    }
316    private SymbolicExpressionTreeNode SimplifyTangent(SymbolicExpressionTreeNode original) {
317      return MakeTangent(GetSimplifiedTree(original.SubTrees[0]));
318    }
319    private SymbolicExpressionTreeNode SimplifyCosine(SymbolicExpressionTreeNode original) {
320      return MakeCosine(GetSimplifiedTree(original.SubTrees[0]));
321    }
322    private SymbolicExpressionTreeNode SimplifySine(SymbolicExpressionTreeNode original) {
323      return MakeSine(GetSimplifiedTree(original.SubTrees[0]));
324    }
325    private SymbolicExpressionTreeNode SimplifyExp(SymbolicExpressionTreeNode original) {
326      return MakeExp(GetSimplifiedTree(original.SubTrees[0]));
327    }
328
329    private SymbolicExpressionTreeNode SimplifyLog(SymbolicExpressionTreeNode original) {
330      return MakeLog(GetSimplifiedTree(original.SubTrees[0]));
331    }
332
333    #endregion
334
335
336
337    #region low level tree restructuring
338    private SymbolicExpressionTreeNode MakeNot(SymbolicExpressionTreeNode t) {
339      return MakeProduct(t, MakeConstant(-1.0));
340    }
341
342    private SymbolicExpressionTreeNode MakeOr(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
343      if (IsConstant(a) && IsConstant(b)) {
344        var constA = a as ConstantTreeNode;
345        var constB = b as ConstantTreeNode;
346        if (constA.Value > 0.0 || constB.Value > 0.0) {
347          return MakeConstant(1.0);
348        } else {
349          return MakeConstant(-1.0);
350        }
351      } else if (IsConstant(a)) {
352        return MakeOr(b, a);
353      } else if (IsConstant(b)) {
354        var constT = b as ConstantTreeNode;
355        if (constT.Value > 0.0) {
356          // boolean expression is necessarily true
357          return MakeConstant(1.0);
358        } else {
359          // the constant value has no effect on the result of the boolean condition so we can drop the constant term
360          var orNode = orSymbol.CreateTreeNode();
361          orNode.AddSubTree(a);
362          return orNode;
363        }
364      } else {
365        var orNode = orSymbol.CreateTreeNode();
366        orNode.AddSubTree(a);
367        orNode.AddSubTree(b);
368        return orNode;
369      }
370    }
371    private SymbolicExpressionTreeNode MakeAnd(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
372      if (IsConstant(a) && IsConstant(b)) {
373        var constA = a as ConstantTreeNode;
374        var constB = b as ConstantTreeNode;
375        if (constA.Value > 0.0 && constB.Value > 0.0) {
376          return MakeConstant(1.0);
377        } else {
378          return MakeConstant(-1.0);
379        }
380      } else if (IsConstant(a)) {
381        return MakeAnd(b, a);
382      } else if (IsConstant(b)) {
383        var constB = b as ConstantTreeNode;
384        if (constB.Value > 0.0) {
385          // the constant value has no effect on the result of the boolean condition so we can drop the constant term
386          var andNode = andSymbol.CreateTreeNode();
387          andNode.AddSubTree(a);
388          return andNode;
389        } else {
390          // boolean expression is necessarily false
391          return MakeConstant(-1.0);
392        }
393      } else {
394        var andNode = andSymbol.CreateTreeNode();
395        andNode.AddSubTree(a);
396        andNode.AddSubTree(b);
397        return andNode;
398      }
399    }
400    private SymbolicExpressionTreeNode MakeLessThan(SymbolicExpressionTreeNode leftSide, SymbolicExpressionTreeNode rightSide) {
401      if (IsConstant(leftSide) && IsConstant(rightSide)) {
402        var lsConst = leftSide as ConstantTreeNode;
403        var rsConst = rightSide as ConstantTreeNode;
404        if (lsConst.Value < rsConst.Value) return MakeConstant(1.0);
405        else return MakeConstant(-1.0);
406      } else {
407        var ltNode = ltSymbol.CreateTreeNode();
408        ltNode.AddSubTree(leftSide);
409        ltNode.AddSubTree(rightSide);
410        return ltNode;
411      }
412    }
413    private SymbolicExpressionTreeNode MakeGreaterThan(SymbolicExpressionTreeNode leftSide, SymbolicExpressionTreeNode rightSide) {
414      if (IsConstant(leftSide) && IsConstant(rightSide)) {
415        var lsConst = leftSide as ConstantTreeNode;
416        var rsConst = rightSide as ConstantTreeNode;
417        if (lsConst.Value > rsConst.Value) return MakeConstant(1.0);
418        else return MakeConstant(-1.0);
419      } else {
420        var gtNode = gtSymbol.CreateTreeNode();
421        gtNode.AddSubTree(leftSide);
422        gtNode.AddSubTree(rightSide);
423        return gtNode;
424      }
425    }
426    private SymbolicExpressionTreeNode MakeIfThenElse(SymbolicExpressionTreeNode condition, SymbolicExpressionTreeNode trueBranch, SymbolicExpressionTreeNode falseBranch) {
427      if (IsConstant(condition)) {
428        var constT = condition as ConstantTreeNode;
429        if (constT.Value > 0.0) return trueBranch;
430        else return falseBranch;
431      } else {
432        var ifNode = ifThenElseSymbol.CreateTreeNode();
433        ifNode.AddSubTree(condition);
434        ifNode.AddSubTree(trueBranch);
435        ifNode.AddSubTree(falseBranch);
436        return ifNode;
437      }
438    }
439    private SymbolicExpressionTreeNode MakeSine(SymbolicExpressionTreeNode node) {
440      // todo implement more transformation rules
441      if (IsConstant(node)) {
442        var constT = node as ConstantTreeNode;
443        return MakeConstant(Math.Sin(constT.Value));
444      } else {
445        var sineNode = sineSymbol.CreateTreeNode();
446        sineNode.AddSubTree(node);
447        return sineNode;
448      }
449    }
450    private SymbolicExpressionTreeNode MakeTangent(SymbolicExpressionTreeNode node) {
451      // todo implement more transformation rules
452      if (IsConstant(node)) {
453        var constT = node as ConstantTreeNode;
454        return MakeConstant(Math.Tan(constT.Value));
455      } else {
456        var tanNode = tanSymbol.CreateTreeNode();
457        tanNode.AddSubTree(node);
458        return tanNode;
459      }
460    }
461    private SymbolicExpressionTreeNode MakeCosine(SymbolicExpressionTreeNode node) {
462      // todo implement more transformation rules
463      if (IsConstant(node)) {
464        var constT = node as ConstantTreeNode;
465        return MakeConstant(Math.Cos(constT.Value));
466      } else {
467        var cosNode = cosineSymbol.CreateTreeNode();
468        cosNode.AddSubTree(node);
469        return cosNode;
470      }
471    }
472    private SymbolicExpressionTreeNode MakeExp(SymbolicExpressionTreeNode node) {
473      // todo implement more transformation rules
474      if (IsConstant(node)) {
475        var constT = node as ConstantTreeNode;
476        return MakeConstant(Math.Exp(constT.Value));
477      } else {
478        var expNode = expSymbol.CreateTreeNode();
479        expNode.AddSubTree(node);
480        return expNode;
481      }
482    }
483    private SymbolicExpressionTreeNode MakeLog(SymbolicExpressionTreeNode node) {
484      // todo implement more transformation rules
485      if (IsConstant(node)) {
486        var constT = node as ConstantTreeNode;
487        return MakeConstant(Math.Log(constT.Value));
488      } else {
489        var logNode = logSymbol.CreateTreeNode();
490        logNode.AddSubTree(node);
491        return logNode;
492      }
493    }
494
495
496    // MakeFraction, MakeProduct and MakeSum take two already simplified trees and create a new simplified tree
497
498    private SymbolicExpressionTreeNode MakeFraction(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
499      if (IsConstant(a) && IsConstant(b)) {
500        // fold constants
501        return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
502      } if (IsConstant(a) && !((ConstantTreeNode)a).Value.IsAlmost(1.0)) {
503        return MakeFraction(MakeConstant(1.0), MakeProduct(b, Invert(a)));
504      } else if (IsVariable(a) && IsConstant(b)) {
505        // merge constant values into variable weights
506        var constB = ((ConstantTreeNode)b).Value;
507        ((VariableTreeNode)a).Weight /= constB;
508        return a;
509      } else if (IsAddition(a) && IsConstant(b)) {
510        return a.SubTrees
511          .Select(x => GetSimplifiedTree(x))
512         .Select(x => MakeFraction(x, b))
513         .Aggregate((c, d) => MakeSum(c, d));
514      } else if (IsMultiplication(a) && IsConstant(b)) {
515        return MakeProduct(a, Invert(b));
516      } else if (IsDivision(a) && IsConstant(b)) {
517        // (a1 / a2) / c => (a1 / (a2 * c))
518        Trace.Assert(a.SubTrees.Count == 2);
519        return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
520      } else if (IsDivision(a) && IsDivision(b)) {
521        // (a1 / a2) / (b1 / b2) =>
522        Trace.Assert(a.SubTrees.Count == 2);
523        Trace.Assert(b.SubTrees.Count == 2);
524        return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[1]), MakeProduct(a.SubTrees[1], b.SubTrees[0]));
525      } else if (IsDivision(a)) {
526        // (a1 / a2) / b => (a1 / (a2 * b))
527        Trace.Assert(a.SubTrees.Count == 2);
528        return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
529      } else if (IsDivision(b)) {
530        // a / (b1 / b2) => (a * b2) / b1
531        Trace.Assert(b.SubTrees.Count == 2);
532        return MakeFraction(MakeProduct(a, b.SubTrees[1]), b.SubTrees[0]);
533      } else {
534        var div = divSymbol.CreateTreeNode();
535        div.AddSubTree(a);
536        div.AddSubTree(b);
537        return div;
538      }
539    }
540
541    private SymbolicExpressionTreeNode MakeSum(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
542      if (IsConstant(a) && IsConstant(b)) {
543        // fold constants
544        ((ConstantTreeNode)a).Value += ((ConstantTreeNode)b).Value;
545        return a;
546      } else if (IsConstant(a)) {
547        // c + x => x + c
548        // b is not constant => make sure constant is on the right
549        return MakeSum(b, a);
550      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(0.0)) {
551        // x + 0 => x
552        return a;
553      } else if (IsAddition(a) && IsAddition(b)) {
554        // merge additions
555        var add = addSymbol.CreateTreeNode();
556        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
557        for (int i = 0; i < b.SubTrees.Count - 1; i++) add.AddSubTree(b.SubTrees[i]);
558        if (IsConstant(a.SubTrees.Last()) && IsConstant(b.SubTrees.Last())) {
559          add.AddSubTree(MakeSum(a.SubTrees.Last(), b.SubTrees.Last()));
560        } else if (IsConstant(a.SubTrees.Last())) {
561          add.AddSubTree(b.SubTrees.Last());
562          add.AddSubTree(a.SubTrees.Last());
563        } else {
564          add.AddSubTree(a.SubTrees.Last());
565          add.AddSubTree(b.SubTrees.Last());
566        }
567        MergeVariablesInSum(add);
568        return add;
569      } else if (IsAddition(b)) {
570        return MakeSum(b, a);
571      } else if (IsAddition(a) && IsConstant(b)) {
572        // a is an addition and b is a constant => append b to a and make sure the constants are merged
573        var add = addSymbol.CreateTreeNode();
574        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
575        if (IsConstant(a.SubTrees.Last()))
576          add.AddSubTree(MakeSum(a.SubTrees.Last(), b));
577        else {
578          add.AddSubTree(a.SubTrees.Last());
579          add.AddSubTree(b);
580        }
581        return add;
582      } else if (IsAddition(a)) {
583        // a is already an addition => append b
584        var add = addSymbol.CreateTreeNode();
585        add.AddSubTree(b);
586        foreach (var subTree in a.SubTrees) {
587          add.AddSubTree(subTree);
588        }
589        MergeVariablesInSum(add);
590        return add;
591      } else {
592        var add = addSymbol.CreateTreeNode();
593        add.AddSubTree(a);
594        add.AddSubTree(b);
595        MergeVariablesInSum(add);
596        return add;
597      }
598    }
599
600    // makes sure variable symbols in sums are combined
601    // possible improvment: combine sums of products where the products only reference the same variable
602    private void MergeVariablesInSum(SymbolicExpressionTreeNode sum) {
603      var subtrees = new List<SymbolicExpressionTreeNode>(sum.SubTrees);
604      while (sum.SubTrees.Count > 0) sum.RemoveSubTree(0);
605      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
606                            group node by node.VariableName into g
607                            select g;
608      var unchangedSubTrees = subtrees.Where(t => !(t is VariableTreeNode));
609
610      foreach (var variableNodeGroup in groupedVarNodes) {
611        var weightSum = variableNodeGroup.Select(t => t.Weight).Sum();
612        var representative = variableNodeGroup.First();
613        representative.Weight = weightSum;
614        sum.AddSubTree(representative);
615      }
616      foreach (var unchangedSubtree in unchangedSubTrees)
617        sum.AddSubTree(unchangedSubtree);
618    }
619
620
621    private SymbolicExpressionTreeNode MakeProduct(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
622      if (IsConstant(a) && IsConstant(b)) {
623        // fold constants
624        ((ConstantTreeNode)a).Value *= ((ConstantTreeNode)b).Value;
625        return a;
626      } else if (IsConstant(a)) {
627        // a * $ => $ * a
628        return MakeProduct(b, a);
629      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
630        // $ * 1.0 => $
631        return a;
632      } else if (IsConstant(b) && IsVariable(a)) {
633        // multiply constants into variables weights
634        ((VariableTreeNode)a).Weight *= ((ConstantTreeNode)b).Value;
635        return a;
636      } else if (IsConstant(b) && IsAddition(a)) {
637        // multiply constants into additions
638        return a.SubTrees.Select(x => MakeProduct(x, b)).Aggregate((c, d) => MakeSum(c, d));
639      } else if (IsDivision(a) && IsDivision(b)) {
640        // (a1 / a2) * (b1 / b2) => (a1 * b1) / (a2 * b2)
641        Trace.Assert(a.SubTrees.Count == 2);
642        Trace.Assert(b.SubTrees.Count == 2);
643        return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[0]), MakeProduct(a.SubTrees[1], b.SubTrees[1]));
644      } else if (IsDivision(a)) {
645        // (a1 / a2) * b => (a1 * b) / a2
646        Trace.Assert(a.SubTrees.Count == 2);
647        return MakeFraction(MakeProduct(a.SubTrees[0], b), a.SubTrees[1]);
648      } else if (IsDivision(b)) {
649        // a * (b1 / b2) => (b1 * a) / b2
650        Trace.Assert(b.SubTrees.Count == 2);
651        return MakeFraction(MakeProduct(b.SubTrees[0], a), b.SubTrees[1]);
652      } else if (IsMultiplication(a) && IsMultiplication(b)) {
653        // merge multiplications (make sure constants are merged)
654        var mul = mulSymbol.CreateTreeNode();
655        for (int i = 0; i < a.SubTrees.Count; i++) mul.AddSubTree(a.SubTrees[i]);
656        for (int i = 0; i < b.SubTrees.Count; i++) mul.AddSubTree(b.SubTrees[i]);
657        MergeVariablesAndConstantsInProduct(mul);
658        return mul;
659      } else if (IsMultiplication(b)) {
660        return MakeProduct(b, a);
661      } else if (IsMultiplication(a)) {
662        // a is already an multiplication => append b
663        a.AddSubTree(b);
664        MergeVariablesAndConstantsInProduct(a);
665        return a;
666      } else {
667        var mul = mulSymbol.CreateTreeNode();
668        mul.SubTrees.Add(a);
669        mul.SubTrees.Add(b);
670        MergeVariablesAndConstantsInProduct(mul);
671        return mul;
672      }
673    }
674    #endregion
675
676    // helper to combine the constant factors in products and to combine variables (powers of 2, 3...)
677    private void MergeVariablesAndConstantsInProduct(SymbolicExpressionTreeNode prod) {
678      var subtrees = new List<SymbolicExpressionTreeNode>(prod.SubTrees);
679      while (prod.SubTrees.Count > 0) prod.RemoveSubTree(0);
680      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
681                            group node by node.VariableName into g
682                            orderby g.Count()
683                            select g;
684      var constantProduct = (from node in subtrees.OfType<VariableTreeNode>()
685                             select node.Weight)
686                            .Concat(from node in subtrees.OfType<ConstantTreeNode>()
687                                    select node.Value)
688                            .DefaultIfEmpty(1.0)
689                            .Aggregate((c1, c2) => c1 * c2);
690
691      var unchangedSubTrees = from tree in subtrees
692                              where !(tree is VariableTreeNode)
693                              where !(tree is ConstantTreeNode)
694                              select tree;
695
696      foreach (var variableNodeGroup in groupedVarNodes) {
697        var representative = variableNodeGroup.First();
698        representative.Weight = 1.0;
699        if (variableNodeGroup.Count() > 1) {
700          var poly = mulSymbol.CreateTreeNode();
701          for (int p = 0; p < variableNodeGroup.Count(); p++) {
702            poly.AddSubTree((SymbolicExpressionTreeNode)representative.Clone());
703          }
704          prod.AddSubTree(poly);
705        } else {
706          prod.AddSubTree(representative);
707        }
708      }
709
710      foreach (var unchangedSubtree in unchangedSubTrees)
711        prod.AddSubTree(unchangedSubtree);
712
713      if (!constantProduct.IsAlmost(1.0)) {
714        prod.AddSubTree(MakeConstant(constantProduct));
715      }
716    }
717
718
719    #region helper functions
720    /// <summary>
721    /// x => x * -1
722    /// Doesn't create new trees and manipulates x
723    /// </summary>
724    /// <param name="x"></param>
725    /// <returns>-x</returns>
726    private SymbolicExpressionTreeNode Negate(SymbolicExpressionTreeNode x) {
727      if (IsConstant(x)) {
728        ((ConstantTreeNode)x).Value *= -1;
729      } else if (IsVariable(x)) {
730        var variableTree = (VariableTreeNode)x;
731        variableTree.Weight *= -1.0;
732      } else if (IsAddition(x)) {
733        // (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)       
734        foreach (var subTree in x.SubTrees) {
735          Negate(subTree);
736        }
737      } else if (IsMultiplication(x) || IsDivision(x)) {
738        // x0 * x1 * .. * xn * -1 => x0 * x1 * .. * -xn
739        Negate(x.SubTrees.Last()); // last is maybe a constant, prefer to negate the constant
740      } else {
741        // any other function
742        return MakeProduct(x, MakeConstant(-1));
743      }
744      return x;
745    }
746
747    /// <summary>
748    /// x => 1/x
749    /// Doesn't create new trees and manipulates x
750    /// </summary>
751    /// <param name="x"></param>
752    /// <returns></returns>
753    private SymbolicExpressionTreeNode Invert(SymbolicExpressionTreeNode x) {
754      if (IsConstant(x)) {
755        return MakeConstant(1.0 / ((ConstantTreeNode)x).Value);
756      } else if (IsDivision(x)) {
757        Trace.Assert(x.SubTrees.Count == 2);
758        return MakeFraction(x.SubTrees[1], x.SubTrees[0]);
759      } else {
760        // any other function
761        return MakeFraction(MakeConstant(1), x);
762      }
763    }
764
765    private SymbolicExpressionTreeNode MakeConstant(double value) {
766      ConstantTreeNode constantTreeNode = (ConstantTreeNode)(constSymbol.CreateTreeNode());
767      constantTreeNode.Value = value;
768      return (SymbolicExpressionTreeNode)constantTreeNode;
769    }
770
771    private SymbolicExpressionTreeNode MakeVariable(double weight, string name) {
772      var tree = (VariableTreeNode)varSymbol.CreateTreeNode();
773      tree.Weight = weight;
774      tree.VariableName = name;
775      return tree;
776    }
777    #endregion
778  }
779}
Note: See TracBrowser for help on using the repository browser.