using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; namespace HeuristicLab.Problems.DataAnalysis.Symbolic { public sealed class VectorEvaluator : InterpreterBase { private const int BATCHSIZE = 128; [ThreadStatic] private Dictionary cachedData; [ThreadStatic] private IDataset dataset; [ThreadStatic] private int rowIndex; [ThreadStatic] private int[] rows; private void InitCache(IDataset dataset) { this.dataset = dataset; cachedData = new Dictionary(); foreach (var v in dataset.DoubleVariables) { cachedData[v] = dataset.GetReadOnlyDoubleValues(v).ToArray(); } } public double[] Evaluate(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) { if (cachedData == null || this.dataset != dataset) { InitCache(dataset); } this.rows = rows; var code = Compile(tree); var remainingRows = rows.Length % BATCHSIZE; var roundedTotal = rows.Length - remainingRows; var result = new double[rows.Length]; for (rowIndex = 0; rowIndex < roundedTotal; rowIndex += BATCHSIZE) { Evaluate(code); code[0].value.CopyTo(result, rowIndex, BATCHSIZE); } if (remainingRows > 0) { Evaluate(code); code[0].value.CopyTo(result, roundedTotal, remainingRows); } return result; } protected override void InitializeTerminalInstruction(ref Instruction instruction, ConstantTreeNode constant) { instruction.dblVal = constant.Value; instruction.value = new AlgebraicDoubleVector(BATCHSIZE); instruction.value.AssignConstant(instruction.dblVal); } protected override void InitializeTerminalInstruction(ref Instruction instruction, VariableTreeNode variable) { instruction.dblVal = variable.Weight; instruction.value = new AlgebraicDoubleVector(BATCHSIZE); if (cachedData.ContainsKey(variable.VariableName)) { instruction.data = cachedData[variable.VariableName]; } else { instruction.data = dataset.GetDoubleValues(variable.VariableName).ToArray(); cachedData[variable.VariableName] = (double[])instruction.data; } } protected override void InitializeInternalInstruction(ref Instruction instruction, ISymbolicExpressionTreeNode node) { instruction.value = new AlgebraicDoubleVector(BATCHSIZE); } protected override void LoadVariable(Instruction a) { var data = (double[])a.data; for (int i = rowIndex; i < rows.Length && i - rowIndex < BATCHSIZE; i++) a.value[i - rowIndex] = data[rows[i]]; a.value.Scale(a.dblVal); } } }