Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeBatchInterpreter.cs @ 16296

Last change on this file since 16296 was 16296, checked in by bburlacu, 5 years ago

#2958: SymbolicDataAnalysisExpressionTreeBatchInterpreter: simplify Compile, add cache for variable values (helps a lot with performance).

File size: 8.1 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
9using HeuristicLab.Parameters;
10using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
11
12using static HeuristicLab.Problems.DataAnalysis.Symbolic.BatchOperations;
13//using static HeuristicLab.Problems.DataAnalysis.Symbolic.BatchOperationsVector;
14
15namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
16  [Item("SymbolicDataAnalysisExpressionTreeBatchInterpreter", "An interpreter that uses batching and vectorization techniques to achieve faster performance.")]
17  [StorableClass]
18  public class SymbolicDataAnalysisExpressionTreeBatchInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
19    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
20
21    #region parameters
22    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
23      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
24    }
25    #endregion
26
27    #region properties
28    public int EvaluatedSolutions {
29      get { return EvaluatedSolutionsParameter.Value.Value; }
30      set { EvaluatedSolutionsParameter.Value.Value = value; }
31    }
32    #endregion
33
34    public void ClearState() { }
35
36    public SymbolicDataAnalysisExpressionTreeBatchInterpreter() {
37      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
38    }
39
40
41    [StorableConstructor]
42    protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(bool deserializing) : base(deserializing) { }
43    protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(SymbolicDataAnalysisExpressionTreeBatchInterpreter original, Cloner cloner) : base(original, cloner) {
44    }
45    public override IDeepCloneable Clone(Cloner cloner) {
46      return new SymbolicDataAnalysisExpressionTreeBatchInterpreter(this, cloner);
47    }
48
49    private void LoadData(BatchInstruction instr, int[] rows, int rowIndex, int batchSize) {
50      for (int i = 0; i < batchSize; ++i) {
51        var row = rows[rowIndex] + i;
52        instr.buf[i] = instr.weight * instr.data[row];
53      }
54    }
55
56    private void Evaluate(BatchInstruction[] code, int[] rows, int rowIndex, int batchSize) {
57      for (int i = code.Length - 1; i >= 0; --i) {
58        var instr = code[i];
59        var c = instr.childIndex;
60        var n = instr.narg;
61
62        switch (instr.opcode) {
63          case OpCodes.Variable: {
64              LoadData(instr, rows, rowIndex, batchSize);
65              break;
66            }
67
68          case OpCodes.Add: {
69              Load(instr.buf, code[c].buf);
70              for (int j = 1; j < n; ++j) {
71                Add(instr.buf, code[c + j].buf);
72              }
73              break;
74            }
75
76          case OpCodes.Sub: {
77              if (n == 1) {
78                Neg(instr.buf, code[c].buf);
79              } else {
80                Load(instr.buf, code[c].buf);
81                for (int j = 1; j < n; ++j) {
82                  Sub(instr.buf, code[c + j].buf);
83                }
84              }
85              break;
86            }
87
88          case OpCodes.Mul: {
89              Load(instr.buf, code[c].buf);
90              for (int j = 1; j < n; ++j) {
91                Mul(instr.buf, code[c + j].buf);
92              }
93              break;
94            }
95
96          case OpCodes.Div: {
97              if (n == 1) {
98                Inv(instr.buf, code[c].buf);
99              } else {
100                Load(instr.buf, code[c].buf);
101                for (int j = 1; j < n; ++j) {
102                  Div(instr.buf, code[c + j].buf);
103                }
104              }
105              break;
106            }
107
108          case OpCodes.Square: {
109              Square(instr.buf, code[c].buf);
110              break;
111            }
112
113          case OpCodes.Root: {
114              Root(instr.buf, code[c].buf);
115              break;
116            }
117
118          case OpCodes.SquareRoot: {
119              Sqrt(instr.buf, code[c].buf);
120              break;
121            }
122
123          case OpCodes.Power: {
124              Pow(instr.buf, code[c].buf);
125              break;
126            }
127
128          case OpCodes.Exp: {
129              Exp(instr.buf, code[c].buf);
130              break;
131            }
132
133          case OpCodes.Log: {
134              Log(instr.buf, code[c].buf);
135              break;
136            }
137
138          case OpCodes.Sin: {
139              Sin(instr.buf, code[c].buf);
140              break;
141            }
142
143          case OpCodes.Cos: {
144              Cos(instr.buf, code[c].buf);
145              break;
146            }
147
148          case OpCodes.Tan: {
149              Tan(instr.buf, code[c].buf);
150              break;
151            }
152        }
153      }
154    }
155
156    [ThreadStatic]
157    private Dictionary<string, double[]> cachedData;
158
159    private void InitCache(IDataset dataset) {
160      cachedData = new Dictionary<string, double[]>();
161      foreach (var v in dataset.DoubleVariables) {
162        var values = dataset.GetDoubleValues(v).ToArray();
163      }
164    }
165
166    public void InitializeState() {
167      cachedData = null;
168      EvaluatedSolutions = 0;
169    }
170
171    private double[] GetValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
172      var code = Compile(tree, dataset, OpCodes.MapSymbolToOpCode);
173      var remainingRows = rows.Length % BATCHSIZE;
174      var roundedTotal = rows.Length - remainingRows;
175
176      var result = new double[rows.Length];
177
178      for (int rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) {
179        Evaluate(code, rows, rowIndex, BATCHSIZE);
180        Array.Copy(code[0].buf, 0, result, rowIndex, BATCHSIZE);
181      }
182
183      if (remainingRows > 0) {
184        Evaluate(code, rows, roundedTotal, remainingRows);
185        Array.Copy(code[0].buf, 0, result, roundedTotal, remainingRows);
186      }
187
188      return result;
189    }
190
191    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
192      if (cachedData == null) {
193        InitCache(dataset);
194      }
195      return GetValues(tree, dataset, rows);
196    }
197
198    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
199      return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
200    }
201
202    private BatchInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper) {
203      var root = tree.Root.GetSubtree(0).GetSubtree(0);
204      var code = new BatchInstruction[root.GetLength()];
205      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
206      code[0] = new BatchInstruction { narg = (ushort)root.SubtreeCount, opcode = opCodeMapper(root) };
207      int c = 1, i = 0;
208      foreach (var node in root.IterateNodesBreadth()) {
209        if (node.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
210        code[i] = new BatchInstruction {
211          opcode = opCodeMapper(node),
212          narg = (ushort)node.SubtreeCount,
213          buf = new double[BATCHSIZE],
214          childIndex = c
215        };
216        if (node is VariableTreeNode variable) {
217          code[i].weight = variable.Weight;
218          if (cachedData.ContainsKey(variable.VariableName)) {
219            code[i].data = cachedData[variable.VariableName];
220          } else {
221            code[i].data = dataset.GetReadOnlyDoubleValues(variable.VariableName).ToArray();
222            cachedData[variable.VariableName] = code[i].data;
223          }
224        } else if (node is ConstantTreeNode constant) {
225          code[i].value = constant.Value;
226          for (int j = 0; j < BATCHSIZE; ++j)
227            code[i].buf[j] = code[i].value;
228        }
229        c += node.SubtreeCount;
230        ++i;
231      }
232      return code;
233    }
234  }
235}
Note: See TracBrowser for help on using the repository browser.