Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/NativeInterpreter.cs @ 17853

Last change on this file since 17853 was 17853, checked in by bburlacu, 3 years ago

#3087: Add accidentally omitted files.

File size: 6.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.InteropServices;
26
27using HEAL.Attic;
28
29using HeuristicLab.Common;
30using HeuristicLab.Core;
31using HeuristicLab.Data;
32using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
33using HeuristicLab.Parameters;
34using HeuristicLab.Problems.DataAnalysis;
35
36namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
37  [StorableType("91723319-8F15-4D33-B277-40AC7C7CF9AE")]
38  [Item("NativeInterpreter", "Operator calling into native C++ code for tree interpretation.")]
39  public class NativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
40    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
41
42    #region parameters
43    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
44      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
45    }
46    #endregion
47
48    #region properties
49    public int EvaluatedSolutions {
50      get { return EvaluatedSolutionsParameter.Value.Value; }
51      set { EvaluatedSolutionsParameter.Value.Value = value; }
52    }
53    #endregion
54
55    public NativeInterpreter() {
56      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
57    }
58
59    [StorableConstructor]
60    protected NativeInterpreter(StorableConstructorFlag _) : base(_) { }
61
62    protected NativeInterpreter(NativeInterpreter original, Cloner cloner) : base(original, cloner) {
63    }
64
65    public override IDeepCloneable Clone(Cloner cloner) {
66      return new NativeInterpreter(this, cloner);
67    }
68    public static NativeInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) {
69      var root = tree.Root.GetSubtree(0).GetSubtree(0);
70      return Compile(root, dataset, opCodeMapper, out nodes);
71    }
72
73    public static NativeInstruction[] Compile(ISymbolicExpressionTreeNode root, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) {
74      if (cachedData == null || cachedDataset != dataset) {
75        InitCache(dataset);
76      }
77
78      nodes = root.IterateNodesPrefix().ToList(); nodes.Reverse();
79      var code = new NativeInstruction[nodes.Count];
80
81      for (int i = 0; i < nodes.Count; ++i) {
82        var node = nodes[i];
83        code[i] = new NativeInstruction { Arity = (ushort)node.SubtreeCount, OpCode = opCodeMapper(node), Length = (ushort)node.GetLength(), Optimize = true };
84
85        if (node is VariableTreeNode variable) {
86          code[i].Value = variable.Weight;
87          code[i].Data = cachedData[variable.VariableName].AddrOfPinnedObject();
88        } else if (node is ConstantTreeNode constant) {
89          code[i].Value = constant.Value;
90        }
91      }
92      return code;
93    }
94
95    private readonly object syncRoot = new object();
96
97    [ThreadStatic]
98    private static Dictionary<string, GCHandle> cachedData;
99
100    [ThreadStatic]
101    private static IDataset cachedDataset;
102
103    protected static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
104      (byte)OpCode.Constant,
105      (byte)OpCode.Variable,
106      (byte)OpCode.Add,
107      (byte)OpCode.Sub,
108      (byte)OpCode.Mul,
109      (byte)OpCode.Div,
110      (byte)OpCode.Exp,
111      (byte)OpCode.Log,
112      (byte)OpCode.Sin,
113      (byte)OpCode.Cos,
114      (byte)OpCode.Tan,
115      (byte)OpCode.Tanh,
116      (byte)OpCode.Power,
117      (byte)OpCode.Root,
118      (byte)OpCode.SquareRoot,
119      (byte)OpCode.Square,
120      (byte)OpCode.CubeRoot,
121      (byte)OpCode.Cube,
122      (byte)OpCode.Absolute,
123      (byte)OpCode.AnalyticQuotient
124    };
125
126    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
127      return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
128    }
129   
130    private static void InitCache(IDataset dataset) {
131      cachedDataset = dataset;
132      // cache new data (but free old data first)
133      if (cachedData != null) {
134        foreach (var gch in cachedData.Values) {
135          gch.Free();
136        }
137      }
138      cachedData = new Dictionary<string, GCHandle>();
139      foreach (var v in dataset.DoubleVariables) {
140        var values = dataset.GetDoubleValues(v).ToArray();
141        var gch = GCHandle.Alloc(values, GCHandleType.Pinned);
142        cachedData[v] = gch;
143      }
144    }
145
146    public void ClearState() {
147      if (cachedData != null) {
148        foreach (var gch in cachedData.Values) {
149          gch.Free();
150        }
151        cachedData = null;
152      }
153      cachedDataset = null;
154      EvaluatedSolutions = 0;
155    }
156
157    public void InitializeState() {
158      ClearState();
159    }
160
161    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
162      if (!rows.Any()) return Enumerable.Empty<double>();
163
164      byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {
165        var opCode = OpCodes.MapSymbolToOpCode(node);
166        if (supportedOpCodes.Contains(opCode)) return opCode;
167        else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
168      };
169      var code = Compile(tree, dataset, mapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
170
171      var result = new double[rows.Length];
172      var options = new SolverOptions { /* not using any options here */ };
173
174      var summary = new OptimizationSummary(); // also not used
175      NativeWrapper.GetValues(code, rows, result, null, options, ref summary);
176
177      // when evaluation took place without any error, we can increment the counter
178      lock (syncRoot) {
179        EvaluatedSolutions++;
180      }
181
182      return result;
183    }
184  }
185}
Note: See TracBrowser for help on using the repository browser.