Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeBatchInterpreter.cs @ 16444

Last change on this file since 16444 was 16378, checked in by bburlacu, 6 years ago

#2958: Batch and Native interpreter: keep a cached reference to the dataset so we can detect when it changes.

File size: 8.9 KB
RevLine 
[16285]1using System;
2using System.Collections.Generic;
3using System.Linq;
4
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
9using HeuristicLab.Parameters;
10using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
11
12using static HeuristicLab.Problems.DataAnalysis.Symbolic.BatchOperations;
13
14namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
15  [Item("SymbolicDataAnalysisExpressionTreeBatchInterpreter", "An interpreter that uses batching and vectorization techniques to achieve faster performance.")]
16  [StorableClass]
17  public class SymbolicDataAnalysisExpressionTreeBatchInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
18    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
19
20    #region parameters
21    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
22      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
23    }
24    #endregion
25
26    #region properties
27    public int EvaluatedSolutions {
28      get { return EvaluatedSolutionsParameter.Value.Value; }
29      set { EvaluatedSolutionsParameter.Value.Value = value; }
30    }
31    #endregion
32
33    public void ClearState() { }
34
35    public SymbolicDataAnalysisExpressionTreeBatchInterpreter() {
36      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
37    }
38
39    [StorableConstructor]
40    protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(bool deserializing) : base(deserializing) { }
41    protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(SymbolicDataAnalysisExpressionTreeBatchInterpreter original, Cloner cloner) : base(original, cloner) {
42    }
43    public override IDeepCloneable Clone(Cloner cloner) {
44      return new SymbolicDataAnalysisExpressionTreeBatchInterpreter(this, cloner);
45    }
46
47    private void LoadData(BatchInstruction instr, int[] rows, int rowIndex, int batchSize) {
48      for (int i = 0; i < batchSize; ++i) {
49        var row = rows[rowIndex] + i;
50        instr.buf[i] = instr.weight * instr.data[row];
51      }
52    }
53
54    private void Evaluate(BatchInstruction[] code, int[] rows, int rowIndex, int batchSize) {
55      for (int i = code.Length - 1; i >= 0; --i) {
56        var instr = code[i];
57        var c = instr.childIndex;
58        var n = instr.narg;
59
60        switch (instr.opcode) {
61          case OpCodes.Variable: {
62              LoadData(instr, rows, rowIndex, batchSize);
63              break;
64            }
[16293]65
[16285]66          case OpCodes.Add: {
67              Load(instr.buf, code[c].buf);
68              for (int j = 1; j < n; ++j) {
69                Add(instr.buf, code[c + j].buf);
70              }
71              break;
72            }
73
74          case OpCodes.Sub: {
75              if (n == 1) {
76                Neg(instr.buf, code[c].buf);
77              } else {
78                Load(instr.buf, code[c].buf);
79                for (int j = 1; j < n; ++j) {
80                  Sub(instr.buf, code[c + j].buf);
81                }
82              }
[16293]83              break;
[16285]84            }
85
86          case OpCodes.Mul: {
87              Load(instr.buf, code[c].buf);
88              for (int j = 1; j < n; ++j) {
89                Mul(instr.buf, code[c + j].buf);
90              }
91              break;
92            }
93
94          case OpCodes.Div: {
95              if (n == 1) {
96                Inv(instr.buf, code[c].buf);
97              } else {
98                Load(instr.buf, code[c].buf);
99                for (int j = 1; j < n; ++j) {
100                  Div(instr.buf, code[c + j].buf);
101                }
102              }
[16293]103              break;
[16285]104            }
105
[16293]106          case OpCodes.Square: {
107              Square(instr.buf, code[c].buf);
108              break;
109            }
110
111          case OpCodes.Root: {
[16356]112              Load(instr.buf, code[c].buf);
113              Root(instr.buf, code[c + 1].buf);
[16293]114              break;
115            }
116
117          case OpCodes.SquareRoot: {
118              Sqrt(instr.buf, code[c].buf);
119              break;
120            }
121
[16356]122          case OpCodes.Cube: {
123              Cube(instr.buf, code[c].buf);
124              break;
125            }
126          case OpCodes.CubeRoot: {
127              CubeRoot(instr.buf, code[c].buf);
128              break;
129            }
130
[16293]131          case OpCodes.Power: {
[16356]132              Load(instr.buf, code[c].buf);
133              Pow(instr.buf, code[c + 1].buf);
[16293]134              break;
135            }
136
[16285]137          case OpCodes.Exp: {
138              Exp(instr.buf, code[c].buf);
139              break;
140            }
141
142          case OpCodes.Log: {
143              Log(instr.buf, code[c].buf);
144              break;
145            }
[16293]146
147          case OpCodes.Sin: {
148              Sin(instr.buf, code[c].buf);
149              break;
150            }
151
152          case OpCodes.Cos: {
153              Cos(instr.buf, code[c].buf);
154              break;
155            }
156
157          case OpCodes.Tan: {
158              Tan(instr.buf, code[c].buf);
159              break;
160            }
[16356]161
162          case OpCodes.Absolute: {
163              Absolute(instr.buf, code[c].buf);
164              break;
165            }
166
[16360]167          case OpCodes.AnalyticQuotient: {
[16356]168              Load(instr.buf, code[c].buf);
169              AnalyticQuotient(instr.buf, code[c + 1].buf);
170              break;
171            }
[16285]172        }
173      }
174    }
175
[16378]176    private readonly object syncRoot = new object();
177
[16296]178    [ThreadStatic]
179    private Dictionary<string, double[]> cachedData;
180
[16378]181    [ThreadStatic]
182    private IDataset dataset;
183
[16296]184    private void InitCache(IDataset dataset) {
[16378]185      this.dataset = dataset;
[16296]186      cachedData = new Dictionary<string, double[]>();
187      foreach (var v in dataset.DoubleVariables) {
[16301]188        cachedData[v] = dataset.GetDoubleValues(v).ToArray();
[16296]189      }
190    }
191
192    public void InitializeState() {
193      cachedData = null;
[16378]194      dataset = null;
[16296]195      EvaluatedSolutions = 0;
196    }
197
[16285]198    private double[] GetValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
[16378]199      if (cachedData == null || this.dataset != dataset) {
200        InitCache(dataset);
201      }
202
[16285]203      var code = Compile(tree, dataset, OpCodes.MapSymbolToOpCode);
204      var remainingRows = rows.Length % BATCHSIZE;
205      var roundedTotal = rows.Length - remainingRows;
206
207      var result = new double[rows.Length];
208
209      for (int rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) {
210        Evaluate(code, rows, rowIndex, BATCHSIZE);
211        Array.Copy(code[0].buf, 0, result, rowIndex, BATCHSIZE);
212      }
213
214      if (remainingRows > 0) {
215        Evaluate(code, rows, roundedTotal, remainingRows);
216        Array.Copy(code[0].buf, 0, result, roundedTotal, remainingRows);
217      }
218
[16378]219      // when evaluation took place without any error, we can increment the counter
220      lock (syncRoot) {
221        EvaluatedSolutions++;
222      }
223
[16285]224      return result;
225    }
226
[16293]227    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
228      return GetValues(tree, dataset, rows);
229    }
230
[16285]231    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
[16296]232      return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
[16285]233    }
234
235    private BatchInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper) {
236      var root = tree.Root.GetSubtree(0).GetSubtree(0);
237      var code = new BatchInstruction[root.GetLength()];
238      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
239      int c = 1, i = 0;
240      foreach (var node in root.IterateNodesBreadth()) {
[16296]241        if (node.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
242        code[i] = new BatchInstruction {
243          opcode = opCodeMapper(node),
244          narg = (ushort)node.SubtreeCount,
245          buf = new double[BATCHSIZE],
246          childIndex = c
247        };
[16285]248        if (node is VariableTreeNode variable) {
249          code[i].weight = variable.Weight;
[16296]250          if (cachedData.ContainsKey(variable.VariableName)) {
251            code[i].data = cachedData[variable.VariableName];
252          } else {
253            code[i].data = dataset.GetReadOnlyDoubleValues(variable.VariableName).ToArray();
254            cachedData[variable.VariableName] = code[i].data;
255          }
[16285]256        } else if (node is ConstantTreeNode constant) {
257          code[i].value = constant.Value;
[16287]258          for (int j = 0; j < BATCHSIZE; ++j)
259            code[i].buf[j] = code[i].value;
[16285]260        }
261        c += node.SubtreeCount;
262        ++i;
263      }
264      return code;
265    }
266  }
267}
Note: See TracBrowser for help on using the repository browser.