Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/NativeInterpreter.cs @ 18007

Last change on this file since 18007 was 18007, checked in by gkronber, 3 years ago

#3087: removed "strings-enums" in ParameterOptimizer and do not derive ParameterOptimizer from NativeInterpreter (+ renamed enum types in CeresTypes)

File size: 7.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.InteropServices;
26
27using HEAL.Attic;
28
29using HeuristicLab.Common;
30using HeuristicLab.Core;
31using HeuristicLab.Data;
32using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
33using HeuristicLab.NativeInterpreter;
34using HeuristicLab.Parameters;
35using HeuristicLab.Problems.DataAnalysis;
36
37namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
38  [StorableType("91723319-8F15-4D33-B277-40AC7C7CF9AE")]
39  [Item("NativeInterpreter", "Operator calling into native C++ code for tree interpretation.")]
40  public sealed class NativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
41    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
42
43    #region parameters
44    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
45      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
46    }
47    #endregion
48
49    #region properties
50    public int EvaluatedSolutions {
51      get { return EvaluatedSolutionsParameter.Value.Value; }
52      set { EvaluatedSolutionsParameter.Value.Value = value; }
53    }
54    #endregion
55
56    public NativeInterpreter() {
57      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
58    }
59
60    #region storable ctor and cloning
61    [StorableConstructor]
62    private NativeInterpreter(StorableConstructorFlag _) : base(_) { }
63    public override IDeepCloneable Clone(Cloner cloner) {
64      return new NativeInterpreter(this, cloner);
65    }
66
67    private NativeInterpreter(NativeInterpreter original, Cloner cloner) : base(original, cloner) { }
68    #endregion
69
70    public static NativeInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) {
71      var root = tree.Root.GetSubtree(0).GetSubtree(0);
72      return Compile(root, dataset, opCodeMapper, out nodes);
73    }
74
75    public static NativeInstruction[] Compile(ISymbolicExpressionTreeNode root, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) {
76      if (cachedData == null || cachedDataset != dataset || cachedDataset is ModifiableDataset) {
77        InitCache(dataset);
78      }
79     
80      nodes = root.IterateNodesPrefix().ToList(); nodes.Reverse();
81      var code = new NativeInstruction[nodes.Count];
82      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
83      int i = code.Length - 1;
84      foreach (var n in root.IterateNodesPrefix()) {
85        code[i] = new NativeInstruction { Arity = (ushort)n.SubtreeCount, OpCode = opCodeMapper(n), Length = 1, Optimize = false };
86        if (n is VariableTreeNode variable) {
87          code[i].Value = variable.Weight;
88          code[i].Data = cachedData[variable.VariableName].AddrOfPinnedObject();
89        } else if (n is ConstantTreeNode constant) {
90          code[i].Value = constant.Value;
91        }
92        --i;
93      }
94      // second pass to calculate lengths
95      for (i = 0; i < code.Length; i++) {
96        var c = i - 1;
97        for (int j = 0; j < code[i].Arity; ++j) {
98          code[i].Length += code[c].Length;
99          c -= code[c].Length;
100        }
101      }
102
103      return code;
104    }
105
106    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
107      return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
108    }
109
110    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
111      if (!rows.Any()) return Enumerable.Empty<double>();
112
113      byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {
114        var opCode = OpCodes.MapSymbolToOpCode(node);
115        if (supportedOpCodes.Contains(opCode)) return opCode;
116        else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
117      };
118      var code = Compile(tree, dataset, mapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
119
120      var result = new double[rows.Length];
121      var options = new SolverOptions { Iterations = 0 }; // Evaluate only. Do not optimize.
122
123      NativeWrapper.GetValues(code, rows, options, result, target: null, out var summary);
124
125      // when evaluation took place without any error, we can increment the counter
126      lock (syncRoot) {
127        EvaluatedSolutions++;
128      }
129
130      return result;
131    }
132
133    private readonly object syncRoot = new object();
134
135    [ThreadStatic]
136    private static Dictionary<string, GCHandle> cachedData;
137
138    [ThreadStatic]
139    private static IDataset cachedDataset;
140
141    private static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
142      (byte)OpCode.Constant,
143      (byte)OpCode.Variable,
144      (byte)OpCode.Add,
145      (byte)OpCode.Sub,
146      (byte)OpCode.Mul,
147      (byte)OpCode.Div,
148      (byte)OpCode.Exp,
149      (byte)OpCode.Log,
150      (byte)OpCode.Sin,
151      (byte)OpCode.Cos,
152      (byte)OpCode.Tan,
153      (byte)OpCode.Tanh,
154      // (byte)OpCode.Power, // these symbols are handled differently in the NativeInterpreter than in HL
155      // (byte)OpCode.Root,
156      (byte)OpCode.SquareRoot,
157      (byte)OpCode.Square,
158      (byte)OpCode.CubeRoot,
159      (byte)OpCode.Cube,
160      (byte)OpCode.Absolute,
161      (byte)OpCode.AnalyticQuotient
162    };
163
164    private static void InitCache(IDataset dataset) {
165      cachedDataset = dataset;
166      // cache new data (but free old data first)
167      if (cachedData != null) {
168        foreach (var gch in cachedData.Values) {
169          gch.Free();
170        }
171      }
172      cachedData = new Dictionary<string, GCHandle>();
173      foreach (var v in dataset.DoubleVariables) {
174        var values = dataset.GetDoubleValues(v).ToArray();
175        var gch = GCHandle.Alloc(values, GCHandleType.Pinned);
176        cachedData[v] = gch;
177      }
178    }
179
180    public void ClearState() {
181      if (cachedData != null) {
182        foreach (var gch in cachedData.Values) {
183          gch.Free();
184        }
185        cachedData = null;
186      }
187      cachedDataset = null;
188      EvaluatedSolutions = 0;
189    }
190
191    public void InitializeState() {
192      ClearState();
193    }
194  }
195}
Note: See TracBrowser for help on using the repository browser.