source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/VectorEvaluator.cs @ 17295

Last change on this file since 17295 was 17295, checked in by gkronber, 3 years ago

#2994: refactoring: moved types into separate files

File size: 2.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
5
6namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
7  public sealed class VectorEvaluator : Interpreter<AlgebraicDoubleVector> {
8    private const int BATCHSIZE = 128;
9    [ThreadStatic]
10    private Dictionary<string, double[]> cachedData;
11
12    [ThreadStatic]
13    private IDataset dataset;
14
15    [ThreadStatic]
16    private int rowIndex;
17
18    [ThreadStatic]
19    private int[] rows;
20
21    private void InitCache(IDataset dataset) {
22      this.dataset = dataset;
23      cachedData = new Dictionary<string, double[]>();
24      foreach (var v in dataset.DoubleVariables) {
25        cachedData[v] = dataset.GetReadOnlyDoubleValues(v).ToArray();
26      }
27    }
28
29    public double[] Evaluate(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
30      if (cachedData == null || this.dataset != dataset) {
31        InitCache(dataset);
32      }
33
34      this.rows = rows;
35      var code = Compile(tree);
36      var remainingRows = rows.Length % BATCHSIZE;
37      var roundedTotal = rows.Length - remainingRows;
38
39      var result = new double[rows.Length];
40
41      for (rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) {
42        Evaluate(code);
43        code[0].value.CopyTo(result, rowIndex, BATCHSIZE);
44      }
45
46      if (remainingRows > 0) {
47        Evaluate(code);
48        code[0].value.CopyTo(result, roundedTotal, remainingRows);
49      }
50
51      return result;
52    }
53
54    protected override void InitializeTerminalInstruction(ref Instruction instruction, ConstantTreeNode constant) {
55      instruction.dblVal = constant.Value;
56      instruction.value = new AlgebraicDoubleVector(BATCHSIZE);
57      instruction.value.AssignConstant(instruction.dblVal);
58    }
59
60    protected override void InitializeTerminalInstruction(ref Instruction instruction, VariableTreeNode variable) {
61      instruction.dblVal = variable.Weight;
62      instruction.value = new AlgebraicDoubleVector(BATCHSIZE);
63      if (cachedData.ContainsKey(variable.VariableName)) {
64        instruction.data = cachedData[variable.VariableName];
65      } else {
66        instruction.data = dataset.GetDoubleValues(variable.VariableName).ToArray();
67        cachedData[variable.VariableName] = (double[])instruction.data;
68      }
69    }
70
71    protected override void InitializeInternalInstruction(ref Instruction instruction, ISymbolicExpressionTreeNode node) {
72      instruction.value = new AlgebraicDoubleVector(BATCHSIZE);
73    }
74
75    protected override void LoadVariable(Instruction a) {
76      var data = (double[])a.data;
77      for (int i = rowIndex; i < rows.Length && i - rowIndex < BATCHSIZE; i++) a.value[i - rowIndex] = data[rows[i]];
78      a.value.Scale(a.dblVal);
79    }
80  }
81}
Note: See TracBrowser for help on using the repository browser.