Free cookie consent management tool by TermsFeed Policy Generator

source: branches/RemoveBackwardsCompatibility/HeuristicLab.ExtLibs/HeuristicLab.AutoDiff/1.0/AutoDiff-1.0/CompiledDifferentiator.Compiler.cs @ 13820

Last change on this file since 13820 was 8703, checked in by gkronber, 12 years ago

#1960 added HL wrapper plugin for AutoDiff

File size: 8.4 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5
6using CompileResult = AutoDiff.Compiled.TapeElement;
7
8namespace AutoDiff
9{
10    partial class CompiledDifferentiator<T>
11    {
12        private class Compiler : ITermVisitor<int> // int --> the index of the compiled element in the tape
13        {
14            private readonly List<Compiled.TapeElement> tape;
15            private readonly Dictionary<Term, int> indexOf;
16
17            public Compiler(T variables, List<Compiled.TapeElement> tape)
18            {
19                this.tape = tape;
20                indexOf = new Dictionary<Term, int>();
21                foreach (var i in Enumerable.Range(0, variables.Count))
22                {
23                    indexOf[variables[i]] = i;
24                    tape.Add(new Compiled.Variable());
25                }
26            }
27
28            public void Compile(Term term)
29            {
30                term.Accept(this);
31            }
32
33            public int Visit(Constant constant)
34            {
35                return Compile(constant, () => new Compiled.Constant(constant.Value) { Inputs = new Compiled.InputEdge[0] });
36            }
37
38            public int Visit(Zero zero)
39            {
40                return Compile(zero, () => new Compiled.Constant(0) { Inputs = new Compiled.InputEdge[0] });
41            }
42
43            public int Visit(ConstPower intPower)
44            {
45                return Compile(intPower, () =>
46                    {
47                        var baseIndex = intPower.Base.Accept(this);
48                        var element = new Compiled.ConstPower
49                        {
50                            Base = baseIndex,
51                            Exponent = intPower.Exponent,
52                            Inputs = new Compiled.InputEdge[]
53                            {
54                                new Compiled.InputEdge { Index = baseIndex },
55                            },
56                        };
57
58                        return element;
59                    });
60            }
61
62            public int Visit(TermPower power)
63            {
64                return Compile(power, () =>
65                {
66                    var baseIndex = power.Base.Accept(this);
67                    var expIndex = power.Exponent.Accept(this);
68                    var element = new Compiled.TermPower
69                    {
70                        Base = baseIndex,
71                        Exponent = expIndex,
72                        Inputs = new Compiled.InputEdge[]
73                        {
74                            new Compiled.InputEdge { Index = baseIndex },
75                            new Compiled.InputEdge { Index = expIndex },
76                        },
77                    };
78
79                    return element;
80                });
81            }
82
83            public int Visit(Product product)
84            {
85                return Compile(product, () =>
86                    {
87                        var leftIndex = product.Left.Accept(this);
88                        var rightIndex = product.Right.Accept(this);
89                        var element = new Compiled.Product
90                        {
91                            Left = leftIndex,
92                            Right = rightIndex,
93                            Inputs = new Compiled.InputEdge[]
94                            {
95                                new Compiled.InputEdge { Index = leftIndex },
96                                new Compiled.InputEdge { Index = rightIndex },
97                            }
98                        };
99
100                        return element;
101                    });
102            }
103
104            public int Visit(Sum sum)
105            {
106                return Compile(sum, () =>
107                    {
108                        var indicesQuery = from term in sum.Terms
109                                           select term.Accept(this);
110                        var indices = indicesQuery.ToArray();
111                        var element = new Compiled.Sum
112                        {
113                            Terms = indices,
114                            Inputs = indices.Select(x => new Compiled.InputEdge { Index = x }).ToArray(),
115                        };
116
117                        return element;
118                    });
119            }
120
121            public int Visit(Variable variable)
122            {
123                return indexOf[variable];
124            }
125
126            public int Visit(Log log)
127            {
128                return Compile(log, () =>
129                    {
130                        var argIndex = log.Arg.Accept(this);
131                        var element = new Compiled.Log
132                        {
133                            Arg = argIndex,
134                            Inputs = new Compiled.InputEdge[]
135                            {
136                                new Compiled.InputEdge { Index = argIndex },
137                            },
138                        };
139
140                        return element;
141                    });
142            }
143
144            public int Visit(Exp exp)
145            {
146                return Compile(exp, () =>
147                    {
148                        var argIndex = exp.Arg.Accept(this);
149                        var element = new Compiled.Exp
150                        {
151                            Arg = argIndex,
152                            Inputs = new Compiled.InputEdge[]
153                            {
154                                new Compiled.InputEdge { Index = argIndex },
155                            },
156                        };
157
158                        return element;
159                    });
160            }
161
162            public int Visit(UnaryFunc func)
163            {
164                return Compile(func, () =>
165                    {
166                        var argIndex = func.Argument.Accept(this);
167                        var element = new Compiled.UnaryFunc(func.Eval, func.Diff)
168                        {
169                            Arg = argIndex,
170                            Inputs = new Compiled.InputEdge[]
171                            {
172                                new Compiled.InputEdge { Index = argIndex },
173                            },
174                        };
175
176                        return element;
177                    });
178            }
179
180            public int Visit(BinaryFunc func)
181            {
182                return Compile(func, () =>
183                    {
184                        var leftIndex = func.Left.Accept(this);
185                        var rightIndex = func.Right.Accept(this);
186
187                        var element = new Compiled.BinaryFunc
188                        {
189                            Eval = func.Eval,
190                            Diff = func.Diff,
191                            Left = leftIndex,
192                            Right = rightIndex,
193                            Inputs = new Compiled.InputEdge[]
194                            {
195                                new Compiled.InputEdge { Index = leftIndex },
196                                new Compiled.InputEdge { Index = rightIndex },
197                            }
198                        };
199
200                        return element;
201                    });
202            }
203
204            public int Visit(NaryFunc func)
205            {
206                return Compile(func, () =>
207                {
208                    var indicesQuery = from term in func.Terms
209                                       select term.Accept(this);
210                    var indices = indicesQuery.ToArray();
211
212                    var element = new Compiled.NaryFunc
213                    {
214                        Eval = func.Eval,
215                        Diff = func.Diff,
216                        Terms = indices,
217                        Inputs = indices.Select(x => new Compiled.InputEdge { Index = x }).ToArray(),
218                    };
219
220                    return element;
221                });
222            }
223
224
225            private int Compile(Term term, Func<CompileResult> compiler)
226            {
227                int index;
228                if (!indexOf.TryGetValue(term, out index))
229                {
230                    var compileResult = compiler();
231                    tape.Add(compileResult);
232
233                    index = tape.Count - 1;
234                    indexOf.Add(term, index);
235                }
236
237                return index;
238            }
239        }
240    }
241}
Note: See TracBrowser for help on using the repository browser.