Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeBatchInterpreter.cs @ 18220

Last change on this file since 18220 was 18220, checked in by gkronber, 3 years ago

#3136: reintegrated structure-template GP branch into trunk

File size: 10.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HEAL.Attic;
32
33using static HeuristicLab.Problems.DataAnalysis.Symbolic.BatchOperations;
34
35namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
36  [Item("SymbolicDataAnalysisExpressionTreeBatchInterpreter", "An interpreter that uses batching and vectorization techniques to achieve faster performance.")]
37  [StorableType("BEB15146-BB95-4838-83AC-6838543F017B")]
38  public class SymbolicDataAnalysisExpressionTreeBatchInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
39    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
40
41    #region parameters
42    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
43      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
44    }
45    #endregion
46
47    #region properties
48    public int EvaluatedSolutions {
49      get { return EvaluatedSolutionsParameter.Value.Value; }
50      set { EvaluatedSolutionsParameter.Value.Value = value; }
51    }
52    #endregion
53
54    public void ClearState() { }
55
56    public SymbolicDataAnalysisExpressionTreeBatchInterpreter() {
57      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
58    }
59
60    [StorableConstructor]
61    protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(StorableConstructorFlag _) : base(_) { }
62    protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(SymbolicDataAnalysisExpressionTreeBatchInterpreter original, Cloner cloner) : base(original, cloner) {
63    }
64    public override IDeepCloneable Clone(Cloner cloner) {
65      return new SymbolicDataAnalysisExpressionTreeBatchInterpreter(this, cloner);
66    }
67
68    private void LoadData(BatchInstruction instr, int[] rows, int rowIndex, int batchSize) {
69      for (int i = 0; i < batchSize; ++i) {
70        var row = rows[rowIndex] + i;
71        instr.buf[i] = instr.weight * instr.data[row];
72      }
73    }
74
75    private void Evaluate(BatchInstruction[] code, int[] rows, int rowIndex, int batchSize) {
76      for (int i = code.Length - 1; i >= 0; --i) {
77        var instr = code[i];
78        var c = instr.childIndex;
79        var n = instr.narg;
80
81        switch (instr.opcode) {
82          case OpCodes.Variable: {
83              LoadData(instr, rows, rowIndex, batchSize);
84              break;
85            }
86          case OpCodes.Constant: // fall through
87          case OpCodes.Number:
88            break; // nothing to do here, don't remove because we want to prevent falling into the default case here.
89          case OpCodes.Add: {
90              Load(instr.buf, code[c].buf);
91              for (int j = 1; j < n; ++j) {
92                Add(instr.buf, code[c + j].buf);
93              }
94              break;
95            }
96
97          case OpCodes.Sub: {
98              if (n == 1) {
99                Neg(instr.buf, code[c].buf);
100              } else {
101                Load(instr.buf, code[c].buf);
102                for (int j = 1; j < n; ++j) {
103                  Sub(instr.buf, code[c + j].buf);
104                }
105              }
106              break;
107            }
108
109          case OpCodes.Mul: {
110              Load(instr.buf, code[c].buf);
111              for (int j = 1; j < n; ++j) {
112                Mul(instr.buf, code[c + j].buf);
113              }
114              break;
115            }
116
117          case OpCodes.Div: {
118              if (n == 1) {
119                Inv(instr.buf, code[c].buf);
120              } else {
121                Load(instr.buf, code[c].buf);
122                for (int j = 1; j < n; ++j) {
123                  Div(instr.buf, code[c + j].buf);
124                }
125              }
126              break;
127            }
128
129          case OpCodes.Square: {
130              Square(instr.buf, code[c].buf);
131              break;
132            }
133
134          case OpCodes.Root: {
135              Load(instr.buf, code[c].buf);
136              Root(instr.buf, code[c + 1].buf);
137              break;
138            }
139
140          case OpCodes.SquareRoot: {
141              Sqrt(instr.buf, code[c].buf);
142              break;
143            }
144
145          case OpCodes.Cube: {
146              Cube(instr.buf, code[c].buf);
147              break;
148            }
149          case OpCodes.CubeRoot: {
150              CubeRoot(instr.buf, code[c].buf);
151              break;
152            }
153
154          case OpCodes.Power: {
155              Load(instr.buf, code[c].buf);
156              Pow(instr.buf, code[c + 1].buf);
157              break;
158            }
159
160          case OpCodes.Exp: {
161              Exp(instr.buf, code[c].buf);
162              break;
163            }
164
165          case OpCodes.Log: {
166              Log(instr.buf, code[c].buf);
167              break;
168            }
169
170          case OpCodes.Sin: {
171              Sin(instr.buf, code[c].buf);
172              break;
173            }
174
175          case OpCodes.Cos: {
176              Cos(instr.buf, code[c].buf);
177              break;
178            }
179
180          case OpCodes.Tan: {
181              Tan(instr.buf, code[c].buf);
182              break;
183            }
184
185          case OpCodes.Tanh: {
186              Tanh(instr.buf, code[c].buf);
187              break;
188            }
189
190          case OpCodes.Absolute: {
191              Absolute(instr.buf, code[c].buf);
192              break;
193            }
194
195          case OpCodes.AnalyticQuotient: {
196              Load(instr.buf, code[c].buf);
197              AnalyticQuotient(instr.buf, code[c + 1].buf);
198              break;
199            }
200
201          case OpCodes.SubFunction: {
202              Load(instr.buf, code[c].buf);
203              break;
204            }
205          default: throw new NotSupportedException($"This interpreter does not support {(OpCode)instr.opcode}");
206        }
207      }
208    }
209
210    private readonly object syncRoot = new object();
211
212    [ThreadStatic]
213    private static Dictionary<string, double[]> cachedData;
214
215    [ThreadStatic]
216    private static IDataset cachedDataset;
217
218    private void InitCache(IDataset dataset) {
219      cachedDataset = dataset;
220      cachedData = new Dictionary<string, double[]>();
221      foreach (var v in dataset.DoubleVariables) {
222        cachedData[v] = dataset.GetDoubleValues(v).ToArray();
223      }
224    }
225
226    public void InitializeState() {
227      cachedData = null;
228      cachedDataset = null;
229      EvaluatedSolutions = 0;
230    }
231
232    private double[] GetValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
233      if (cachedData == null || cachedDataset != dataset || cachedDataset is ModifiableDataset) {
234        InitCache(dataset);
235      }
236
237      var code = Compile(tree, dataset, OpCodes.MapSymbolToOpCode);
238      var remainingRows = rows.Length % BATCHSIZE;
239      var roundedTotal = rows.Length - remainingRows;
240
241      var result = new double[rows.Length];
242
243      for (int rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) {
244        Evaluate(code, rows, rowIndex, BATCHSIZE);
245        Array.Copy(code[0].buf, 0, result, rowIndex, BATCHSIZE);
246      }
247
248      if (remainingRows > 0) {
249        Evaluate(code, rows, roundedTotal, remainingRows);
250        Array.Copy(code[0].buf, 0, result, roundedTotal, remainingRows);
251      }
252
253      // when evaluation took place without any error, we can increment the counter
254      lock (syncRoot) {
255        EvaluatedSolutions++;
256      }
257
258      return result;
259    }
260
261    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
262      return GetValues(tree, dataset, rows);
263    }
264
265    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
266      return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
267    }
268
269    private BatchInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper) {
270      var root = tree.Root.GetSubtree(0).GetSubtree(0);
271      var code = new BatchInstruction[root.GetLength()];
272      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
273      int c = 1, i = 0;
274      foreach (var node in root.IterateNodesBreadth()) {
275        if (node.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
276        code[i] = new BatchInstruction {
277          opcode = opCodeMapper(node),
278          narg = (ushort)node.SubtreeCount,
279          buf = new double[BATCHSIZE],
280          childIndex = c
281        };
282        if (node is VariableTreeNode variable) {
283          code[i].weight = variable.Weight;
284          if (cachedData.ContainsKey(variable.VariableName)) {
285            code[i].data = cachedData[variable.VariableName];
286          } else {
287            code[i].data = dataset.GetReadOnlyDoubleValues(variable.VariableName).ToArray();
288            cachedData[variable.VariableName] = code[i].data;
289          }
290        } else if (node is INumericTreeNode numeric) {
291          code[i].value = numeric.Value;
292          for (int j = 0; j < BATCHSIZE; ++j)
293            code[i].buf[j] = code[i].value;
294        }
295        c += node.SubtreeCount;
296        ++i;
297      }
298      return code;
299    }
300  }
301}
Note: See TracBrowser for help on using the repository browser.