Changeset 18007 for branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/NativeInterpreter.cs
- Timestamp:
- 07/13/21 14:38:02 (3 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/NativeInterpreter.cs
r17989 r18007 38 38 [StorableType("91723319-8F15-4D33-B277-40AC7C7CF9AE")] 39 39 [Item("NativeInterpreter", "Operator calling into native C++ code for tree interpretation.")] 40 public class NativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {40 public sealed class NativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter { 41 41 private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions"; 42 42 … … 58 58 } 59 59 60 #region storable ctor and cloning 60 61 [StorableConstructor] 61 protected NativeInterpreter(StorableConstructorFlag _) : base(_) { } 62 63 protected NativeInterpreter(NativeInterpreter original, Cloner cloner) : base(original, cloner) { 64 } 65 62 private NativeInterpreter(StorableConstructorFlag _) : base(_) { } 66 63 public override IDeepCloneable Clone(Cloner cloner) { 67 64 return new NativeInterpreter(this, cloner); 68 65 } 66 67 private NativeInterpreter(NativeInterpreter original, Cloner cloner) : base(original, cloner) { } 68 #endregion 69 69 70 public static NativeInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) { 70 71 var root = tree.Root.GetSubtree(0).GetSubtree(0); … … 73 74 74 75 public static NativeInstruction[] Compile(ISymbolicExpressionTreeNode root, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) { 75 if (cachedData == null || cachedDataset != dataset ) {76 if (cachedData == null || cachedDataset != dataset || cachedDataset is ModifiableDataset) { 76 77 InitCache(dataset); 77 78 } 78 79 79 80 nodes = root.IterateNodesPrefix().ToList(); nodes.Reverse(); 80 81 var code = new NativeInstruction[nodes.Count]; 81 82 for (int i = 0; i < nodes.Count; ++i) { 83 var node = nodes[i]; 84 code[i] = new NativeInstruction { Arity = (ushort)node.SubtreeCount, OpCode = opCodeMapper(node), Length = (ushort)node.GetLength(), Optimize = true }; 85 86 if (node is VariableTreeNode variable) { 82 if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)"); 83 int i = code.Length - 1; 84 foreach (var n in root.IterateNodesPrefix()) { 85 code[i] = new NativeInstruction { Arity = (ushort)n.SubtreeCount, OpCode = opCodeMapper(n), Length = 1, Optimize = false }; 86 if (n is VariableTreeNode variable) { 87 87 code[i].Value = variable.Weight; 88 88 code[i].Data = cachedData[variable.VariableName].AddrOfPinnedObject(); 89 } else if (n odeis ConstantTreeNode constant) {89 } else if (n is ConstantTreeNode constant) { 90 90 code[i].Value = constant.Value; 91 91 } 92 --i; 92 93 } 94 // second pass to calculate lengths 95 for (i = 0; i < code.Length; i++) { 96 var c = i - 1; 97 for (int j = 0; j < code[i].Arity; ++j) { 98 code[i].Length += code[c].Length; 99 c -= code[c].Length; 100 } 101 } 102 93 103 return code; 104 } 105 106 public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) { 107 return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray()); 108 } 109 110 public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) { 111 if (!rows.Any()) return Enumerable.Empty<double>(); 112 113 byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) { 114 var opCode = OpCodes.MapSymbolToOpCode(node); 115 if (supportedOpCodes.Contains(opCode)) return opCode; 116 else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}"); 117 }; 118 var code = Compile(tree, dataset, mapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes); 119 120 var result = new double[rows.Length]; 121 var options = new SolverOptions { Iterations = 0 }; // Evaluate only. Do not optimize. 122 123 NativeWrapper.GetValues(code, rows, options, result, target: null, out var summary); 124 125 // when evaluation took place without any error, we can increment the counter 126 lock (syncRoot) { 127 EvaluatedSolutions++; 128 } 129 130 return result; 94 131 } 95 132 … … 102 139 private static IDataset cachedDataset; 103 140 104 pr otectedstatic readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {141 private static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() { 105 142 (byte)OpCode.Constant, 106 143 (byte)OpCode.Variable, … … 115 152 (byte)OpCode.Tan, 116 153 (byte)OpCode.Tanh, 117 (byte)OpCode.Power,118 (byte)OpCode.Root,154 // (byte)OpCode.Power, // these symbols are handled differently in the NativeInterpreter than in HL 155 // (byte)OpCode.Root, 119 156 (byte)OpCode.SquareRoot, 120 157 (byte)OpCode.Square, … … 125 162 }; 126 163 127 public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {128 return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());129 }130 131 164 private static void InitCache(IDataset dataset) { 132 165 cachedDataset = dataset; … … 159 192 ClearState(); 160 193 } 161 162 public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {163 if (!rows.Any()) return Enumerable.Empty<double>();164 165 byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {166 var opCode = OpCodes.MapSymbolToOpCode(node);167 if (supportedOpCodes.Contains(opCode)) return opCode;168 else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");169 };170 var code = Compile(tree, dataset, mapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);171 172 var result = new double[rows.Length];173 var options = new SolverOptions { /* not using any options here */ };174 175 var summary = new OptimizationSummary(); // also not used176 NativeWrapper.GetValues(code, rows, options, result, target: null, out summary);177 178 // when evaluation took place without any error, we can increment the counter179 lock (syncRoot) {180 EvaluatedSolutions++;181 }182 183 return result;184 }185 194 } 186 195 }
Note: See TracChangeset
for help on using the changeset viewer.