Free cookie consent management tool by TermsFeed Policy Generator

source: branches/Async/HeuristicLab.ExtLibs/HeuristicLab.AutoDiff/1.0/AutoDiff-1.0/CompiledDifferentiator.cs @ 13329

Last change on this file since 13329 was 8952, checked in by mkommend, 12 years ago

#1976: Performance improvements in AutoDiff library.

File size: 3.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Diagnostics.Contracts;
6using System.Diagnostics;
7using System.Collections.ObjectModel;
8
9namespace AutoDiff
10{
11    /// <summary>
12    /// Compiles the terms tree to a more efficient form for differentiation.
13    /// </summary>
14    internal partial class CompiledDifferentiator<T> : ICompiledTerm
15        where T : IList<Variable>
16    {
17        private readonly Compiled.TapeElement[] tape;
18
19        /// <summary>
20        /// Initializes a new instance of the <see cref="CompiledDifferentiator"/> class.
21        /// </summary>
22        /// <param name="function">The function.</param>
23        /// <param name="variables">The variables.</param>
24        public CompiledDifferentiator(Term function, T variables)
25        {
26            Contract.Requires(function != null);
27            Contract.Requires(variables != null);
28            Contract.Requires(Contract.ForAll(variables, variable => variable != null));
29            Contract.Ensures(Dimension == variables.Count);
30
31            if (function is Variable)
32                function = new ConstPower(function, 1);
33
34            var tapeList = new List<Compiled.TapeElement>();
35            new Compiler(variables, tapeList).Compile(function);
36            tape = tapeList.ToArray();
37
38            Dimension = variables.Count;
39            Variables = new ReadOnlyCollection<Variable>(variables);
40        }
41
42        public int Dimension { get; private set; }
43
44        public double Evaluate(double[] arg)
45        {
46            EvaluateTape(arg);
47            return tape.Last().Value;
48        }
49
50        public Tuple<double[], double> Differentiate<S>(S arg)
51            where S : IList<double>
52        {
53            ForwardSweep(arg);
54            ReverseSweep();
55
56            var gradient = new double[Dimension];
57            for (int i = 0; i < Dimension; i++)
58               gradient[i] = tape[i].Adjoint;
59            var value = tape.Last().Value;           
60
61            return Tuple.Create(gradient, value);
62        }
63
64        public Tuple<double[], double> Differentiate(params double[] arg)
65        {
66            return Differentiate<double[]>(arg);
67        }
68
69        private void ReverseSweep()
70        {
71            tape.Last().Adjoint = 1;
72           
73            // initialize adjoints
74            for (int i = 0; i < tape.Length - 1; ++i)
75                tape[i].Adjoint = 0;
76
77            // accumulate adjoints
78            for (int i = tape.Length - 1; i >= Dimension; --i)
79            {
80                var inputs = tape[i].Inputs;
81                var adjoint = tape[i].Adjoint;
82               
83                for(int j = 0; j < inputs.Length; ++j)
84                    tape[inputs[j].Index].Adjoint += adjoint * inputs[j].Weight;
85            }
86        }
87
88        private void ForwardSweep<S>(S arg)
89            where S : IList<double>
90        {
91            for (int i = 0; i < Dimension; ++i)
92                tape[i].Value = arg[i];
93
94            var forwardDiffVisitor = new ForwardSweepVisitor(tape);
95            for (int i = Dimension; i < tape.Length; ++i)
96                tape[i].Accept(forwardDiffVisitor);
97        }
98
99        private void EvaluateTape(double[] arg)
100        {
101            for(int i = 0; i < Dimension; ++i)
102                tape[i].Value = arg[i];
103            var evalVisitor = new EvalVisitor(tape);
104            for (int i = Dimension; i < tape.Length; ++i )
105                tape[i].Accept(evalVisitor);
106        }
107
108        private double ValueOf(int index)
109        {
110            return tape[index].Value;
111        }
112
113        public ReadOnlyCollection<Variable> Variables { get; private set; }
114
115
116    }
117}
Note: See TracBrowser for help on using the repository browser.