Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/NativeInterpreter.cs @ 17989

Last change on this file since 17989 was 17989, checked in by gkronber, 3 years ago

#3087: updated native dlls for NativeInterpreter to a version that runs on Hive infrastructure. Some smaller changes because of deviations in the independently developed implementations (in particular enum types).

File size: 6.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.InteropServices;
26
27using HEAL.Attic;
28
29using HeuristicLab.Common;
30using HeuristicLab.Core;
31using HeuristicLab.Data;
32using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
33using HeuristicLab.NativeInterpreter;
34using HeuristicLab.Parameters;
35using HeuristicLab.Problems.DataAnalysis;
36
37namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
38  [StorableType("91723319-8F15-4D33-B277-40AC7C7CF9AE")]
39  [Item("NativeInterpreter", "Operator calling into native C++ code for tree interpretation.")]
40  public class NativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
41    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
42
43    #region parameters
44    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
45      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
46    }
47    #endregion
48
49    #region properties
50    public int EvaluatedSolutions {
51      get { return EvaluatedSolutionsParameter.Value.Value; }
52      set { EvaluatedSolutionsParameter.Value.Value = value; }
53    }
54    #endregion
55
56    public NativeInterpreter() {
57      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
58    }
59
60    [StorableConstructor]
61    protected NativeInterpreter(StorableConstructorFlag _) : base(_) { }
62
63    protected NativeInterpreter(NativeInterpreter original, Cloner cloner) : base(original, cloner) {
64    }
65
66    public override IDeepCloneable Clone(Cloner cloner) {
67      return new NativeInterpreter(this, cloner);
68    }
69    public static NativeInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) {
70      var root = tree.Root.GetSubtree(0).GetSubtree(0);
71      return Compile(root, dataset, opCodeMapper, out nodes);
72    }
73
74    public static NativeInstruction[] Compile(ISymbolicExpressionTreeNode root, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) {
75      if (cachedData == null || cachedDataset != dataset) {
76        InitCache(dataset);
77      }
78
79      nodes = root.IterateNodesPrefix().ToList(); nodes.Reverse();
80      var code = new NativeInstruction[nodes.Count];
81
82      for (int i = 0; i < nodes.Count; ++i) {
83        var node = nodes[i];
84        code[i] = new NativeInstruction { Arity = (ushort)node.SubtreeCount, OpCode = opCodeMapper(node), Length = (ushort)node.GetLength(), Optimize = true };
85
86        if (node is VariableTreeNode variable) {
87          code[i].Value = variable.Weight;
88          code[i].Data = cachedData[variable.VariableName].AddrOfPinnedObject();
89        } else if (node is ConstantTreeNode constant) {
90          code[i].Value = constant.Value;
91        }
92      }
93      return code;
94    }
95
96    private readonly object syncRoot = new object();
97
98    [ThreadStatic]
99    private static Dictionary<string, GCHandle> cachedData;
100
101    [ThreadStatic]
102    private static IDataset cachedDataset;
103
104    protected static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
105      (byte)OpCode.Constant,
106      (byte)OpCode.Variable,
107      (byte)OpCode.Add,
108      (byte)OpCode.Sub,
109      (byte)OpCode.Mul,
110      (byte)OpCode.Div,
111      (byte)OpCode.Exp,
112      (byte)OpCode.Log,
113      (byte)OpCode.Sin,
114      (byte)OpCode.Cos,
115      (byte)OpCode.Tan,
116      (byte)OpCode.Tanh,
117      (byte)OpCode.Power,
118      (byte)OpCode.Root,
119      (byte)OpCode.SquareRoot,
120      (byte)OpCode.Square,
121      (byte)OpCode.CubeRoot,
122      (byte)OpCode.Cube,
123      (byte)OpCode.Absolute,
124      (byte)OpCode.AnalyticQuotient
125    };
126
127    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
128      return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
129    }
130   
131    private static void InitCache(IDataset dataset) {
132      cachedDataset = dataset;
133      // cache new data (but free old data first)
134      if (cachedData != null) {
135        foreach (var gch in cachedData.Values) {
136          gch.Free();
137        }
138      }
139      cachedData = new Dictionary<string, GCHandle>();
140      foreach (var v in dataset.DoubleVariables) {
141        var values = dataset.GetDoubleValues(v).ToArray();
142        var gch = GCHandle.Alloc(values, GCHandleType.Pinned);
143        cachedData[v] = gch;
144      }
145    }
146
147    public void ClearState() {
148      if (cachedData != null) {
149        foreach (var gch in cachedData.Values) {
150          gch.Free();
151        }
152        cachedData = null;
153      }
154      cachedDataset = null;
155      EvaluatedSolutions = 0;
156    }
157
158    public void InitializeState() {
159      ClearState();
160    }
161
162    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
163      if (!rows.Any()) return Enumerable.Empty<double>();
164
165      byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {
166        var opCode = OpCodes.MapSymbolToOpCode(node);
167        if (supportedOpCodes.Contains(opCode)) return opCode;
168        else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
169      };
170      var code = Compile(tree, dataset, mapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
171
172      var result = new double[rows.Length];
173      var options = new SolverOptions { /* not using any options here */ };
174
175      var summary = new OptimizationSummary(); // also not used
176      NativeWrapper.GetValues(code, rows, options, result, target: null, out summary);
177
178      // when evaluation took place without any error, we can increment the counter
179      lock (syncRoot) {
180        EvaluatedSolutions++;
181      }
182
183      return result;
184    }
185  }
186}
Note: See TracBrowser for help on using the repository browser.