#region License Information
/* HeuristicLab
* Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
using HeuristicLab.Parameters;
using HEAL.Attic;
using static HeuristicLab.Problems.DataAnalysis.Symbolic.BatchOperations;
namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
[Item("SymbolicDataAnalysisExpressionTreeBatchInterpreter", "An interpreter that uses batching and vectorization techniques to achieve faster performance.")]
[StorableType("BEB15146-BB95-4838-83AC-6838543F017B")]
public class SymbolicDataAnalysisExpressionTreeBatchInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
#region parameters
public IFixedValueParameter EvaluatedSolutionsParameter {
get { return (IFixedValueParameter)Parameters[EvaluatedSolutionsParameterName]; }
}
#endregion
#region properties
public int EvaluatedSolutions {
get { return EvaluatedSolutionsParameter.Value.Value; }
set { EvaluatedSolutionsParameter.Value.Value = value; }
}
#endregion
public void ClearState() { }
public SymbolicDataAnalysisExpressionTreeBatchInterpreter() {
Parameters.Add(new FixedValueParameter(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
}
[StorableConstructor]
protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(StorableConstructorFlag _) : base(_) { }
protected SymbolicDataAnalysisExpressionTreeBatchInterpreter(SymbolicDataAnalysisExpressionTreeBatchInterpreter original, Cloner cloner) : base(original, cloner) {
}
public override IDeepCloneable Clone(Cloner cloner) {
return new SymbolicDataAnalysisExpressionTreeBatchInterpreter(this, cloner);
}
private void LoadData(BatchInstruction instr, int[] rows, int rowIndex, int batchSize) {
for (int i = 0; i < batchSize; ++i) {
var row = rows[rowIndex] + i;
instr.buf[i] = instr.weight * instr.data[row];
}
}
private void Evaluate(BatchInstruction[] code, int[] rows, int rowIndex, int batchSize) {
for (int i = code.Length - 1; i >= 0; --i) {
var instr = code[i];
var c = instr.childIndex;
var n = instr.narg;
switch (instr.opcode) {
case OpCodes.Variable: {
LoadData(instr, rows, rowIndex, batchSize);
break;
}
case OpCodes.Constant: // fall through
case OpCodes.Number:
break; // nothing to do here, don't remove because we want to prevent falling into the default case here.
case OpCodes.Add: {
Load(instr.buf, code[c].buf);
for (int j = 1; j < n; ++j) {
Add(instr.buf, code[c + j].buf);
}
break;
}
case OpCodes.Sub: {
if (n == 1) {
Neg(instr.buf, code[c].buf);
} else {
Load(instr.buf, code[c].buf);
for (int j = 1; j < n; ++j) {
Sub(instr.buf, code[c + j].buf);
}
}
break;
}
case OpCodes.Mul: {
Load(instr.buf, code[c].buf);
for (int j = 1; j < n; ++j) {
Mul(instr.buf, code[c + j].buf);
}
break;
}
case OpCodes.Div: {
if (n == 1) {
Inv(instr.buf, code[c].buf);
} else {
Load(instr.buf, code[c].buf);
for (int j = 1; j < n; ++j) {
Div(instr.buf, code[c + j].buf);
}
}
break;
}
case OpCodes.Square: {
Square(instr.buf, code[c].buf);
break;
}
case OpCodes.Root: {
Load(instr.buf, code[c].buf);
Root(instr.buf, code[c + 1].buf);
break;
}
case OpCodes.SquareRoot: {
Sqrt(instr.buf, code[c].buf);
break;
}
case OpCodes.Cube: {
Cube(instr.buf, code[c].buf);
break;
}
case OpCodes.CubeRoot: {
CubeRoot(instr.buf, code[c].buf);
break;
}
case OpCodes.Power: {
Load(instr.buf, code[c].buf);
Pow(instr.buf, code[c + 1].buf);
break;
}
case OpCodes.Exp: {
Exp(instr.buf, code[c].buf);
break;
}
case OpCodes.Log: {
Log(instr.buf, code[c].buf);
break;
}
case OpCodes.Sin: {
Sin(instr.buf, code[c].buf);
break;
}
case OpCodes.Cos: {
Cos(instr.buf, code[c].buf);
break;
}
case OpCodes.Tan: {
Tan(instr.buf, code[c].buf);
break;
}
case OpCodes.Tanh: {
Tanh(instr.buf, code[c].buf);
break;
}
case OpCodes.Absolute: {
Absolute(instr.buf, code[c].buf);
break;
}
case OpCodes.AnalyticQuotient: {
Load(instr.buf, code[c].buf);
AnalyticQuotient(instr.buf, code[c + 1].buf);
break;
}
case OpCodes.SubFunction: {
Load(instr.buf, code[c].buf);
break;
}
default: throw new NotSupportedException($"This interpreter does not support {(OpCode)instr.opcode}");
}
}
}
private readonly object syncRoot = new object();
[ThreadStatic]
private static Dictionary cachedData;
[ThreadStatic]
private static IDataset cachedDataset;
private void InitCache(IDataset dataset) {
cachedDataset = dataset;
cachedData = new Dictionary();
foreach (var v in dataset.DoubleVariables) {
cachedData[v] = dataset.GetDoubleValues(v).ToArray();
}
}
public void InitializeState() {
cachedData = null;
cachedDataset = null;
EvaluatedSolutions = 0;
}
private double[] GetValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
if (cachedData == null || cachedDataset != dataset || cachedDataset is ModifiableDataset) {
InitCache(dataset);
}
var code = Compile(tree, dataset, OpCodes.MapSymbolToOpCode);
var remainingRows = rows.Length % BATCHSIZE;
var roundedTotal = rows.Length - remainingRows;
var result = new double[rows.Length];
for (int rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) {
Evaluate(code, rows, rowIndex, BATCHSIZE);
Array.Copy(code[0].buf, 0, result, rowIndex, BATCHSIZE);
}
if (remainingRows > 0) {
Evaluate(code, rows, roundedTotal, remainingRows);
Array.Copy(code[0].buf, 0, result, roundedTotal, remainingRows);
}
// when evaluation took place without any error, we can increment the counter
lock (syncRoot) {
EvaluatedSolutions++;
}
return result;
}
public IEnumerable GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
return GetValues(tree, dataset, rows);
}
public IEnumerable GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable rows) {
return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
}
private BatchInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func opCodeMapper) {
var root = tree.Root.GetSubtree(0).GetSubtree(0);
var code = new BatchInstruction[root.GetLength()];
if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
int c = 1, i = 0;
foreach (var node in root.IterateNodesBreadth()) {
if (node.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
code[i] = new BatchInstruction {
opcode = opCodeMapper(node),
narg = (ushort)node.SubtreeCount,
buf = new double[BATCHSIZE],
childIndex = c
};
if (node is VariableTreeNode variable) {
code[i].weight = variable.Weight;
if (cachedData.ContainsKey(variable.VariableName)) {
code[i].data = cachedData[variable.VariableName];
} else {
code[i].data = dataset.GetReadOnlyDoubleValues(variable.VariableName).ToArray();
cachedData[variable.VariableName] = code[i].data;
}
} else if (node is INumericTreeNode numeric) {
code[i].value = numeric.Value;
for (int j = 0; j < BATCHSIZE; ++j)
code[i].buf[j] = code[i].value;
}
c += node.SubtreeCount;
++i;
}
return code;
}
}
}