1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Linq;
|
---|
4 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
5 |
|
---|
6 | namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
|
---|
7 | public sealed class VectorEvaluator : Interpreter<AlgebraicDoubleVector> {
|
---|
8 | private const int BATCHSIZE = 128;
|
---|
9 | [ThreadStatic]
|
---|
10 | private Dictionary<string, double[]> cachedData;
|
---|
11 |
|
---|
12 | [ThreadStatic]
|
---|
13 | private IDataset dataset;
|
---|
14 |
|
---|
15 | [ThreadStatic]
|
---|
16 | private int rowIndex;
|
---|
17 |
|
---|
18 | [ThreadStatic]
|
---|
19 | private int[] rows;
|
---|
20 |
|
---|
21 | private void InitCache(IDataset dataset) {
|
---|
22 | this.dataset = dataset;
|
---|
23 | cachedData = new Dictionary<string, double[]>();
|
---|
24 | foreach (var v in dataset.DoubleVariables) {
|
---|
25 | cachedData[v] = dataset.GetReadOnlyDoubleValues(v).ToArray();
|
---|
26 | }
|
---|
27 | }
|
---|
28 |
|
---|
29 | public double[] Evaluate(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
|
---|
30 | if (cachedData == null || this.dataset != dataset) {
|
---|
31 | InitCache(dataset);
|
---|
32 | }
|
---|
33 |
|
---|
34 | this.rows = rows;
|
---|
35 | var code = Compile(tree);
|
---|
36 | var remainingRows = rows.Length % BATCHSIZE;
|
---|
37 | var roundedTotal = rows.Length - remainingRows;
|
---|
38 |
|
---|
39 | var result = new double[rows.Length];
|
---|
40 |
|
---|
41 | for (rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) {
|
---|
42 | Evaluate(code);
|
---|
43 | code[0].value.CopyTo(result, rowIndex, BATCHSIZE);
|
---|
44 | }
|
---|
45 |
|
---|
46 | if (remainingRows > 0) {
|
---|
47 | Evaluate(code);
|
---|
48 | code[0].value.CopyTo(result, roundedTotal, remainingRows);
|
---|
49 | }
|
---|
50 |
|
---|
51 | return result;
|
---|
52 | }
|
---|
53 |
|
---|
54 | protected override void InitializeTerminalInstruction(ref Instruction instruction, ConstantTreeNode constant) {
|
---|
55 | instruction.dblVal = constant.Value;
|
---|
56 | instruction.value = new AlgebraicDoubleVector(BATCHSIZE);
|
---|
57 | instruction.value.AssignConstant(instruction.dblVal);
|
---|
58 | }
|
---|
59 |
|
---|
60 | protected override void InitializeTerminalInstruction(ref Instruction instruction, VariableTreeNode variable) {
|
---|
61 | instruction.dblVal = variable.Weight;
|
---|
62 | instruction.value = new AlgebraicDoubleVector(BATCHSIZE);
|
---|
63 | if (cachedData.ContainsKey(variable.VariableName)) {
|
---|
64 | instruction.data = cachedData[variable.VariableName];
|
---|
65 | } else {
|
---|
66 | instruction.data = dataset.GetDoubleValues(variable.VariableName).ToArray();
|
---|
67 | cachedData[variable.VariableName] = (double[])instruction.data;
|
---|
68 | }
|
---|
69 | }
|
---|
70 |
|
---|
71 | protected override void InitializeInternalInstruction(ref Instruction instruction, ISymbolicExpressionTreeNode node) {
|
---|
72 | instruction.value = new AlgebraicDoubleVector(BATCHSIZE);
|
---|
73 | }
|
---|
74 |
|
---|
75 | protected override void LoadVariable(Instruction a) {
|
---|
76 | var data = (double[])a.data;
|
---|
77 | for (int i = rowIndex; i < rows.Length && i - rowIndex < BATCHSIZE; i++) a.value[i - rowIndex] = data[rows[i]];
|
---|
78 | a.value.Scale(a.dblVal);
|
---|
79 | }
|
---|
80 | }
|
---|
81 | } |
---|