Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeNativeInterpreter.cs @ 17989

Last change on this file since 17989 was 17989, checked in by gkronber, 3 years ago

#3087: updated native dlls for NativeInterpreter to a version that runs on Hive infrastructure. Some smaller changes because of deviations in the independently developed implementations (in particular enum types).

File size: 6.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.InteropServices;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.NativeInterpreter;
31using HeuristicLab.Parameters;
32using HEAL.Attic;
33
34namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
35  [StorableType("91723319-8F15-4D33-B277-40AC7C7CF9AE")]
36  [Item("SymbolicDataAnalysisExpressionTreeNativeInterpreter", "An interpreter that wraps a native dll")]
37  public class SymbolicDataAnalysisExpressionTreeNativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
38    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
39
40    #region parameters
41    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
42      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
43    }
44    #endregion
45
46    #region properties
47    public int EvaluatedSolutions {
48      get { return EvaluatedSolutionsParameter.Value.Value; }
49      set { EvaluatedSolutionsParameter.Value.Value = value; }
50    }
51    #endregion
52
53    public void ClearState() { }
54
55    public SymbolicDataAnalysisExpressionTreeNativeInterpreter() {
56      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
57    }
58
59    [StorableConstructor]
60    protected SymbolicDataAnalysisExpressionTreeNativeInterpreter(StorableConstructorFlag _) : base(_) { }
61
62    protected SymbolicDataAnalysisExpressionTreeNativeInterpreter(SymbolicDataAnalysisExpressionTreeNativeInterpreter original, Cloner cloner) : base(original, cloner) {
63    }
64
65    public override IDeepCloneable Clone(Cloner cloner) {
66      return new SymbolicDataAnalysisExpressionTreeNativeInterpreter(this, cloner);
67    }
68
69    private NativeInstruction[] Compile(ISymbolicExpressionTree tree, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper) {
70      var root = tree.Root.GetSubtree(0).GetSubtree(0);
71      var code = new NativeInstruction[root.GetLength()];
72      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
73      int i = code.Length - 1;
74      foreach (var n in root.IterateNodesPrefix()) {
75        code[i] = new NativeInstruction { Arity = (ushort)n.SubtreeCount, OpCode = opCodeMapper(n), Length = 1, Optimize = false };
76        if (n is VariableTreeNode variable) {
77          code[i].Value = variable.Weight;
78          code[i].Data = cachedData[variable.VariableName].AddrOfPinnedObject();
79        } else if (n is ConstantTreeNode constant) {
80          code[i].Value = constant.Value;
81        }
82        --i;
83      }
84      // second pass to calculate lengths
85      for (i = 0; i < code.Length; i++) {
86        var c = i - 1;
87        for (int j = 0; j < code[i].Arity; ++j) {
88          code[i].Length += code[c].Length;
89          c -= code[c].Length;
90        }
91      }
92
93      return code;
94    }
95
96    private readonly object syncRoot = new object();
97
98    [ThreadStatic]
99    private static Dictionary<string, GCHandle> cachedData;
100
101    [ThreadStatic]
102    private static IDataset cachedDataset;
103
104    private static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
105      (byte)OpCode.Constant,
106      (byte)OpCode.Variable,
107      (byte)OpCode.Add,
108      (byte)OpCode.Sub,
109      (byte)OpCode.Mul,
110      (byte)OpCode.Div,
111      (byte)OpCode.Exp,
112      (byte)OpCode.Log,
113      (byte)OpCode.Sin,
114      (byte)OpCode.Cos,
115      (byte)OpCode.Tan,
116      (byte)OpCode.Tanh,
117      // (byte)OpCode.Power,
118      // (byte)OpCode.Root,
119      (byte)OpCode.SquareRoot,
120      (byte)OpCode.Square,
121      (byte)OpCode.CubeRoot,
122      (byte)OpCode.Cube,
123      (byte)OpCode.Absolute,
124      (byte)OpCode.AnalyticQuotient
125    };
126
127    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
128      if (!rows.Any()) return Enumerable.Empty<double>();
129
130      if (cachedData == null || cachedDataset != dataset || cachedDataset is ModifiableDataset) {
131        InitCache(dataset);
132      }
133
134      byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {
135        var opCode = OpCodes.MapSymbolToOpCode(node);
136        if (supportedOpCodes.Contains(opCode)) return opCode;
137        else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
138      };
139      var code = Compile(tree, mapSupportedSymbols);
140
141      var rowsArray = rows.ToArray();
142      var result = new double[rowsArray.Length];
143      // prevent optimization of parameters
144      var options = new SolverOptions {
145        Iterations = 0
146      };
147      NativeWrapper.GetValues(code, rowsArray, options, result, target: null, optSummary: out var optSummary); // target is only used when optimizing parameters
148
149      // when evaluation took place without any error, we can increment the counter
150      lock (syncRoot) {
151        EvaluatedSolutions++;
152      }
153
154      return result;
155    }
156
157    private void InitCache(IDataset dataset) {
158      cachedDataset = dataset;
159
160      // free handles to old data
161      if (cachedData != null) {
162        foreach (var gch in cachedData.Values) {
163          gch.Free();
164        }
165        cachedData = null;
166      }
167
168      // cache new data
169      cachedData = new Dictionary<string, GCHandle>();
170      foreach (var v in dataset.DoubleVariables) {
171        var values = dataset.GetDoubleValues(v).ToArray();
172        var gch = GCHandle.Alloc(values, GCHandleType.Pinned);
173        cachedData[v] = gch;
174      }
175    }
176
177    public void InitializeState() {
178      if (cachedData != null) {
179        foreach (var gch in cachedData.Values) {
180          gch.Free();
181        }
182        cachedData = null;
183      }
184      cachedDataset = null;
185      EvaluatedSolutions = 0;
186    }
187  }
188}
Note: See TracBrowser for help on using the repository browser.