Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2988_ModelsOfModels2/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeBatchInterpreter.cs @ 17198

Last change on this file since 17198 was 16899, checked in by msemenki, 6 years ago

#2988: New version of class structure.

File size: 9.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
9using HeuristicLab.Parameters;
10using HEAL.Attic;
11
12using static HeuristicLab.Problems.DataAnalysis.Symbolic.BatchOperations;
13
14namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
15  [Item("SymbolicDataAnalysisExpressionTreeBatchInterpreter", "An interpreter that uses batching and vectorization techniques to achieve faster performance.")]
16  [StorableType("BEB15146-BB95-4838-83AC-6838543F017B")]
17  public class SymbolicDataAnalysisExpressionTreeBatchInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
18    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
19
20    #region parameters
21    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
22      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
23    }
24    #endregion
25
26    #region properties
27    public int EvaluatedSolutions {
28      get { return EvaluatedSolutionsParameter.Value.Value; }
29      set { EvaluatedSolutionsParameter.Value.Value = value; }
30    }
31    #endregion
32
33    public void ClearState() { }
34
35    public SymbolicDataAnalysisExpressionTreeBatchInterpreter() {
36      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
37    }
38
39    [StorableConstructor]
40    protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(StorableConstructorFlag _) : base(_) { }
41    protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(SymbolicDataAnalysisExpressionTreeBatchInterpreter original, Cloner cloner) : base(original, cloner) {
42    }
43    public override IDeepCloneable Clone(Cloner cloner) {
44      return new SymbolicDataAnalysisExpressionTreeBatchInterpreter(this, cloner);
45    }
46
47    private void LoadData(BatchInstruction instr, int[] rows, int rowIndex, int batchSize) {
48      for (int i = 0; i < batchSize; ++i) {
49        var row = rows[rowIndex] + i;
50        instr.buf[i] = instr.weight * instr.data[row];
51      }
52    }
53    private void SubTreeEvoluate(BatchInstruction instr, int[] rows, int rowIndex, int batchSize) {
54      for (int i = 0; i < batchSize; ++i) {
55        var row = rows[rowIndex] + i;
56        instr.buf[i] = instr.data[row];  // не забыть заполнить
57      }
58    }
59    private void Evaluate(BatchInstruction[] code, int[] rows, int rowIndex, int batchSize) {
60      for (int i = code.Length - 1; i >= 0; --i) {
61        var instr = code[i];
62        var c = instr.childIndex;
63        var n = instr.narg;
64
65        switch (instr.opcode) {
66          case OpCode.Variable: {
67              LoadData(instr, rows, rowIndex, batchSize);
68              break;
69            }
70          case OpCode.TreeModel: {
71              SubTreeEvoluate(instr, rows, rowIndex, batchSize);
72              break;
73            }
74
75          case OpCode.Add: {
76              Load(instr.buf, code[c].buf);
77              for (int j = 1; j < n; ++j) {
78                Add(instr.buf, code[c + j].buf);
79              }
80              break;
81            }
82
83          case OpCode.Sub: {
84              if (n == 1) {
85                Neg(instr.buf, code[c].buf);
86              } else {
87                Load(instr.buf, code[c].buf);
88                for (int j = 1; j < n; ++j) {
89                  Sub(instr.buf, code[c + j].buf);
90                }
91              }
92              break;
93            }
94
95          case OpCode.Mul: {
96              Load(instr.buf, code[c].buf);
97              for (int j = 1; j < n; ++j) {
98                Mul(instr.buf, code[c + j].buf);
99              }
100              break;
101            }
102
103          case OpCode.Div: {
104              if (n == 1) {
105                Inv(instr.buf, code[c].buf);
106              } else {
107                Load(instr.buf, code[c].buf);
108                for (int j = 1; j < n; ++j) {
109                  Div(instr.buf, code[c + j].buf);
110                }
111              }
112              break;
113            }
114
115          case OpCode.Square: {
116              Square(instr.buf, code[c].buf);
117              break;
118            }
119
120          case OpCode.Root: {
121              Load(instr.buf, code[c].buf);
122              Root(instr.buf, code[c + 1].buf);
123              break;
124            }
125
126          case OpCode.SquareRoot: {
127              Sqrt(instr.buf, code[c].buf);
128              break;
129            }
130
131          case OpCode.Cube: {
132              Cube(instr.buf, code[c].buf);
133              break;
134            }
135          case OpCode.CubeRoot: {
136              CubeRoot(instr.buf, code[c].buf);
137              break;
138            }
139
140          case OpCode.Power: {
141              Load(instr.buf, code[c].buf);
142              Pow(instr.buf, code[c + 1].buf);
143              break;
144            }
145
146          case OpCode.Exp: {
147              Exp(instr.buf, code[c].buf);
148              break;
149            }
150
151          case OpCode.Log: {
152              Log(instr.buf, code[c].buf);
153              break;
154            }
155
156          case OpCode.Sin: {
157              Sin(instr.buf, code[c].buf);
158              break;
159            }
160
161          case OpCode.Cos: {
162              Cos(instr.buf, code[c].buf);
163              break;
164            }
165
166          case OpCode.Tan: {
167              Tan(instr.buf, code[c].buf);
168              break;
169            }
170          case OpCode.Tanh: {
171              Tanh(instr.buf, code[c].buf);
172              break;
173            }
174          case OpCode.Absolute: {
175              Absolute(instr.buf, code[c].buf);
176              break;
177            }
178
179          case OpCode.AnalyticQuotient: {
180              Load(instr.buf, code[c].buf);
181              AnalyticQuotient(instr.buf, code[c + 1].buf);
182              break;
183            }
184        }
185      }
186    }
187
188    private readonly object syncRoot = new object();
189
190    [ThreadStatic]
191    private Dictionary<string, double[]> cachedData;
192
193    [ThreadStatic]
194    private IDataset dataset;
195
196    private void InitCache(IDataset dataset) {
197      this.dataset = dataset;
198      cachedData = new Dictionary<string, double[]>();
199      foreach (var v in dataset.DoubleVariables) {
200        cachedData[v] = dataset.GetDoubleValues(v).ToArray();
201      }
202    }
203
204    public void InitializeState() {
205      cachedData = null;
206      dataset = null;
207      EvaluatedSolutions = 0;
208    }
209
210    private double[] GetValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
211      if (cachedData == null || this.dataset != dataset) {
212        InitCache(dataset);
213      }
214
215      var code = Compile(tree, dataset, OpCode.MapSymbolToOpCode, rows);
216      var remainingRows = rows.Length % BATCHSIZE;
217      var roundedTotal = rows.Length - remainingRows;
218
219      var result = new double[rows.Length];
220
221      for (int rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) {
222        Evaluate(code, rows, rowIndex, BATCHSIZE);
223        Array.Copy(code[0].buf, 0, result, rowIndex, BATCHSIZE);
224      }
225
226      if (remainingRows > 0) {
227        Evaluate(code, rows, roundedTotal, remainingRows);
228        Array.Copy(code[0].buf, 0, result, roundedTotal, remainingRows);
229      }
230
231      // when evaluation took place without any error, we can increment the counter
232      lock (syncRoot) {
233        EvaluatedSolutions++;
234      }
235
236      return result;
237    }
238
239    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
240      return GetValues(tree, dataset, rows);
241    }
242
243    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
244      return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
245    }
246
247    private BatchInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, int[] rows) {
248      var root = tree.Root.GetSubtree(0).GetSubtree(0);
249      var code = new BatchInstruction[root.GetLength()];
250      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
251      int c = 1, i = 0;
252      var allRows = Enumerable.Range(0, dataset.Rows).ToArray();
253      foreach (var node in root.IterateNodesBreadth()) {
254        if (node.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
255        code[i] = new BatchInstruction {
256          opcode = opCodeMapper(node),
257          narg = (ushort)node.SubtreeCount,
258          buf = new double[BATCHSIZE],
259          childIndex = c
260        };
261        if (node is VariableTreeNode variable) {
262          code[i].weight = variable.Weight;
263          if (cachedData.ContainsKey(variable.VariableName)) {
264            code[i].data = cachedData[variable.VariableName];
265          } else {
266            code[i].data = dataset.GetReadOnlyDoubleValues(variable.VariableName).ToArray();
267            cachedData[variable.VariableName] = code[i].data;
268          }
269        } else if (node is ConstantTreeNode constant) {
270          code[i].value = constant.Value;
271          for (int j = 0; j < BATCHSIZE; ++j)
272            code[i].buf[j] = code[i].value;
273        } else if (node is TreeModelTreeNode subtree) {
274          code[i].data = GetValues(subtree.Tree, dataset, allRows);
275          EvaluatedSolutions--;
276        }
277        c += node.SubtreeCount;
278        ++i;
279      }
280      return code;
281    }
282  }
283}
Note: See TracBrowser for help on using the repository browser.