#region License Information /* HeuristicLab * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; using HEAL.Attic; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Parameters; namespace HeuristicLab.Problems.DataAnalysis.Symbolic { [StorableType("91723319-8F15-4D33-B277-40AC7C7CF9AE")] [Item("SymbolicDataAnalysisExpressionTreeNativeInterpreter", "Operator calling into native C++ code for tree interpretation.")] public class SymbolicDataAnalysisExpressionTreeNativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter { private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions"; #region parameters public IFixedValueParameter EvaluatedSolutionsParameter { get { return (IFixedValueParameter)Parameters[EvaluatedSolutionsParameterName]; } } #endregion #region properties public int EvaluatedSolutions { get { return EvaluatedSolutionsParameter.Value.Value; } set { EvaluatedSolutionsParameter.Value.Value = value; } } #endregion public SymbolicDataAnalysisExpressionTreeNativeInterpreter() { Parameters.Add(new FixedValueParameter(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0))); } [StorableConstructor] protected SymbolicDataAnalysisExpressionTreeNativeInterpreter(StorableConstructorFlag _) : base(_) { } protected SymbolicDataAnalysisExpressionTreeNativeInterpreter(SymbolicDataAnalysisExpressionTreeNativeInterpreter original, Cloner cloner) : base(original, cloner) { } public override IDeepCloneable Clone(Cloner cloner) { return new SymbolicDataAnalysisExpressionTreeNativeInterpreter(this, cloner); } public static NativeInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func opCodeMapper, out List nodes) { if (cachedData == null || cachedDataset != dataset) { InitCache(dataset); } var root = tree.Root.GetSubtree(0).GetSubtree(0); var code = new NativeInstruction[root.GetLength()]; if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)"); code[0] = new NativeInstruction { narg = (ushort)root.SubtreeCount, opcode = opCodeMapper(root) }; int c = 1; nodes = (List)root.IterateNodesBreadth(); for (int i = 0; i < nodes.Count; ++i) { var node = nodes[i]; for (int j = 0; j < node.SubtreeCount; ++j) { var s = node.GetSubtree(j); if (s.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)"); code[c + j] = new NativeInstruction { narg = (ushort)s.SubtreeCount, opcode = opCodeMapper(s) }; } if (node is VariableTreeNode variable) { code[i].value = variable.Weight; code[i].data = cachedData[variable.VariableName].AddrOfPinnedObject(); } else if (node is ConstantTreeNode constant) { code[i].value = constant.Value; } code[i].childIndex = c; c += node.SubtreeCount; } return code; } private readonly object syncRoot = new object(); [ThreadStatic] private static Dictionary cachedData; [ThreadStatic] private static IDataset cachedDataset; private static readonly HashSet supportedOpCodes = new HashSet() { (byte)OpCode.Constant, (byte)OpCode.Variable, (byte)OpCode.Add, (byte)OpCode.Sub, (byte)OpCode.Mul, (byte)OpCode.Div, (byte)OpCode.Exp, (byte)OpCode.Log, (byte)OpCode.Sin, (byte)OpCode.Cos, (byte)OpCode.Tan, (byte)OpCode.Tanh, (byte)OpCode.Power, (byte)OpCode.Root, (byte)OpCode.SquareRoot, (byte)OpCode.Square, (byte)OpCode.CubeRoot, (byte)OpCode.Cube, (byte)OpCode.Absolute, (byte)OpCode.AnalyticQuotient }; public IEnumerable GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable rows) { if (!rows.Any()) return Enumerable.Empty(); byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) { var opCode = OpCodes.MapSymbolToOpCode(node); if (supportedOpCodes.Contains(opCode)) return opCode; else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}"); }; var code = Compile(tree, dataset, mapSupportedSymbols, out List nodes); var rowsArray = rows.ToArray(); var result = new double[rowsArray.Length]; NativeWrapper.GetValues(code, code.Length, rowsArray, rowsArray.Length, result); // when evaluation took place without any error, we can increment the counter lock (syncRoot) { EvaluatedSolutions++; } return result; } public static Dictionary OptimizeConstants(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable rows, string targetVariable, int iterations) { byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) { var opCode = OpCodes.MapSymbolToOpCode(node); if (supportedOpCodes.Contains(opCode)) return opCode; else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}"); }; var code = Compile(tree, dataset, mapSupportedSymbols, out List nodes); if (iterations > 0) { var target = dataset.GetDoubleValues(targetVariable, rows).ToArray(); var rowsArray = rows.ToArray(); var result = new double[rowsArray.Length]; NativeWrapper.GetValues(code, code.Length, rowsArray, rowsArray.Length, result, target, iterations); } return Enumerable.Range(0, code.Length).Where(i => nodes[i] is SymbolicExpressionTreeTerminalNode).ToDictionary(i => nodes[i], i => code[i].value); } private static void InitCache(IDataset dataset) { cachedDataset = dataset; cachedData = new Dictionary(); foreach (var v in dataset.DoubleVariables) { var values = dataset.GetDoubleValues(v).ToArray(); var gch = GCHandle.Alloc(values, GCHandleType.Pinned); cachedData[v] = gch; } } public void ClearState() { if (cachedData != null) { foreach (var gch in cachedData.Values) { gch.Free(); } cachedData = null; } cachedDataset = null; EvaluatedSolutions = 0; } public void InitializeState() { ClearState(); } } }