Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SymbolicSimplifier.cs @ 3442

Last change on this file since 3442 was 3442, checked in by gkronber, 15 years ago

Implemented views for DataAnalysisProblems and DataAnalysisSolutions. #938 (Data types and operators for regression problems)

File size: 15.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.GeneralSymbols;
28using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
29using System.Diagnostics;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  /// <summary>
33  /// Simplistic simplifier for arithmetic expressions
34  /// Rules:
35  ///  * Constants are always the last argument to any function
36  ///  * f(c1, c2) => c3 (constant expression folding)
37  ///  * c1 / ( c2 * Var) => ( var * ( c2 / c1))
38  /// </summary>
39  public class SymbolicSimplifier {
40    public SymbolicExpressionTree Simplify(SymbolicExpressionTree originalTree) {
41      var clone = (SymbolicExpressionTreeNode)originalTree.Root.Clone();
42      // macro expand (initially no argument trees)
43      var macroExpandedTree = MacroExpand(clone, clone.SubTrees[0], new List<SymbolicExpressionTreeNode>());
44      return new SymbolicExpressionTree(GetSimplifiedTree(macroExpandedTree));
45    }
46
47    // the argumentTrees list contains already expanded trees used as arguments for invocations
48    private SymbolicExpressionTreeNode MacroExpand(SymbolicExpressionTreeNode root, SymbolicExpressionTreeNode node, IList<SymbolicExpressionTreeNode> argumentTrees) {
49      List<SymbolicExpressionTreeNode> subtrees = new List<SymbolicExpressionTreeNode>(node.SubTrees);
50      while (node.SubTrees.Count > 0) node.SubTrees.RemoveAt(0);
51      if (node.Symbol is InvokeFunction) {
52        var invokeSym = node.Symbol as InvokeFunction;
53        var defunNode = FindFunctionDefinition(root, invokeSym.FunctionName);
54        var macroExpandedArguments = new List<SymbolicExpressionTreeNode>();
55        foreach (var subtree in subtrees) {
56          macroExpandedArguments.Add(MacroExpand(root, subtree, argumentTrees));
57        }
58        return MacroExpand(root, defunNode, macroExpandedArguments);
59      } else if (node.Symbol is Argument) {
60        var argSym = node.Symbol as Argument;
61        // return the correct argument sub-tree (already macro-expanded)
62        return (SymbolicExpressionTreeNode)argumentTrees[argSym.ArgumentIndex].Clone();
63      } else {
64        // recursive application
65        foreach (var subtree in subtrees) {
66          node.AddSubTree(MacroExpand(root, subtree, argumentTrees));
67        }
68        return node;
69      }
70    }
71
72    private SymbolicExpressionTreeNode FindFunctionDefinition(SymbolicExpressionTreeNode root, string functionName) {
73      foreach (var subtree in root.SubTrees.OfType<DefunTreeNode>()) {
74        if (subtree.FunctionName == functionName) return subtree.SubTrees[0];
75      }
76
77      throw new ArgumentException("Definition of function " + functionName + " not found.");
78    }
79
80    /// <summary>
81    /// Creates a new simplified tree
82    /// </summary>
83    /// <param name="original"></param>
84    /// <returns></returns>
85    public SymbolicExpressionTreeNode GetSimplifiedTree(SymbolicExpressionTreeNode original) {
86      if (IsConstant(original) || IsVariable(original)) {
87        return (SymbolicExpressionTreeNode)original.Clone();
88      } else if (IsAddition(original)) {
89        if (original.SubTrees.Count == 1) {
90          return GetSimplifiedTree(original.SubTrees[0]);
91        } else {
92          // simplify expression x0..xn
93          // make addition (x0..xn)
94          Trace.Assert(original.SubTrees.Count > 1);
95          return original.SubTrees
96            .Select(x => GetSimplifiedTree(x))
97            .Aggregate((a, b) => MakeAddition(a, b));
98        }
99      } else if (IsSubtraction(original)) {
100        if (original.SubTrees.Count == 1) {
101          return Negate(GetSimplifiedTree(original.SubTrees[0]));
102        } else {
103          // simplify expressions x0..xn
104          // make addition (x0,-x1..-xn)
105          Trace.Assert(original.SubTrees.Count > 1);
106          var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
107          return simplifiedTrees.Take(1)
108            .Concat(simplifiedTrees.Skip(1).Select(x => Negate(x)))
109            .Aggregate((a, b) => MakeAddition(a, b));
110        }
111      } else if (IsMultiplication(original)) {
112        if (original.SubTrees.Count == 1) {
113          return GetSimplifiedTree(original.SubTrees[0]);
114        } else {
115          Trace.Assert(original.SubTrees.Count > 1);
116          return original.SubTrees
117            .Select(x => GetSimplifiedTree(x))
118            .Aggregate((a, b) => MakeMultiplication(a, b));
119        }
120      } else if (IsDivision(original)) {
121        if (original.SubTrees.Count == 1) {
122          return Invert(GetSimplifiedTree(original.SubTrees[0]));
123        } else {
124          // simplify expressions x0..xn
125          // make multiplication (x0 * 1/(x1 * x1 * .. * xn))
126          Trace.Assert(original.SubTrees.Count > 1);
127          var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
128          return
129            MakeMultiplication(simplifiedTrees.First(), Invert(simplifiedTrees.Skip(1).Aggregate((a, b) => MakeMultiplication(a, b))));
130        }
131      } else {
132        // can't simplify this function but simplify all subtrees
133        // TODO evaluate the function if original is a constant expression
134        List<SymbolicExpressionTreeNode> subTrees = new List<SymbolicExpressionTreeNode>(original.SubTrees);
135        while (original.SubTrees.Count > 0) original.RemoveSubTree(0);
136        var clone = (SymbolicExpressionTreeNode)original.Clone();
137        foreach (var subTree in subTrees) {
138          clone.AddSubTree(GetSimplifiedTree(subTree));
139          original.AddSubTree(subTree);
140        }
141        return clone;
142      }
143    }
144
145    /// <summary>
146    /// x => x * -1
147    /// Doesn't create new trees and manipulates x
148    /// </summary>
149    /// <param name="x"></param>
150    /// <returns>-x</returns>
151    private SymbolicExpressionTreeNode Negate(SymbolicExpressionTreeNode x) {
152      if (IsConstant(x)) {
153        ((ConstantTreeNode)x).Value *= -1;
154      } else if (IsVariable(x)) {
155        var variableTree = (VariableTreeNode)x;
156        variableTree.Weight *= -1.0;
157      } else if (IsAddition(x)) {
158        // (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)       
159        foreach (var subTree in x.SubTrees) {
160          Negate(subTree);
161        }
162      } else if (IsMultiplication(x) || IsDivision(x)) {
163        // x0 * x1 * .. * xn * -1 => x0 * x1 * .. * -xn
164        Negate(x.SubTrees.Last()); // last is maybe a constant, prefer to negate the constant
165      } else {
166        // any other function
167        return MakeMultiplication(x, MakeConstant(-1));
168      }
169      return x;
170    }
171
172    /// <summary>
173    /// x => 1/x
174    /// Doesn't create new trees and manipulates x
175    /// </summary>
176    /// <param name="x"></param>
177    /// <returns></returns>
178    private SymbolicExpressionTreeNode Invert(SymbolicExpressionTreeNode x) {
179      if (IsConstant(x)) {
180        ((ConstantTreeNode)x).Value = 1.0 / ((ConstantTreeNode)x).Value;
181      } else {
182        // any other function
183        return MakeDivision(MakeConstant(1), x);
184      }
185      return x;
186    }
187
188    private SymbolicExpressionTreeNode MakeDivision(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
189      if (IsConstant(a) && IsConstant(b)) {
190        return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
191      } else if (IsVariable(a) && IsConstant(b)) {
192        var constB = ((ConstantTreeNode)b).Value;
193        ((VariableTreeNode)a).Weight /= constB;
194        return a;
195      } else {
196        var div = (new Division()).CreateTreeNode();
197        div.SubTrees.Add(a);
198        div.SubTrees.Add(b);
199        return div;
200      }
201    }
202
203    private SymbolicExpressionTreeNode MakeAddition(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
204      if (IsConstant(a) && IsConstant(b)) {
205        // merge constants
206        ((ConstantTreeNode)a).Value += ((ConstantTreeNode)b).Value;
207        return a;
208      } else if (IsConstant(a)) {
209        // c + x => x + c
210        // b is not constant => make sure constant is on the right
211        return MakeAddition(b, a);
212      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(0.0)) {
213        // x + 0 => x
214        return a;
215      } else if (IsAddition(a) && IsAddition(b)) {
216        // merge additions
217        var add = (new Addition()).CreateTreeNode();
218        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
219        for (int i = 0; i < b.SubTrees.Count - 1; i++) add.AddSubTree(b.SubTrees[i]);
220        if (IsConstant(a.SubTrees.Last()) && IsConstant(b.SubTrees.Last())) {
221          add.AddSubTree(MakeAddition(a.SubTrees.Last(), b.SubTrees.Last()));
222        } else if (IsConstant(a.SubTrees.Last())) {
223          add.AddSubTree(b.SubTrees.Last());
224          add.AddSubTree(a.SubTrees.Last());
225        } else {
226          add.AddSubTree(a.SubTrees.Last());
227          add.AddSubTree(b.SubTrees.Last());
228        }
229        MergeVariables(add);
230        return add;
231      } else if (IsAddition(b)) {
232        return MakeAddition(b, a);
233      } else if (IsAddition(a) && IsConstant(b)) {
234        var add = (new Addition()).CreateTreeNode();
235        for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
236        if (IsConstant(a.SubTrees.Last()))
237          add.AddSubTree(MakeAddition(a.SubTrees.Last(), b));
238        else {
239          add.AddSubTree(a.SubTrees.Last());
240          add.AddSubTree(b);
241        }
242        return add;
243      } else if (IsAddition(a)) {
244        var add = (new Addition()).CreateTreeNode();
245        add.AddSubTree(b);
246        foreach (var subTree in a.SubTrees) {
247          add.AddSubTree(subTree);
248        }
249        MergeVariables(add);
250        return add;
251      } else {
252        var add = (new Addition()).CreateTreeNode();
253        add.SubTrees.Add(a);
254        add.SubTrees.Add(b);
255        MergeVariables(add);
256        return add;
257      }
258    }
259
260    private void MergeVariables(SymbolicExpressionTreeNode add) {
261      var subtrees = new List<SymbolicExpressionTreeNode>(add.SubTrees);
262      while (add.SubTrees.Count > 0) add.RemoveSubTree(0);
263      var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
264                            group node by node.VariableName into g
265                            select g;
266      var unchangedSubTrees = subtrees.Where(t => !(t is VariableTreeNode));
267
268      foreach (var variableNodeGroup in groupedVarNodes) {
269        var weightSum = variableNodeGroup.Select(t => t.Weight).Sum();
270        var representative = variableNodeGroup.First();
271        representative.Weight = weightSum;
272        add.AddSubTree(representative);
273      }
274      foreach (var unchangedSubtree in unchangedSubTrees)
275        add.AddSubTree(unchangedSubtree);
276    }
277
278    private SymbolicExpressionTreeNode MakeMultiplication(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
279      if (IsConstant(a) && IsConstant(b)) {
280        ((ConstantTreeNode)a).Value *= ((ConstantTreeNode)b).Value;
281        return a;
282      } else if (IsConstant(a)) {
283        return MakeMultiplication(b, a);
284      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
285        return a;
286      } else if (IsConstant(b) && IsVariable(a)) {
287        ((VariableTreeNode)a).Weight *= ((ConstantTreeNode)b).Value;
288        return a;
289      } else if (IsConstant(b) && IsAddition(a)) {
290        return a.SubTrees.Select(x => MakeMultiplication(x, b)).Aggregate((c, d) => MakeAddition(c, d));
291      } else if (IsDivision(a)) {
292        Trace.Assert(a.SubTrees.Count == 2);
293        return MakeDivision(MakeMultiplication(a.SubTrees[0], b), a.SubTrees[1]);
294      } else if (IsDivision(b)) {
295        Trace.Assert(b.SubTrees.Count == 2);
296        return MakeDivision(MakeMultiplication(b.SubTrees[0], a), b.SubTrees[1]);
297      } else if (IsMultiplication(a) && IsMultiplication(b)) {
298        var mul = (new Multiplication()).CreateTreeNode();
299        for (int i = 0; i < a.SubTrees.Count - 1; i++) mul.AddSubTree(a.SubTrees[i]);
300        for (int i = 0; i < b.SubTrees.Count - 1; i++) mul.AddSubTree(b.SubTrees[i]);
301        mul.AddSubTree(MakeMultiplication(a.SubTrees.Last(), b.SubTrees.Last()));
302        return mul;
303      } else if (IsMultiplication(a)) {
304        var mul = (new Multiplication()).CreateTreeNode();
305        for (int i = 0; i < a.SubTrees.Count - 1; i++) mul.AddSubTree(a.SubTrees[i]);
306        mul.AddSubTree(MakeMultiplication(a.SubTrees.Last(), b));
307        return mul;
308      } else if (IsMultiplication(b)) {
309        var mul = (new Multiplication()).CreateTreeNode();
310        for (int i = 0; i < b.SubTrees.Count - 1; i++) mul.AddSubTree(b.SubTrees[i]);
311        mul.AddSubTree(MakeMultiplication(b.SubTrees.Last(), a));
312        return mul;
313      } else {
314        var mul = (new Multiplication()).CreateTreeNode();
315        mul.SubTrees.Add(a);
316        mul.SubTrees.Add(b);
317        return mul;
318      }
319    }
320
321    #region is symbol ?
322    private bool IsDivision(SymbolicExpressionTreeNode original) {
323      return original.Symbol is Division;
324    }
325
326    private bool IsMultiplication(SymbolicExpressionTreeNode original) {
327      return original.Symbol is Multiplication;
328    }
329
330    private bool IsSubtraction(SymbolicExpressionTreeNode original) {
331      return original.Symbol is Subtraction;
332    }
333
334    private bool IsAddition(SymbolicExpressionTreeNode original) {
335      return original.Symbol is Addition;
336    }
337
338    private bool IsVariable(SymbolicExpressionTreeNode original) {
339      return original.Symbol is Variable;
340    }
341
342    private bool IsConstant(SymbolicExpressionTreeNode original) {
343      return original.Symbol is Constant;
344    }
345    #endregion
346
347    private SymbolicExpressionTreeNode MakeConstant(double value) {
348      ConstantTreeNode constantTreeNode = (ConstantTreeNode)(new Constant().CreateTreeNode());
349      constantTreeNode.Value = value;
350      return (SymbolicExpressionTreeNode)constantTreeNode;
351    }
352
353    private SymbolicExpressionTreeNode MakeVariable(double weight, string name) {
354      var tree = (VariableTreeNode)(new Variable().CreateTreeNode());
355      tree.Weight = weight;
356      tree.VariableName = name;
357      return tree;
358    }
359  }
360}
Note: See TracBrowser for help on using the repository browser.