Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2915-AbsoluteSymbol/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeBatchInterpreter.cs @ 16347

Last change on this file since 16347 was 16347, checked in by gkronber, 5 years ago

#2915: fixed bug in BatchInterpreter for Root symbol

File size: 8.5 KB
RevLine 
[16285]1using System;
2using System.Collections.Generic;
3using System.Linq;
4
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
9using HeuristicLab.Parameters;
10using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
11
12using static HeuristicLab.Problems.DataAnalysis.Symbolic.BatchOperations;
13
14namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
15  [Item("SymbolicDataAnalysisExpressionTreeBatchInterpreter", "An interpreter that uses batching and vectorization techniques to achieve faster performance.")]
16  [StorableClass]
17  public class SymbolicDataAnalysisExpressionTreeBatchInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
18    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
19
20    #region parameters
21    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
22      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
23    }
24    #endregion
25
26    #region properties
27    public int EvaluatedSolutions {
28      get { return EvaluatedSolutionsParameter.Value.Value; }
29      set { EvaluatedSolutionsParameter.Value.Value = value; }
30    }
31    #endregion
32
33    public void ClearState() { }
34
35    public SymbolicDataAnalysisExpressionTreeBatchInterpreter() {
36      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
37    }
38
39
40    [StorableConstructor]
41    protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(bool deserializing) : base(deserializing) { }
42    protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(SymbolicDataAnalysisExpressionTreeBatchInterpreter original, Cloner cloner) : base(original, cloner) {
43    }
44    public override IDeepCloneable Clone(Cloner cloner) {
45      return new SymbolicDataAnalysisExpressionTreeBatchInterpreter(this, cloner);
46    }
47
48    private void LoadData(BatchInstruction instr, int[] rows, int rowIndex, int batchSize) {
49      for (int i = 0; i < batchSize; ++i) {
50        var row = rows[rowIndex] + i;
51        instr.buf[i] = instr.weight * instr.data[row];
52      }
53    }
54
55    private void Evaluate(BatchInstruction[] code, int[] rows, int rowIndex, int batchSize) {
56      for (int i = code.Length - 1; i >= 0; --i) {
57        var instr = code[i];
58        var c = instr.childIndex;
59        var n = instr.narg;
60
61        switch (instr.opcode) {
62          case OpCodes.Variable: {
63              LoadData(instr, rows, rowIndex, batchSize);
64              break;
65            }
[16293]66
[16285]67          case OpCodes.Add: {
68              Load(instr.buf, code[c].buf);
69              for (int j = 1; j < n; ++j) {
70                Add(instr.buf, code[c + j].buf);
71              }
72              break;
73            }
74
75          case OpCodes.Sub: {
76              if (n == 1) {
77                Neg(instr.buf, code[c].buf);
78              } else {
79                Load(instr.buf, code[c].buf);
80                for (int j = 1; j < n; ++j) {
81                  Sub(instr.buf, code[c + j].buf);
82                }
83              }
[16293]84              break;
[16285]85            }
86
87          case OpCodes.Mul: {
88              Load(instr.buf, code[c].buf);
89              for (int j = 1; j < n; ++j) {
90                Mul(instr.buf, code[c + j].buf);
91              }
92              break;
93            }
94
95          case OpCodes.Div: {
96              if (n == 1) {
97                Inv(instr.buf, code[c].buf);
98              } else {
99                Load(instr.buf, code[c].buf);
100                for (int j = 1; j < n; ++j) {
101                  Div(instr.buf, code[c + j].buf);
102                }
103              }
[16293]104              break;
[16285]105            }
106
[16293]107          case OpCodes.Square: {
108              Square(instr.buf, code[c].buf);
109              break;
110            }
111
112          case OpCodes.Root: {
[16347]113              Load(instr.buf, code[c].buf);
[16293]114              Root(instr.buf, code[c].buf);
115              break;
116            }
117
118          case OpCodes.SquareRoot: {
119              Sqrt(instr.buf, code[c].buf);
120              break;
121            }
122
[16345]123          case OpCodes.Cube: {
124              Cube(instr.buf, code[c].buf);
125              break;
126            }
127          case OpCodes.CubeRoot: {
128              CubeRoot(instr.buf, code[c].buf);
129              break;
130            }
131
[16293]132          case OpCodes.Power: {
[16345]133              Load(instr.buf, code[c].buf);
134              Pow(instr.buf, code[c + 1].buf);
[16293]135              break;
136            }
137
[16285]138          case OpCodes.Exp: {
139              Exp(instr.buf, code[c].buf);
140              break;
141            }
142
143          case OpCodes.Log: {
144              Log(instr.buf, code[c].buf);
145              break;
146            }
[16293]147
148          case OpCodes.Sin: {
149              Sin(instr.buf, code[c].buf);
150              break;
151            }
152
153          case OpCodes.Cos: {
154              Cos(instr.buf, code[c].buf);
155              break;
156            }
157
158          case OpCodes.Tan: {
159              Tan(instr.buf, code[c].buf);
160              break;
161            }
[16346]162
[16345]163          case OpCodes.Absolute: {
164              Absolute(instr.buf, code[c].buf);
165              break;
166            }
[16346]167
[16345]168          case OpCodes.AnalyticalQuotient: {
169              Load(instr.buf, code[c].buf);
[16346]170              AnalyticQuotient(instr.buf, code[c + 1].buf);
[16345]171              break;
172            }
[16285]173        }
174      }
175    }
176
[16296]177    [ThreadStatic]
178    private Dictionary<string, double[]> cachedData;
179
180    private void InitCache(IDataset dataset) {
181      cachedData = new Dictionary<string, double[]>();
182      foreach (var v in dataset.DoubleVariables) {
[16301]183        cachedData[v] = dataset.GetDoubleValues(v).ToArray();
[16296]184      }
185    }
186
187    public void InitializeState() {
188      cachedData = null;
189      EvaluatedSolutions = 0;
190    }
191
[16285]192    private double[] GetValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
193      var code = Compile(tree, dataset, OpCodes.MapSymbolToOpCode);
194      var remainingRows = rows.Length % BATCHSIZE;
195      var roundedTotal = rows.Length - remainingRows;
196
197      var result = new double[rows.Length];
198
199      for (int rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) {
200        Evaluate(code, rows, rowIndex, BATCHSIZE);
201        Array.Copy(code[0].buf, 0, result, rowIndex, BATCHSIZE);
202      }
203
204      if (remainingRows > 0) {
205        Evaluate(code, rows, roundedTotal, remainingRows);
206        Array.Copy(code[0].buf, 0, result, roundedTotal, remainingRows);
207      }
208
209      return result;
210    }
211
[16293]212    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
[16296]213      if (cachedData == null) {
214        InitCache(dataset);
215      }
[16293]216      return GetValues(tree, dataset, rows);
217    }
218
[16285]219    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
[16296]220      return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
[16285]221    }
222
223    private BatchInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper) {
224      var root = tree.Root.GetSubtree(0).GetSubtree(0);
225      var code = new BatchInstruction[root.GetLength()];
226      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
227      int c = 1, i = 0;
228      foreach (var node in root.IterateNodesBreadth()) {
[16296]229        if (node.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
230        code[i] = new BatchInstruction {
231          opcode = opCodeMapper(node),
232          narg = (ushort)node.SubtreeCount,
233          buf = new double[BATCHSIZE],
234          childIndex = c
235        };
[16285]236        if (node is VariableTreeNode variable) {
237          code[i].weight = variable.Weight;
[16296]238          if (cachedData.ContainsKey(variable.VariableName)) {
239            code[i].data = cachedData[variable.VariableName];
240          } else {
241            code[i].data = dataset.GetReadOnlyDoubleValues(variable.VariableName).ToArray();
242            cachedData[variable.VariableName] = code[i].data;
243          }
[16285]244        } else if (node is ConstantTreeNode constant) {
245          code[i].value = constant.Value;
[16287]246          for (int j = 0; j < BATCHSIZE; ++j)
247            code[i].buf[j] = code[i].value;
[16285]248        }
249        c += node.SubtreeCount;
250        ++i;
251      }
252      return code;
253    }
254  }
255}
Note: See TracBrowser for help on using the repository browser.