source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/DerivativeCalculator.cs @ 16213

Last change on this file since 16213 was 16213, checked in by gkronber, 12 months ago

#2948 changed unit test cases to assert results of derivative calculations. Fixed bug in deriving sqrt(x)

File size: 8.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
25
26namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
27  public static class DerivativeCalculator {
28    public static ISymbolicExpressionTree Derive(ISymbolicExpressionTree tree, string variableName) {
29      var mainBranch = tree.Root.GetSubtree(0).GetSubtree(0);
30      var root = new ProgramRootSymbol().CreateTreeNode();
31      root.AddSubtree(new StartSymbol().CreateTreeNode());
32      var dTree = TreeSimplifier.GetSimplifiedTree(Derive(mainBranch, variableName));
33      // var dTree = Derive(mainBranch, variableName);
34      root.GetSubtree(0).AddSubtree(dTree);
35      return new SymbolicExpressionTree(root);
36    }
37
38    private static Constant constantSy = new Constant();
39    private static Addition addSy = new Addition();
40    private static Subtraction subSy = new Subtraction();
41    private static Multiplication mulSy = new Multiplication();
42    private static Division divSy = new Division();
43
44    public static ISymbolicExpressionTreeNode Derive(ISymbolicExpressionTreeNode branch, string variableName) {
45      if (branch.Symbol is Constant) {
46        return CreateConstant(0.0);
47      }
48      if (branch.Symbol is Variable) {
49        var varNode = branch as VariableTreeNode;
50        if (varNode.VariableName == variableName) {
51          return CreateConstant(varNode.Weight);
52        } else {
53          return CreateConstant(0.0);
54        }
55      }
56      if (branch.Symbol is Addition) {
57        var sum = addSy.CreateTreeNode();
58        foreach (var subTree in branch.Subtrees) {
59          sum.AddSubtree(Derive(subTree, variableName));
60        }
61        return sum;
62      }
63      if (branch.Symbol is Subtraction) {
64        var sum = subSy.CreateTreeNode();
65        foreach (var subTree in branch.Subtrees) {
66          sum.AddSubtree(Derive(subTree, variableName));
67        }
68        return sum;
69      }
70      if (branch.Symbol is Multiplication) {
71        // (f * g)' = f'*g + f*g'
72        // for multiple factors: (f * g * h)' = ((f*g) * h)' = (f*g)' * h + (f*g) * h'
73
74        if (branch.SubtreeCount >= 2) {
75          var f = (ISymbolicExpressionTreeNode)branch.GetSubtree(0).Clone();
76          var g = (ISymbolicExpressionTreeNode)branch.GetSubtree(1).Clone();
77          var fprime = Derive(f, variableName);
78          var gprime = Derive(g, variableName);
79          var fgPrime = Sum(Product(f, gprime), Product(fprime, g));
80          for (int i = 2; i < branch.SubtreeCount; i++) {
81            var fg = Product((ISymbolicExpressionTreeNode)f.Clone(), (ISymbolicExpressionTreeNode)g.Clone());
82            var h = (ISymbolicExpressionTreeNode)branch.GetSubtree(i).Clone();
83            var hPrime = Derive(h, variableName);
84            fgPrime = Sum(Product(fgPrime, h), Product(fg, hPrime));
85          }
86          return fgPrime;
87        } else throw new ArgumentException();
88      }
89      if (branch.Symbol is Division) {
90        // (f/g)' = (f'g - g'f) / g²
91        if (branch.SubtreeCount == 1) {
92          var g = (ISymbolicExpressionTreeNode)branch.GetSubtree(0).Clone();
93          var gPrime = Product(CreateConstant(-1.0), Derive(g, variableName));
94          var sqrNode = new Square().CreateTreeNode();
95          sqrNode.AddSubtree(g);
96          return Div(gPrime, sqrNode);
97        } else if (branch.SubtreeCount == 2) {
98          var f = (ISymbolicExpressionTreeNode)branch.GetSubtree(0).Clone();
99          var g = (ISymbolicExpressionTreeNode)branch.GetSubtree(1).Clone();
100          var fprime = Derive(f, variableName);
101          var gprime = Derive(g, variableName);
102          var sqrNode = new Square().CreateTreeNode();
103          sqrNode.AddSubtree((ISymbolicExpressionTreeNode)branch.GetSubtree(1).Clone());
104          return Div(Subtract(Product(fprime, g), Product(f, gprime)), sqrNode);
105        } else throw new NotSupportedException();
106      }
107      if (branch.Symbol is Logarithm) {
108        var f = (ISymbolicExpressionTreeNode)branch.GetSubtree(0).Clone();
109        return Product(Div(CreateConstant(1.0), f), Derive(f, variableName));
110      }
111      if (branch.Symbol is Exponential) {
112        var f = (ISymbolicExpressionTreeNode)branch.Clone();
113        return Product(f, Derive(branch.GetSubtree(0), variableName));
114      }
115      if(branch.Symbol is Square) {
116        var f = (ISymbolicExpressionTreeNode)branch.GetSubtree(0).Clone();
117        return Product(Product(CreateConstant(2.0), f), Derive(f, variableName));
118      }     
119      if(branch.Symbol is SquareRoot) {
120        var f = (ISymbolicExpressionTreeNode)branch.Clone();
121        var u = (ISymbolicExpressionTreeNode)branch.GetSubtree(0).Clone();
122        return Product(Div(CreateConstant(1.0), Product(CreateConstant(2.0), f)), Derive(u, variableName));
123      }
124      if (branch.Symbol is Sine) {
125        var u = (ISymbolicExpressionTreeNode)branch.GetSubtree(0).Clone();
126        var cos = (new Cosine()).CreateTreeNode();
127        cos.AddSubtree(u);
128        return Product(cos, Derive(u, variableName));
129      }
130      if (branch.Symbol is Cosine) {
131        var u = (ISymbolicExpressionTreeNode)branch.GetSubtree(0).Clone();
132        var sin = (new Sine()).CreateTreeNode();
133        sin.AddSubtree(u);
134        return Product(CreateConstant(-1.0), Product(sin, Derive(u, variableName)));
135      }
136      throw new NotSupportedException($"Symbol {branch.Symbol} is not supported.");
137    }
138
139
140    private static ISymbolicExpressionTreeNode Product(ISymbolicExpressionTreeNode f, ISymbolicExpressionTreeNode g) {
141      var product = mulSy.CreateTreeNode();
142      product.AddSubtree(f);
143      product.AddSubtree(g);
144      return product;
145    }
146    private static ISymbolicExpressionTreeNode Div(ISymbolicExpressionTreeNode f, ISymbolicExpressionTreeNode g) {
147      var div = divSy.CreateTreeNode();
148      div.AddSubtree(f);
149      div.AddSubtree(g);
150      return div;
151    }
152
153    private static ISymbolicExpressionTreeNode Sum(ISymbolicExpressionTreeNode f, ISymbolicExpressionTreeNode g) {
154      var sum = addSy.CreateTreeNode();
155      sum.AddSubtree(f);
156      sum.AddSubtree(g);
157      return sum;
158    }
159    private static ISymbolicExpressionTreeNode Subtract(ISymbolicExpressionTreeNode f, ISymbolicExpressionTreeNode g) {
160      var sum = subSy.CreateTreeNode();
161      sum.AddSubtree(f);
162      sum.AddSubtree(g);
163      return sum;
164    }
165                         
166    private static ISymbolicExpressionTreeNode CreateConstant(double v) {
167      var constNode = (ConstantTreeNode)constantSy.CreateTreeNode();
168      constNode.Value = v;
169      return constNode;
170    }
171
172    public static bool IsCompatible(ISymbolicExpressionTree tree) {
173      var containsUnknownSymbol = (
174        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
175        where
176          !(n.Symbol is Variable) &&
177          !(n.Symbol is Constant) &&
178          !(n.Symbol is Addition) &&
179          !(n.Symbol is Subtraction) &&
180          !(n.Symbol is Multiplication) &&
181          !(n.Symbol is Division) &&
182          !(n.Symbol is Logarithm) &&
183          !(n.Symbol is Exponential) &&
184          !(n.Symbol is Square) &&
185          !(n.Symbol is SquareRoot) &&
186          !(n.Symbol is Sine) &&
187          !(n.Symbol is Cosine) &&
188          !(n.Symbol is StartSymbol)
189        select n).Any();
190      return !containsUnknownSymbol;
191    }
192  }
193}
Note: See TracBrowser for help on using the repository browser.