[8703] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Linq;
|
---|
| 4 | using System.Text;
|
---|
| 5 | using System.Diagnostics.Contracts;
|
---|
| 6 | using System.Diagnostics;
|
---|
| 7 | using System.Collections.ObjectModel;
|
---|
| 8 |
|
---|
| 9 | namespace AutoDiff
|
---|
| 10 | {
|
---|
| 11 | /// <summary>
|
---|
| 12 | /// Compiles the terms tree to a more efficient form for differentiation.
|
---|
| 13 | /// </summary>
|
---|
| 14 | internal partial class CompiledDifferentiator<T> : ICompiledTerm
|
---|
| 15 | where T : IList<Variable>
|
---|
| 16 | {
|
---|
| 17 | private readonly Compiled.TapeElement[] tape;
|
---|
| 18 |
|
---|
| 19 | /// <summary>
|
---|
| 20 | /// Initializes a new instance of the <see cref="CompiledDifferentiator"/> class.
|
---|
| 21 | /// </summary>
|
---|
| 22 | /// <param name="function">The function.</param>
|
---|
| 23 | /// <param name="variables">The variables.</param>
|
---|
| 24 | public CompiledDifferentiator(Term function, T variables)
|
---|
| 25 | {
|
---|
| 26 | Contract.Requires(function != null);
|
---|
| 27 | Contract.Requires(variables != null);
|
---|
| 28 | Contract.Requires(Contract.ForAll(variables, variable => variable != null));
|
---|
| 29 | Contract.Ensures(Dimension == variables.Count);
|
---|
| 30 |
|
---|
| 31 | if (function is Variable)
|
---|
| 32 | function = new ConstPower(function, 1);
|
---|
| 33 |
|
---|
| 34 | var tapeList = new List<Compiled.TapeElement>();
|
---|
| 35 | new Compiler(variables, tapeList).Compile(function);
|
---|
| 36 | tape = tapeList.ToArray();
|
---|
| 37 |
|
---|
| 38 | Dimension = variables.Count;
|
---|
| 39 | Variables = new ReadOnlyCollection<Variable>(variables);
|
---|
| 40 | }
|
---|
| 41 |
|
---|
| 42 | public int Dimension { get; private set; }
|
---|
| 43 |
|
---|
| 44 | public double Evaluate(double[] arg)
|
---|
| 45 | {
|
---|
| 46 | EvaluateTape(arg);
|
---|
| 47 | return tape.Last().Value;
|
---|
| 48 | }
|
---|
| 49 |
|
---|
| 50 | public Tuple<double[], double> Differentiate<S>(S arg)
|
---|
| 51 | where S : IList<double>
|
---|
| 52 | {
|
---|
| 53 | ForwardSweep(arg);
|
---|
| 54 | ReverseSweep();
|
---|
| 55 |
|
---|
[8952] | 56 | var gradient = new double[Dimension];
|
---|
| 57 | for (int i = 0; i < Dimension; i++)
|
---|
| 58 | gradient[i] = tape[i].Adjoint;
|
---|
| 59 | var value = tape.Last().Value;
|
---|
[8703] | 60 |
|
---|
| 61 | return Tuple.Create(gradient, value);
|
---|
| 62 | }
|
---|
| 63 |
|
---|
| 64 | public Tuple<double[], double> Differentiate(params double[] arg)
|
---|
| 65 | {
|
---|
| 66 | return Differentiate<double[]>(arg);
|
---|
| 67 | }
|
---|
| 68 |
|
---|
| 69 | private void ReverseSweep()
|
---|
| 70 | {
|
---|
| 71 | tape.Last().Adjoint = 1;
|
---|
| 72 |
|
---|
| 73 | // initialize adjoints
|
---|
| 74 | for (int i = 0; i < tape.Length - 1; ++i)
|
---|
| 75 | tape[i].Adjoint = 0;
|
---|
| 76 |
|
---|
| 77 | // accumulate adjoints
|
---|
| 78 | for (int i = tape.Length - 1; i >= Dimension; --i)
|
---|
| 79 | {
|
---|
| 80 | var inputs = tape[i].Inputs;
|
---|
| 81 | var adjoint = tape[i].Adjoint;
|
---|
| 82 |
|
---|
| 83 | for(int j = 0; j < inputs.Length; ++j)
|
---|
| 84 | tape[inputs[j].Index].Adjoint += adjoint * inputs[j].Weight;
|
---|
| 85 | }
|
---|
| 86 | }
|
---|
| 87 |
|
---|
| 88 | private void ForwardSweep<S>(S arg)
|
---|
| 89 | where S : IList<double>
|
---|
| 90 | {
|
---|
| 91 | for (int i = 0; i < Dimension; ++i)
|
---|
| 92 | tape[i].Value = arg[i];
|
---|
| 93 |
|
---|
| 94 | var forwardDiffVisitor = new ForwardSweepVisitor(tape);
|
---|
| 95 | for (int i = Dimension; i < tape.Length; ++i)
|
---|
| 96 | tape[i].Accept(forwardDiffVisitor);
|
---|
| 97 | }
|
---|
| 98 |
|
---|
| 99 | private void EvaluateTape(double[] arg)
|
---|
| 100 | {
|
---|
| 101 | for(int i = 0; i < Dimension; ++i)
|
---|
| 102 | tape[i].Value = arg[i];
|
---|
| 103 | var evalVisitor = new EvalVisitor(tape);
|
---|
| 104 | for (int i = Dimension; i < tape.Length; ++i )
|
---|
| 105 | tape[i].Accept(evalVisitor);
|
---|
| 106 | }
|
---|
| 107 |
|
---|
| 108 | private double ValueOf(int index)
|
---|
| 109 | {
|
---|
| 110 | return tape[index].Value;
|
---|
| 111 | }
|
---|
| 112 |
|
---|
| 113 | public ReadOnlyCollection<Variable> Variables { get; private set; }
|
---|
| 114 |
|
---|
| 115 |
|
---|
| 116 | }
|
---|
| 117 | }
|
---|