Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.ExtLibs/HeuristicLab.AutoDiff/1.0/AutoDiff-1.0/CompiledDifferentiator.cs @ 8731

Last change on this file since 8731 was 8703, checked in by gkronber, 12 years ago

#1960 added HL wrapper plugin for AutoDiff

File size: 3.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Diagnostics.Contracts;
6using System.Diagnostics;
7using System.Collections.ObjectModel;
8
9namespace AutoDiff
10{
11    /// <summary>
12    /// Compiles the terms tree to a more efficient form for differentiation.
13    /// </summary>
14    internal partial class CompiledDifferentiator<T> : ICompiledTerm
15        where T : IList<Variable>
16    {
17        private readonly Compiled.TapeElement[] tape;
18
19        /// <summary>
20        /// Initializes a new instance of the <see cref="CompiledDifferentiator"/> class.
21        /// </summary>
22        /// <param name="function">The function.</param>
23        /// <param name="variables">The variables.</param>
24        public CompiledDifferentiator(Term function, T variables)
25        {
26            Contract.Requires(function != null);
27            Contract.Requires(variables != null);
28            Contract.Requires(Contract.ForAll(variables, variable => variable != null));
29            Contract.Ensures(Dimension == variables.Count);
30
31            if (function is Variable)
32                function = new ConstPower(function, 1);
33
34            var tapeList = new List<Compiled.TapeElement>();
35            new Compiler(variables, tapeList).Compile(function);
36            tape = tapeList.ToArray();
37
38            Dimension = variables.Count;
39            Variables = new ReadOnlyCollection<Variable>(variables);
40        }
41
42        public int Dimension { get; private set; }
43
44        public double Evaluate(double[] arg)
45        {
46            EvaluateTape(arg);
47            return tape.Last().Value;
48        }
49
50        public Tuple<double[], double> Differentiate<S>(S arg)
51            where S : IList<double>
52        {
53            ForwardSweep(arg);
54            ReverseSweep();
55
56            var gradient = tape.Take(Dimension).Select(elem => elem.Adjoint).ToArray();
57            var value = tape.Last().Value;
58
59            return Tuple.Create(gradient, value);
60        }
61
62        public Tuple<double[], double> Differentiate(params double[] arg)
63        {
64            return Differentiate<double[]>(arg);
65        }
66
67        private void ReverseSweep()
68        {
69            tape.Last().Adjoint = 1;
70           
71            // initialize adjoints
72            for (int i = 0; i < tape.Length - 1; ++i)
73                tape[i].Adjoint = 0;
74
75            // accumulate adjoints
76            for (int i = tape.Length - 1; i >= Dimension; --i)
77            {
78                var inputs = tape[i].Inputs;
79                var adjoint = tape[i].Adjoint;
80               
81                for(int j = 0; j < inputs.Length; ++j)
82                    tape[inputs[j].Index].Adjoint += adjoint * inputs[j].Weight;
83            }
84        }
85
86        private void ForwardSweep<S>(S arg)
87            where S : IList<double>
88        {
89            for (int i = 0; i < Dimension; ++i)
90                tape[i].Value = arg[i];
91
92            var forwardDiffVisitor = new ForwardSweepVisitor(tape);
93            for (int i = Dimension; i < tape.Length; ++i)
94                tape[i].Accept(forwardDiffVisitor);
95        }
96
97        private void EvaluateTape(double[] arg)
98        {
99            for(int i = 0; i < Dimension; ++i)
100                tape[i].Value = arg[i];
101            var evalVisitor = new EvalVisitor(tape);
102            for (int i = Dimension; i < tape.Length; ++i )
103                tape[i].Accept(evalVisitor);
104        }
105
106        private double ValueOf(int index)
107        {
108            return tape[index].Value;
109        }
110
111        public ReadOnlyCollection<Variable> Variables { get; private set; }
112
113
114    }
115}
Note: See TracBrowser for help on using the repository browser.