source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/VectorAutoDiffEvaluator.cs @ 17296

Last change on this file since 17296 was 17296, checked in by gkronber, 3 years ago

#2994 rename of abstract class Interpreter -> InterpreterBase

File size: 5.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
5
6namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
7  public sealed class VectorAutoDiffEvaluator : InterpreterBase<MultivariateDual<AlgebraicDoubleVector>> {
8    private const int BATCHSIZE = 128;
9    [ThreadStatic]
10    private Dictionary<string, double[]> cachedData;
11
12    [ThreadStatic]
13    private IDataset dataset;
14
15    [ThreadStatic]
16    private int rowIndex;
17
18    [ThreadStatic]
19    private int[] rows;
20
21    [ThreadStatic]
22    private Dictionary<ISymbolicExpressionTreeNode, int> node2paramIdx;
23
24    private void InitCache(IDataset dataset) {
25      this.dataset = dataset;
26      cachedData = new Dictionary<string, double[]>();
27      foreach (var v in dataset.DoubleVariables) {
28        cachedData[v] = dataset.GetDoubleValues(v).ToArray();
29      }
30    }
31
32    /// <summary>
33    ///
34    /// </summary>
35    /// <param name="tree"></param>
36    /// <param name="dataset"></param>
37    /// <param name="rows"></param>
38    /// <param name="parameterNodes"></param>
39    /// <param name="fi">Function output. Must be allocated by the caller.</param>
40    /// <param name="jac">Jacobian matrix. Must be allocated by the caller.</param>
41    public void Evaluate(ISymbolicExpressionTree tree, IDataset dataset, int[] rows, ISymbolicExpressionTreeNode[] parameterNodes, double[] fi, double[,] jac) {
42      if (cachedData == null || this.dataset != dataset) {
43        InitCache(dataset);
44      }
45
46      int nParams = parameterNodes.Length;
47      node2paramIdx = new Dictionary<ISymbolicExpressionTreeNode, int>();
48      for (int i = 0; i < parameterNodes.Length; i++) node2paramIdx.Add(parameterNodes[i], i);
49
50      var code = Compile(tree);
51
52      var remainingRows = rows.Length % BATCHSIZE;
53      var roundedTotal = rows.Length - remainingRows;
54
55      this.rows = rows;
56
57      for (rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) {
58        Evaluate(code);
59        code[0].value.Value.CopyTo(fi, rowIndex, BATCHSIZE);
60
61        // TRANSPOSE into JAC
62        var g = code[0].value.Gradient;
63        for (int j = 0; j < nParams; ++j) {
64          if (g.Elements.TryGetValue(j, out AlgebraicDoubleVector v)) {
65            v.CopyColumnTo(jac, j, rowIndex, BATCHSIZE);
66          } else {
67            for (int r = 0; r < BATCHSIZE; r++) jac[rowIndex + r, j] = 0.0;
68          }
69        }
70      }
71
72      if (remainingRows > 0) {
73        Evaluate(code);
74        code[0].value.Value.CopyTo(fi, roundedTotal, remainingRows);
75
76        var g = code[0].value.Gradient;
77        for (int j = 0; j < nParams; ++j)
78          if (g.Elements.TryGetValue(j, out AlgebraicDoubleVector v)) {
79            v.CopyColumnTo(jac, j, roundedTotal, remainingRows);
80          } else {
81            for (int r = 0; r < remainingRows; r++) jac[roundedTotal + r, j] = 0.0;
82          }
83      }
84    }
85
86    protected override void InitializeInternalInstruction(ref Instruction instruction, ISymbolicExpressionTreeNode node) {
87      var zero = new AlgebraicDoubleVector(BATCHSIZE);
88      instruction.value = new MultivariateDual<AlgebraicDoubleVector>(zero);
89    }
90
91    protected override void InitializeTerminalInstruction(ref Instruction instruction, ConstantTreeNode constant) {
92      var g_arr = new double[BATCHSIZE];
93      if (node2paramIdx.TryGetValue(constant, out var paramIdx)) {
94        for (int i = 0; i < BATCHSIZE; i++) g_arr[i] = 1.0;
95        var g = new AlgebraicDoubleVector(g_arr);
96        instruction.value = new MultivariateDual<AlgebraicDoubleVector>(new AlgebraicDoubleVector(BATCHSIZE), paramIdx, g); // only a single column for the gradient
97      } else {
98        instruction.value = new MultivariateDual<AlgebraicDoubleVector>(new AlgebraicDoubleVector(BATCHSIZE));
99      }
100
101      instruction.dblVal = constant.Value;
102      instruction.value.Value.AssignConstant(instruction.dblVal);
103    }
104
105    protected override void InitializeTerminalInstruction(ref Instruction instruction, VariableTreeNode variable) {
106      double[] data;
107      if (cachedData.ContainsKey(variable.VariableName)) {
108        data = cachedData[variable.VariableName];
109      } else {
110        data = dataset.GetReadOnlyDoubleValues(variable.VariableName).ToArray();
111        cachedData[variable.VariableName] = (double[])instruction.data;
112      }
113
114      var paramIdx = -1;
115      if (node2paramIdx.ContainsKey(variable)) {
116        paramIdx = node2paramIdx[variable];
117        var f = new AlgebraicDoubleVector(BATCHSIZE);
118        var g = new AlgebraicDoubleVector(BATCHSIZE);
119        instruction.value = new MultivariateDual<AlgebraicDoubleVector>(f, paramIdx, g);
120      } else {
121        var f = new AlgebraicDoubleVector(BATCHSIZE);
122        instruction.value = new MultivariateDual<AlgebraicDoubleVector>(f);
123      }
124
125      instruction.dblVal = variable.Weight;
126      instruction.data = new object[] { data, paramIdx };
127    }
128
129    protected override void LoadVariable(Instruction a) {
130      var paramIdx = (int)((object[])a.data)[1];
131      var data = (double[])((object[])a.data)[0];
132
133      for (int i = rowIndex; i < rows.Length && i - rowIndex < BATCHSIZE; i++) a.value.Value[i - rowIndex] = data[rows[i]];
134      a.value.Scale(a.dblVal);
135
136      if (paramIdx >= 0) {
137        // update gradient with variable values
138        var g = a.value.Gradient.Elements[paramIdx];
139        for (int i = rowIndex; i < rows.Length && i - rowIndex < BATCHSIZE; i++) {
140          g[i - rowIndex] = data[rows[i]];
141        }
142      }
143    }
144  }
145}
Note: See TracBrowser for help on using the repository browser.