1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Linq;
|
---|
4 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
5 |
|
---|
6 | namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
|
---|
7 | public sealed class VectorAutoDiffEvaluator : InterpreterBase<VectorOfAlgebraic<MultivariateDual<AlgebraicDouble>>> {
|
---|
8 | private const int BATCHSIZE = 128;
|
---|
9 | [ThreadStatic]
|
---|
10 | private Dictionary<string, double[]> cachedData;
|
---|
11 |
|
---|
12 | [ThreadStatic]
|
---|
13 | private IDataset dataset;
|
---|
14 |
|
---|
15 | [ThreadStatic]
|
---|
16 | private int rowIndex;
|
---|
17 |
|
---|
18 | [ThreadStatic]
|
---|
19 | private int[] rows;
|
---|
20 |
|
---|
21 | [ThreadStatic]
|
---|
22 | private Dictionary<ISymbolicExpressionTreeNode, int> node2paramIdx;
|
---|
23 |
|
---|
24 | private void InitCache(IDataset dataset) {
|
---|
25 | this.dataset = dataset;
|
---|
26 | cachedData = new Dictionary<string, double[]>();
|
---|
27 | foreach (var v in dataset.DoubleVariables) {
|
---|
28 | cachedData[v] = dataset.GetDoubleValues(v).ToArray();
|
---|
29 | }
|
---|
30 | }
|
---|
31 |
|
---|
32 | /// <summary>
|
---|
33 | ///
|
---|
34 | /// </summary>
|
---|
35 | /// <param name="tree"></param>
|
---|
36 | /// <param name="dataset"></param>
|
---|
37 | /// <param name="rows"></param>
|
---|
38 | /// <param name="parameterNodes"></param>
|
---|
39 | /// <param name="fi">Function output. Must be allocated by the caller.</param>
|
---|
40 | /// <param name="jac">Jacobian matrix. Must be allocated by the caller.</param>
|
---|
41 | public void Evaluate(ISymbolicExpressionTree tree, IDataset dataset, int[] rows, ISymbolicExpressionTreeNode[] parameterNodes, double[] fi, double[,] jac) {
|
---|
42 | if (cachedData == null || this.dataset != dataset) {
|
---|
43 | InitCache(dataset);
|
---|
44 | }
|
---|
45 |
|
---|
46 | int nParams = parameterNodes.Length;
|
---|
47 | node2paramIdx = new Dictionary<ISymbolicExpressionTreeNode, int>();
|
---|
48 | for (int i = 0; i < parameterNodes.Length; i++) node2paramIdx.Add(parameterNodes[i], i);
|
---|
49 |
|
---|
50 | var code = Compile(tree);
|
---|
51 |
|
---|
52 | var remainingRows = rows.Length % BATCHSIZE;
|
---|
53 | var roundedTotal = rows.Length - remainingRows;
|
---|
54 |
|
---|
55 | this.rows = rows;
|
---|
56 |
|
---|
57 | for (rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) {
|
---|
58 | Evaluate(code);
|
---|
59 |
|
---|
60 | // code[0].value.Value.CopyTo(fi, rowIndex, BATCHSIZE);
|
---|
61 | var v = code[0].value;
|
---|
62 | for (int k = 0; k < BATCHSIZE; k++) {
|
---|
63 | fi[rowIndex + k] = v[k].Value.Value;
|
---|
64 |
|
---|
65 | // copy gradient to Jacobian
|
---|
66 | var g = v[k].Gradient;
|
---|
67 | for (int j = 0; j < nParams; ++j) {
|
---|
68 | if (g.Elements.TryGetValue(j, out AlgebraicDouble gj)) {
|
---|
69 | jac[rowIndex + k, j] = gj.Value;
|
---|
70 | } else {
|
---|
71 | jac[rowIndex + k, j] = 0.0;
|
---|
72 | }
|
---|
73 | }
|
---|
74 | }
|
---|
75 | }
|
---|
76 |
|
---|
77 | if (remainingRows > 0) {
|
---|
78 | Evaluate(code);
|
---|
79 | // code[0].value.Value.CopyTo(fi, roundedTotal, remainingRows);
|
---|
80 | var v = code[0].value;
|
---|
81 | for (int k = 0; k < remainingRows; k++) {
|
---|
82 | fi[roundedTotal + k] = v[k].Value.Value;
|
---|
83 |
|
---|
84 | var g = v[k].Gradient;
|
---|
85 | for (int j = 0; j < nParams; ++j) {
|
---|
86 | if (g.Elements.TryGetValue(j, out AlgebraicDouble gj)) {
|
---|
87 | jac[roundedTotal + k, j] = gj.Value;
|
---|
88 | } else {
|
---|
89 | jac[roundedTotal + k, j] = 0.0;
|
---|
90 | }
|
---|
91 | }
|
---|
92 | }
|
---|
93 | }
|
---|
94 | }
|
---|
95 |
|
---|
96 | protected override void InitializeInternalInstruction(ref Instruction instruction, ISymbolicExpressionTreeNode node) {
|
---|
97 | instruction.value = new VectorOfAlgebraic<MultivariateDual<AlgebraicDouble>>(BATCHSIZE).Zero; // XXX zero needed?
|
---|
98 | }
|
---|
99 |
|
---|
100 | protected override void InitializeTerminalInstruction(ref Instruction instruction, ConstantTreeNode constant) {
|
---|
101 | if (node2paramIdx.TryGetValue(constant, out var paramIdx)) {
|
---|
102 | instruction.value = new VectorOfAlgebraic<MultivariateDual<AlgebraicDouble>>(BATCHSIZE);
|
---|
103 | for (int k = 0; k < BATCHSIZE; k++) {
|
---|
104 | instruction.value[k] = new MultivariateDual<AlgebraicDouble>(constant.Value, paramIdx, 1.0); // gradient is 1.0 for all elements
|
---|
105 | }
|
---|
106 | } else {
|
---|
107 | instruction.value = new VectorOfAlgebraic<MultivariateDual<AlgebraicDouble>>(BATCHSIZE);
|
---|
108 | for (int k = 0; k < BATCHSIZE; k++) {
|
---|
109 | instruction.value[k] = new MultivariateDual<AlgebraicDouble>(constant.Value); // zero gradient
|
---|
110 | }
|
---|
111 | }
|
---|
112 |
|
---|
113 | instruction.dblVal = constant.Value; // also store the parameter value in the instruction (not absolutely necessary, will not be used)
|
---|
114 | }
|
---|
115 |
|
---|
116 | protected override void InitializeTerminalInstruction(ref Instruction instruction, VariableTreeNode variable) {
|
---|
117 | double[] data;
|
---|
118 | if (cachedData.ContainsKey(variable.VariableName)) {
|
---|
119 | data = cachedData[variable.VariableName];
|
---|
120 | } else {
|
---|
121 | data = dataset.GetReadOnlyDoubleValues(variable.VariableName).ToArray();
|
---|
122 | cachedData[variable.VariableName] = (double[])instruction.data;
|
---|
123 | }
|
---|
124 |
|
---|
125 | var paramIdx = -1;
|
---|
126 | if (node2paramIdx.ContainsKey(variable)) {
|
---|
127 | paramIdx = node2paramIdx[variable];
|
---|
128 | instruction.value = new VectorOfAlgebraic<MultivariateDual<AlgebraicDouble>>(BATCHSIZE);
|
---|
129 | for(int k=0;k<BATCHSIZE;k++) {
|
---|
130 | instruction.value[k] = new MultivariateDual<AlgebraicDouble>(0.0, paramIdx, 0.0); // values are set in LoadVariable()
|
---|
131 | }
|
---|
132 | } else {
|
---|
133 | var f = new AlgebraicDoubleVector(BATCHSIZE);
|
---|
134 | instruction.value = new VectorOfAlgebraic<MultivariateDual<AlgebraicDouble>>(BATCHSIZE);
|
---|
135 | for (int k = 0; k < BATCHSIZE; k++) {
|
---|
136 | instruction.value[k] = new MultivariateDual<AlgebraicDouble>(0.0); // values are set in LoadVariable()
|
---|
137 | }
|
---|
138 | }
|
---|
139 |
|
---|
140 | instruction.dblVal = variable.Weight;
|
---|
141 | instruction.data = new object[] { data, paramIdx };
|
---|
142 | }
|
---|
143 |
|
---|
144 | protected override void LoadVariable(Instruction a) {
|
---|
145 | var paramIdx = (int)((object[])a.data)[1];
|
---|
146 | var data = (double[])((object[])a.data)[0];
|
---|
147 |
|
---|
148 | for (int i = rowIndex; i < rows.Length && i - rowIndex < BATCHSIZE; i++) {
|
---|
149 | a.value[i - rowIndex].Value.Assign(a.dblVal * data[rows[i]]);
|
---|
150 | }
|
---|
151 |
|
---|
152 | if (paramIdx >= 0) {
|
---|
153 | // update gradient with variable values
|
---|
154 | for (int i = rowIndex; i < rows.Length && i - rowIndex < BATCHSIZE; i++) {
|
---|
155 | a.value[i - rowIndex].Gradient.Elements[paramIdx].Assign(data[rows[i]]);
|
---|
156 | }
|
---|
157 | }
|
---|
158 | }
|
---|
159 | }
|
---|
160 | } |
---|