using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Diagnostics.Contracts;
using System.Diagnostics;
using System.Collections.ObjectModel;
namespace AutoDiff
{
///
/// Compiles the terms tree to a more efficient form for differentiation.
///
internal partial class CompiledDifferentiator : ICompiledTerm
where T : IList
{
private readonly Compiled.TapeElement[] tape;
///
/// Initializes a new instance of the class.
///
/// The function.
/// The variables.
public CompiledDifferentiator(Term function, T variables)
{
Contract.Requires(function != null);
Contract.Requires(variables != null);
Contract.Requires(Contract.ForAll(variables, variable => variable != null));
Contract.Ensures(Dimension == variables.Count);
if (function is Variable)
function = new ConstPower(function, 1);
var tapeList = new List();
new Compiler(variables, tapeList).Compile(function);
tape = tapeList.ToArray();
Dimension = variables.Count;
Variables = new ReadOnlyCollection(variables);
}
public int Dimension { get; private set; }
public double Evaluate(double[] arg)
{
EvaluateTape(arg);
return tape.Last().Value;
}
public Tuple Differentiate(S arg)
where S : IList
{
ForwardSweep(arg);
ReverseSweep();
var gradient = new double[Dimension];
for (int i = 0; i < Dimension; i++)
gradient[i] = tape[i].Adjoint;
var value = tape.Last().Value;
return Tuple.Create(gradient, value);
}
public Tuple Differentiate(params double[] arg)
{
return Differentiate(arg);
}
private void ReverseSweep()
{
tape.Last().Adjoint = 1;
// initialize adjoints
for (int i = 0; i < tape.Length - 1; ++i)
tape[i].Adjoint = 0;
// accumulate adjoints
for (int i = tape.Length - 1; i >= Dimension; --i)
{
var inputs = tape[i].Inputs;
var adjoint = tape[i].Adjoint;
for(int j = 0; j < inputs.Length; ++j)
tape[inputs[j].Index].Adjoint += adjoint * inputs[j].Weight;
}
}
private void ForwardSweep(S arg)
where S : IList
{
for (int i = 0; i < Dimension; ++i)
tape[i].Value = arg[i];
var forwardDiffVisitor = new ForwardSweepVisitor(tape);
for (int i = Dimension; i < tape.Length; ++i)
tape[i].Accept(forwardDiffVisitor);
}
private void EvaluateTape(double[] arg)
{
for(int i = 0; i < Dimension; ++i)
tape[i].Value = arg[i];
var evalVisitor = new EvalVisitor(tape);
for (int i = Dimension; i < tape.Length; ++i )
tape[i].Accept(evalVisitor);
}
private double ValueOf(int index)
{
return tape[index].Value;
}
public ReadOnlyCollection Variables { get; private set; }
}
}