1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Linq;
|
---|
4 | using System.Text;
|
---|
5 | using System.Diagnostics.Contracts;
|
---|
6 |
|
---|
7 | namespace AutoDiff
|
---|
8 | {
|
---|
9 | /// <summary>
|
---|
10 | /// Static methods that operate on terms.
|
---|
11 | /// </summary>
|
---|
12 | public static class TermUtils
|
---|
13 | {
|
---|
14 | /// <summary>
|
---|
15 | /// Creates a compiled representation of a given term that allows efficient evaluation of the value/gradient.
|
---|
16 | /// </summary>
|
---|
17 | /// <param name="term">The term to compile.</param>
|
---|
18 | /// <param name="variables">The variables contained in the term.</param>
|
---|
19 | /// <returns>A compiled representation of <paramref name="term"/> that assigns values to variables in the same order
|
---|
20 | /// as in <paramref name="variables"/></returns>
|
---|
21 | /// <remarks>
|
---|
22 | /// The order of the variables in <paramref name="variables"/> is important. Each call to <c>ICompiledTerm.Evaluate</c> or
|
---|
23 | /// <c>ICompiledTerm.Differentiate</c> receives an array of numbers representing the point of evaluation. The i'th number in this array corresponds
|
---|
24 | /// to the i'th variable in <c>variables</c>.
|
---|
25 | /// </remarks>
|
---|
26 | public static ICompiledTerm Compile(this Term term, params Variable[] variables)
|
---|
27 | {
|
---|
28 | return Compile<Variable[]>(term, variables);
|
---|
29 | }
|
---|
30 |
|
---|
31 | /// <summary>
|
---|
32 | /// Creates a compiled representation of a given term that allows efficient evaluation of the value/gradient.
|
---|
33 | /// </summary>
|
---|
34 | /// <param name="term">The term to compile.</param>
|
---|
35 | /// <param name="variables">The variables contained in the term.</param>
|
---|
36 | /// <returns>A compiled representation of <paramref name="term"/> that assigns values to variables in the same order
|
---|
37 | /// as in <paramref name="variables"/></returns>
|
---|
38 | /// <remarks>
|
---|
39 | /// The order of the variables in <paramref name="variables"/> is important. Each call to <c>ICompiledTerm.Evaluate</c> or
|
---|
40 | /// <c>ICompiledTerm.Differentiate</c> receives an array of numbers representing the point of evaluation. The i'th number in this array corresponds
|
---|
41 | /// to the i'th variable in <c>variables</c>.
|
---|
42 | /// </remarks>
|
---|
43 | public static ICompiledTerm Compile<T>(this Term term, T variables)
|
---|
44 | where T : IList<Variable>
|
---|
45 | {
|
---|
46 | Contract.Requires(variables != null);
|
---|
47 | Contract.Requires(term != null);
|
---|
48 | Contract.Ensures(Contract.Result<ICompiledTerm>() != null);
|
---|
49 | Contract.Ensures(Contract.Result<ICompiledTerm>().Variables.Count == variables.Count);
|
---|
50 | Contract.Ensures(Contract.ForAll(0, variables.Count, i => variables[i] == Contract.Result<ICompiledTerm>().Variables[i]));
|
---|
51 |
|
---|
52 | return new CompiledDifferentiator<T>(term, variables);
|
---|
53 | }
|
---|
54 |
|
---|
55 | /// <summary>
|
---|
56 | /// Creates a compiled representation of a given term that allows efficient evaluation of the value/gradient where part of the variables serve as function
|
---|
57 | /// inputs and other variables serve as constant parameters.
|
---|
58 | /// </summary>
|
---|
59 | /// <param name="term">The term to compile.</param>
|
---|
60 | /// <param name="variables">The variables contained in the term.</param>
|
---|
61 | /// <param name="parameters">The constant parameters in the term.</param>
|
---|
62 | /// <returns>A compiled representation of <paramref name="term"/> that assigns values to variables in the same order
|
---|
63 | /// as in <paramref name="variables"/> and <paramref name="parameters"/></returns>
|
---|
64 | /// <remarks>
|
---|
65 | /// The order of the variables in <paramref name="variables"/> is important. Each call to <c>ICompiledTerm.Evaluate</c> or
|
---|
66 | /// <c>ICompiledTerm.Differentiate</c> receives an array of numbers representing the point of evaluation. The i'th number in this array corresponds
|
---|
67 | /// to the i'th variable in <c>variables</c>.
|
---|
68 | /// </remarks>
|
---|
69 | public static IParametricCompiledTerm Compile(this Term term, Variable[] variables, Variable[] parameters)
|
---|
70 | {
|
---|
71 | Contract.Requires(variables != null);
|
---|
72 | Contract.Requires(parameters != null);
|
---|
73 | Contract.Requires(term != null);
|
---|
74 | Contract.Ensures(Contract.Result<IParametricCompiledTerm>() != null);
|
---|
75 | Contract.Ensures(Contract.Result<IParametricCompiledTerm>().Variables.Count == variables.Length);
|
---|
76 | Contract.Ensures(Contract.ForAll(0, variables.Length, i => variables[i] == Contract.Result<IParametricCompiledTerm>().Variables[i]));
|
---|
77 | Contract.Ensures(Contract.Result<IParametricCompiledTerm>().Parameters.Count == parameters.Length);
|
---|
78 | Contract.Ensures(Contract.ForAll(0, parameters.Length, i => parameters[i] == Contract.Result<IParametricCompiledTerm>().Parameters[i]));
|
---|
79 |
|
---|
80 | return new ParametricCompiledTerm(term, variables, parameters);
|
---|
81 | }
|
---|
82 |
|
---|
83 | /// <summary>
|
---|
84 | /// Evaluates the function represented by a given term at a given point.
|
---|
85 | /// </summary>
|
---|
86 | /// <param name="term">The term representing the function to evaluate.</param>
|
---|
87 | /// <param name="variables">The variables used in <paramref name="term"/>.</param>
|
---|
88 | /// <param name="point">The values assigned to the variables in <paramref name="variables"/></param>
|
---|
89 | /// <returns>The value of the function represented by <paramref name="term"/> at the point represented by <paramref name="variables"/>
|
---|
90 | /// and <paramref name="point"/>.</returns>
|
---|
91 | /// <remarks>The i'th value in <c>point</c> corresponds to the i'th variable in <c>variables</c>.</remarks>
|
---|
92 | public static double Evaluate(this Term term, Variable[] variables, double[] point)
|
---|
93 | {
|
---|
94 | Contract.Requires(term != null);
|
---|
95 | Contract.Requires(variables != null);
|
---|
96 | Contract.Requires(point != null);
|
---|
97 | Contract.Requires(variables.Length == point.Length);
|
---|
98 |
|
---|
99 | return term.Compile(variables).Evaluate(point);
|
---|
100 | }
|
---|
101 |
|
---|
102 | /// <summary>
|
---|
103 | /// Computes the gradient of the function represented by a given term at a given point.
|
---|
104 | /// </summary>
|
---|
105 | /// <param name="term">The term representing the function to differentiate.</param>
|
---|
106 | /// <param name="variables">The variables used in <paramref name="term"/>.</param>
|
---|
107 | /// <param name="point">The values assigned to the variables in <paramref name="variables"/></param>
|
---|
108 | /// <returns>The gradient of the function represented by <paramref name="term"/> at the point represented by <paramref name="variables"/>
|
---|
109 | /// and <paramref name="point"/>.</returns>
|
---|
110 | /// <remarks>The i'th value in <c>point</c> corresponds to the i'th variable in <c>variables</c>. In addition, the i'th value
|
---|
111 | /// in the resulting array is the partial derivative with respect to the i'th variable in <c>variables</c>.</remarks>
|
---|
112 | public static double[] Differentiate(this Term term, Variable[] variables, double[] point)
|
---|
113 | {
|
---|
114 | Contract.Requires(term != null);
|
---|
115 | Contract.Requires(variables != null);
|
---|
116 | Contract.Requires(point != null);
|
---|
117 | Contract.Requires(variables.Length == point.Length);
|
---|
118 | Contract.Ensures(Contract.Result<double[]>() != null);
|
---|
119 | Contract.Ensures(Contract.Result<double[]>().Length == variables.Length);
|
---|
120 |
|
---|
121 | var result = term.Compile(variables).Differentiate(point).Item1;
|
---|
122 | return result;
|
---|
123 | }
|
---|
124 | }
|
---|
125 | }
|
---|