[17295] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Linq;
|
---|
| 4 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
| 5 |
|
---|
| 6 | namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
|
---|
[17296] | 7 | public sealed class VectorAutoDiffEvaluator : InterpreterBase<MultivariateDual<AlgebraicDoubleVector>> {
|
---|
[17295] | 8 | private const int BATCHSIZE = 128;
|
---|
| 9 | [ThreadStatic]
|
---|
| 10 | private Dictionary<string, double[]> cachedData;
|
---|
| 11 |
|
---|
| 12 | [ThreadStatic]
|
---|
| 13 | private IDataset dataset;
|
---|
| 14 |
|
---|
| 15 | [ThreadStatic]
|
---|
| 16 | private int rowIndex;
|
---|
| 17 |
|
---|
| 18 | [ThreadStatic]
|
---|
| 19 | private int[] rows;
|
---|
| 20 |
|
---|
| 21 | [ThreadStatic]
|
---|
| 22 | private Dictionary<ISymbolicExpressionTreeNode, int> node2paramIdx;
|
---|
| 23 |
|
---|
| 24 | private void InitCache(IDataset dataset) {
|
---|
| 25 | this.dataset = dataset;
|
---|
| 26 | cachedData = new Dictionary<string, double[]>();
|
---|
| 27 | foreach (var v in dataset.DoubleVariables) {
|
---|
| 28 | cachedData[v] = dataset.GetDoubleValues(v).ToArray();
|
---|
| 29 | }
|
---|
| 30 | }
|
---|
| 31 |
|
---|
| 32 | /// <summary>
|
---|
| 33 | ///
|
---|
| 34 | /// </summary>
|
---|
| 35 | /// <param name="tree"></param>
|
---|
| 36 | /// <param name="dataset"></param>
|
---|
| 37 | /// <param name="rows"></param>
|
---|
| 38 | /// <param name="parameterNodes"></param>
|
---|
| 39 | /// <param name="fi">Function output. Must be allocated by the caller.</param>
|
---|
| 40 | /// <param name="jac">Jacobian matrix. Must be allocated by the caller.</param>
|
---|
| 41 | public void Evaluate(ISymbolicExpressionTree tree, IDataset dataset, int[] rows, ISymbolicExpressionTreeNode[] parameterNodes, double[] fi, double[,] jac) {
|
---|
| 42 | if (cachedData == null || this.dataset != dataset) {
|
---|
| 43 | InitCache(dataset);
|
---|
| 44 | }
|
---|
| 45 |
|
---|
| 46 | int nParams = parameterNodes.Length;
|
---|
| 47 | node2paramIdx = new Dictionary<ISymbolicExpressionTreeNode, int>();
|
---|
| 48 | for (int i = 0; i < parameterNodes.Length; i++) node2paramIdx.Add(parameterNodes[i], i);
|
---|
| 49 |
|
---|
| 50 | var code = Compile(tree);
|
---|
| 51 |
|
---|
| 52 | var remainingRows = rows.Length % BATCHSIZE;
|
---|
| 53 | var roundedTotal = rows.Length - remainingRows;
|
---|
| 54 |
|
---|
| 55 | this.rows = rows;
|
---|
| 56 |
|
---|
| 57 | for (rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) {
|
---|
| 58 | Evaluate(code);
|
---|
| 59 | code[0].value.Value.CopyTo(fi, rowIndex, BATCHSIZE);
|
---|
| 60 |
|
---|
| 61 | // TRANSPOSE into JAC
|
---|
| 62 | var g = code[0].value.Gradient;
|
---|
| 63 | for (int j = 0; j < nParams; ++j) {
|
---|
| 64 | if (g.Elements.TryGetValue(j, out AlgebraicDoubleVector v)) {
|
---|
| 65 | v.CopyColumnTo(jac, j, rowIndex, BATCHSIZE);
|
---|
| 66 | } else {
|
---|
| 67 | for (int r = 0; r < BATCHSIZE; r++) jac[rowIndex + r, j] = 0.0;
|
---|
| 68 | }
|
---|
| 69 | }
|
---|
| 70 | }
|
---|
| 71 |
|
---|
| 72 | if (remainingRows > 0) {
|
---|
| 73 | Evaluate(code);
|
---|
| 74 | code[0].value.Value.CopyTo(fi, roundedTotal, remainingRows);
|
---|
| 75 |
|
---|
| 76 | var g = code[0].value.Gradient;
|
---|
| 77 | for (int j = 0; j < nParams; ++j)
|
---|
| 78 | if (g.Elements.TryGetValue(j, out AlgebraicDoubleVector v)) {
|
---|
| 79 | v.CopyColumnTo(jac, j, roundedTotal, remainingRows);
|
---|
| 80 | } else {
|
---|
| 81 | for (int r = 0; r < remainingRows; r++) jac[roundedTotal + r, j] = 0.0;
|
---|
| 82 | }
|
---|
| 83 | }
|
---|
| 84 | }
|
---|
| 85 |
|
---|
| 86 | protected override void InitializeInternalInstruction(ref Instruction instruction, ISymbolicExpressionTreeNode node) {
|
---|
| 87 | var zero = new AlgebraicDoubleVector(BATCHSIZE);
|
---|
| 88 | instruction.value = new MultivariateDual<AlgebraicDoubleVector>(zero);
|
---|
| 89 | }
|
---|
| 90 |
|
---|
| 91 | protected override void InitializeTerminalInstruction(ref Instruction instruction, ConstantTreeNode constant) {
|
---|
| 92 | var g_arr = new double[BATCHSIZE];
|
---|
| 93 | if (node2paramIdx.TryGetValue(constant, out var paramIdx)) {
|
---|
| 94 | for (int i = 0; i < BATCHSIZE; i++) g_arr[i] = 1.0;
|
---|
| 95 | var g = new AlgebraicDoubleVector(g_arr);
|
---|
| 96 | instruction.value = new MultivariateDual<AlgebraicDoubleVector>(new AlgebraicDoubleVector(BATCHSIZE), paramIdx, g); // only a single column for the gradient
|
---|
| 97 | } else {
|
---|
| 98 | instruction.value = new MultivariateDual<AlgebraicDoubleVector>(new AlgebraicDoubleVector(BATCHSIZE));
|
---|
| 99 | }
|
---|
| 100 |
|
---|
| 101 | instruction.dblVal = constant.Value;
|
---|
| 102 | instruction.value.Value.AssignConstant(instruction.dblVal);
|
---|
| 103 | }
|
---|
| 104 |
|
---|
| 105 | protected override void InitializeTerminalInstruction(ref Instruction instruction, VariableTreeNode variable) {
|
---|
| 106 | double[] data;
|
---|
| 107 | if (cachedData.ContainsKey(variable.VariableName)) {
|
---|
| 108 | data = cachedData[variable.VariableName];
|
---|
| 109 | } else {
|
---|
| 110 | data = dataset.GetReadOnlyDoubleValues(variable.VariableName).ToArray();
|
---|
| 111 | cachedData[variable.VariableName] = (double[])instruction.data;
|
---|
| 112 | }
|
---|
| 113 |
|
---|
| 114 | var paramIdx = -1;
|
---|
| 115 | if (node2paramIdx.ContainsKey(variable)) {
|
---|
| 116 | paramIdx = node2paramIdx[variable];
|
---|
| 117 | var f = new AlgebraicDoubleVector(BATCHSIZE);
|
---|
| 118 | var g = new AlgebraicDoubleVector(BATCHSIZE);
|
---|
| 119 | instruction.value = new MultivariateDual<AlgebraicDoubleVector>(f, paramIdx, g);
|
---|
| 120 | } else {
|
---|
| 121 | var f = new AlgebraicDoubleVector(BATCHSIZE);
|
---|
| 122 | instruction.value = new MultivariateDual<AlgebraicDoubleVector>(f);
|
---|
| 123 | }
|
---|
| 124 |
|
---|
| 125 | instruction.dblVal = variable.Weight;
|
---|
| 126 | instruction.data = new object[] { data, paramIdx };
|
---|
| 127 | }
|
---|
| 128 |
|
---|
| 129 | protected override void LoadVariable(Instruction a) {
|
---|
| 130 | var paramIdx = (int)((object[])a.data)[1];
|
---|
| 131 | var data = (double[])((object[])a.data)[0];
|
---|
| 132 |
|
---|
| 133 | for (int i = rowIndex; i < rows.Length && i - rowIndex < BATCHSIZE; i++) a.value.Value[i - rowIndex] = data[rows[i]];
|
---|
| 134 | a.value.Scale(a.dblVal);
|
---|
| 135 |
|
---|
| 136 | if (paramIdx >= 0) {
|
---|
| 137 | // update gradient with variable values
|
---|
| 138 | var g = a.value.Gradient.Elements[paramIdx];
|
---|
| 139 | for (int i = rowIndex; i < rows.Length && i - rowIndex < BATCHSIZE; i++) {
|
---|
| 140 | g[i - rowIndex] = data[rows[i]];
|
---|
| 141 | }
|
---|
| 142 | }
|
---|
| 143 | }
|
---|
| 144 | }
|
---|
| 145 | } |
---|