Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeBatchInterpreter.cs @ 17928

Last change on this file since 17928 was 17801, checked in by gkronber, 4 years ago

#3084 updated interpreters to always invalidate their cached dataset when the cached dataset is a ModifiableDataset (as in the PDP)

File size: 10.1 KB
RevLine 
[17402]1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
[16285]23using System.Collections.Generic;
24using System.Linq;
25
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
[16565]31using HEAL.Attic;
[16285]32
33using static HeuristicLab.Problems.DataAnalysis.Symbolic.BatchOperations;
34
35namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
36  [Item("SymbolicDataAnalysisExpressionTreeBatchInterpreter", "An interpreter that uses batching and vectorization techniques to achieve faster performance.")]
[16565]37  [StorableType("BEB15146-BB95-4838-83AC-6838543F017B")]
[16285]38  public class SymbolicDataAnalysisExpressionTreeBatchInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
39    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
40
41    #region parameters
42    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
43      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
44    }
45    #endregion
46
47    #region properties
48    public int EvaluatedSolutions {
49      get { return EvaluatedSolutionsParameter.Value.Value; }
50      set { EvaluatedSolutionsParameter.Value.Value = value; }
51    }
52    #endregion
53
54    public void ClearState() { }
55
56    public SymbolicDataAnalysisExpressionTreeBatchInterpreter() {
57      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
58    }
59
60    [StorableConstructor]
[16565]61    protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(StorableConstructorFlag _) : base(_) { }
[16285]62    protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(SymbolicDataAnalysisExpressionTreeBatchInterpreter original, Cloner cloner) : base(original, cloner) {
63    }
64    public override IDeepCloneable Clone(Cloner cloner) {
65      return new SymbolicDataAnalysisExpressionTreeBatchInterpreter(this, cloner);
66    }
67
68    private void LoadData(BatchInstruction instr, int[] rows, int rowIndex, int batchSize) {
69      for (int i = 0; i < batchSize; ++i) {
70        var row = rows[rowIndex] + i;
71        instr.buf[i] = instr.weight * instr.data[row];
72      }
73    }
74
75    private void Evaluate(BatchInstruction[] code, int[] rows, int rowIndex, int batchSize) {
76      for (int i = code.Length - 1; i >= 0; --i) {
77        var instr = code[i];
78        var c = instr.childIndex;
79        var n = instr.narg;
80
81        switch (instr.opcode) {
82          case OpCodes.Variable: {
83              LoadData(instr, rows, rowIndex, batchSize);
84              break;
85            }
[16768]86          case OpCodes.Constant: break; // nothing to do here, don't remove because we want to prevent falling into the default case here.
[16285]87          case OpCodes.Add: {
88              Load(instr.buf, code[c].buf);
89              for (int j = 1; j < n; ++j) {
90                Add(instr.buf, code[c + j].buf);
91              }
92              break;
93            }
94
95          case OpCodes.Sub: {
96              if (n == 1) {
97                Neg(instr.buf, code[c].buf);
98              } else {
99                Load(instr.buf, code[c].buf);
100                for (int j = 1; j < n; ++j) {
101                  Sub(instr.buf, code[c + j].buf);
102                }
103              }
[16293]104              break;
[16285]105            }
106
107          case OpCodes.Mul: {
108              Load(instr.buf, code[c].buf);
109              for (int j = 1; j < n; ++j) {
110                Mul(instr.buf, code[c + j].buf);
111              }
112              break;
113            }
114
115          case OpCodes.Div: {
116              if (n == 1) {
117                Inv(instr.buf, code[c].buf);
118              } else {
119                Load(instr.buf, code[c].buf);
120                for (int j = 1; j < n; ++j) {
121                  Div(instr.buf, code[c + j].buf);
122                }
123              }
[16293]124              break;
[16285]125            }
126
[16293]127          case OpCodes.Square: {
128              Square(instr.buf, code[c].buf);
129              break;
130            }
131
132          case OpCodes.Root: {
[16356]133              Load(instr.buf, code[c].buf);
134              Root(instr.buf, code[c + 1].buf);
[16293]135              break;
136            }
137
138          case OpCodes.SquareRoot: {
139              Sqrt(instr.buf, code[c].buf);
140              break;
141            }
142
[16356]143          case OpCodes.Cube: {
144              Cube(instr.buf, code[c].buf);
145              break;
146            }
147          case OpCodes.CubeRoot: {
148              CubeRoot(instr.buf, code[c].buf);
149              break;
150            }
151
[16293]152          case OpCodes.Power: {
[16356]153              Load(instr.buf, code[c].buf);
154              Pow(instr.buf, code[c + 1].buf);
[16293]155              break;
156            }
157
[16285]158          case OpCodes.Exp: {
159              Exp(instr.buf, code[c].buf);
160              break;
161            }
162
163          case OpCodes.Log: {
164              Log(instr.buf, code[c].buf);
165              break;
166            }
[16293]167
168          case OpCodes.Sin: {
169              Sin(instr.buf, code[c].buf);
170              break;
171            }
172
173          case OpCodes.Cos: {
174              Cos(instr.buf, code[c].buf);
175              break;
176            }
177
178          case OpCodes.Tan: {
179              Tan(instr.buf, code[c].buf);
180              break;
181            }
[16656]182          case OpCodes.Tanh: {
183              Tanh(instr.buf, code[c].buf);
184              break;
185            }
[16356]186          case OpCodes.Absolute: {
187              Absolute(instr.buf, code[c].buf);
188              break;
189            }
190
[16360]191          case OpCodes.AnalyticQuotient: {
[16356]192              Load(instr.buf, code[c].buf);
193              AnalyticQuotient(instr.buf, code[c + 1].buf);
194              break;
195            }
[16762]196          default: throw new NotSupportedException($"This interpreter does not support {(OpCode)instr.opcode}");
[16285]197        }
198      }
199    }
200
[16378]201    private readonly object syncRoot = new object();
202
[16296]203    [ThreadStatic]
[17402]204    private static Dictionary<string, double[]> cachedData;
[16296]205
[16378]206    [ThreadStatic]
[17402]207    private static IDataset cachedDataset;
[16378]208
[16296]209    private void InitCache(IDataset dataset) {
[17402]210      cachedDataset = dataset;
[16296]211      cachedData = new Dictionary<string, double[]>();
212      foreach (var v in dataset.DoubleVariables) {
[16301]213        cachedData[v] = dataset.GetDoubleValues(v).ToArray();
[16296]214      }
215    }
216
217    public void InitializeState() {
218      cachedData = null;
[17402]219      cachedDataset = null;
[16296]220      EvaluatedSolutions = 0;
221    }
222
[16285]223    private double[] GetValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
[17801]224      if (cachedData == null || cachedDataset != dataset || cachedDataset is ModifiableDataset) {
[16378]225        InitCache(dataset);
226      }
227
[16285]228      var code = Compile(tree, dataset, OpCodes.MapSymbolToOpCode);
229      var remainingRows = rows.Length % BATCHSIZE;
230      var roundedTotal = rows.Length - remainingRows;
231
232      var result = new double[rows.Length];
233
234      for (int rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) {
235        Evaluate(code, rows, rowIndex, BATCHSIZE);
236        Array.Copy(code[0].buf, 0, result, rowIndex, BATCHSIZE);
237      }
238
239      if (remainingRows > 0) {
240        Evaluate(code, rows, roundedTotal, remainingRows);
241        Array.Copy(code[0].buf, 0, result, roundedTotal, remainingRows);
242      }
243
[16378]244      // when evaluation took place without any error, we can increment the counter
245      lock (syncRoot) {
246        EvaluatedSolutions++;
247      }
248
[16285]249      return result;
250    }
251
[16293]252    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
253      return GetValues(tree, dataset, rows);
254    }
255
[16285]256    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
[16296]257      return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
[16285]258    }
259
260    private BatchInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper) {
261      var root = tree.Root.GetSubtree(0).GetSubtree(0);
262      var code = new BatchInstruction[root.GetLength()];
263      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
264      int c = 1, i = 0;
265      foreach (var node in root.IterateNodesBreadth()) {
[16296]266        if (node.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
267        code[i] = new BatchInstruction {
268          opcode = opCodeMapper(node),
269          narg = (ushort)node.SubtreeCount,
270          buf = new double[BATCHSIZE],
271          childIndex = c
272        };
[16285]273        if (node is VariableTreeNode variable) {
274          code[i].weight = variable.Weight;
[16296]275          if (cachedData.ContainsKey(variable.VariableName)) {
276            code[i].data = cachedData[variable.VariableName];
277          } else {
278            code[i].data = dataset.GetReadOnlyDoubleValues(variable.VariableName).ToArray();
279            cachedData[variable.VariableName] = code[i].data;
280          }
[16285]281        } else if (node is ConstantTreeNode constant) {
282          code[i].value = constant.Value;
[16287]283          for (int j = 0; j < BATCHSIZE; ++j)
284            code[i].buf[j] = code[i].value;
[16285]285        }
286        c += node.SubtreeCount;
287        ++i;
288      }
289      return code;
290    }
291  }
292}
Note: See TracBrowser for help on using the repository browser.