Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3136_Structural_GP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeNativeInterpreter.cs @ 18146

Last change on this file since 18146 was 18146, checked in by mkommend, 2 years ago

#3136: Merged trunk changes into branch.

File size: 6.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.InteropServices;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HEAL.Attic;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
34  [StorableType("91723319-8F15-4D33-B277-40AC7C7CF9AE")]
35  [Item("SymbolicDataAnalysisExpressionTreeNativeInterpreter", "An interpreter that wraps a native dll")]
36  public class SymbolicDataAnalysisExpressionTreeNativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
38
39    #region parameters
40    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
41      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
42    }
43    #endregion
44
45    #region properties
46    public int EvaluatedSolutions {
47      get { return EvaluatedSolutionsParameter.Value.Value; }
48      set { EvaluatedSolutionsParameter.Value.Value = value; }
49    }
50    #endregion
51
52    public void ClearState() { }
53
54    public SymbolicDataAnalysisExpressionTreeNativeInterpreter() {
55      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
56    }
57
58    [StorableConstructor]
59    protected SymbolicDataAnalysisExpressionTreeNativeInterpreter(StorableConstructorFlag _) : base(_) { }
60
61    protected SymbolicDataAnalysisExpressionTreeNativeInterpreter(SymbolicDataAnalysisExpressionTreeNativeInterpreter original, Cloner cloner) : base(original, cloner) {
62    }
63
64    public override IDeepCloneable Clone(Cloner cloner) {
65      return new SymbolicDataAnalysisExpressionTreeNativeInterpreter(this, cloner);
66    }
67
68    private NativeInstruction[] Compile(ISymbolicExpressionTree tree, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper) {
69      var root = tree.Root.GetSubtree(0).GetSubtree(0);
70      var code = new NativeInstruction[root.GetLength()];
71      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
72      code[0] = new NativeInstruction { narg = (ushort)root.SubtreeCount, opcode = opCodeMapper(root) };
73      int c = 1, i = 0;
74      foreach (var node in root.IterateNodesBreadth()) {
75        for (int j = 0; j < node.SubtreeCount; ++j) {
76          var s = node.GetSubtree(j);
77          if (s.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
78          code[c + j] = new NativeInstruction { narg = (ushort)s.SubtreeCount, opcode = opCodeMapper(s) };
79        }
80
81        if (node is VariableTreeNode variable) {
82          code[i].weight = variable.Weight;
83          code[i].data = cachedData[variable.VariableName].AddrOfPinnedObject();
84        } else if (node is INumericTreeNode numeric) {
85          code[i].value = numeric.Value;
86        }
87
88        code[i].childIndex = c;
89        c += node.SubtreeCount;
90        ++i;
91      }
92      return code;
93    }
94
95    private readonly object syncRoot = new object();
96
97    [ThreadStatic]
98    private static Dictionary<string, GCHandle> cachedData;
99
100    [ThreadStatic]
101    private static IDataset cachedDataset;
102
103    private static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
104      (byte)OpCode.Constant,
105      (byte)OpCode.Variable,
106      (byte)OpCode.Add,
107      (byte)OpCode.Sub,
108      (byte)OpCode.Mul,
109      (byte)OpCode.Div,
110      (byte)OpCode.Exp,
111      (byte)OpCode.Log,
112      (byte)OpCode.Sin,
113      (byte)OpCode.Cos,
114      (byte)OpCode.Tan,
115      (byte)OpCode.Tanh,
116      (byte)OpCode.Power,
117      (byte)OpCode.Root,
118      (byte)OpCode.SquareRoot,
119      (byte)OpCode.Square,
120      (byte)OpCode.CubeRoot,
121      (byte)OpCode.Cube,
122      (byte)OpCode.Absolute,
123      (byte)OpCode.AnalyticQuotient,
124      (byte)OpCode.SubFunction
125    };
126
127    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
128      if (!rows.Any()) return Enumerable.Empty<double>();
129
130      if (cachedData == null || cachedDataset != dataset || cachedDataset is ModifiableDataset) {
131        InitCache(dataset);
132      }
133
134      byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {       
135        var opCode = OpCodes.MapSymbolToOpCode(node);
136        if (supportedOpCodes.Contains(opCode)) return opCode;
137        else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
138      };
139      var code = Compile(tree, mapSupportedSymbols);
140
141      var rowsArray = rows.ToArray();
142      var result = new double[rowsArray.Length];
143
144      NativeWrapper.GetValuesVectorized(code, code.Length, rowsArray, rowsArray.Length, result);
145
146      // when evaluation took place without any error, we can increment the counter
147      lock (syncRoot) {
148        EvaluatedSolutions++;
149      }
150
151      return result;
152    }
153
154    private void InitCache(IDataset dataset) {
155      cachedDataset = dataset;
156
157      // free handles to old data
158      if (cachedData != null) {
159        foreach (var gch in cachedData.Values) {
160          gch.Free();
161        }
162        cachedData = null;
163      }
164
165      // cache new data
166      cachedData = new Dictionary<string, GCHandle>();
167      foreach (var v in dataset.DoubleVariables) {
168        var values = dataset.GetDoubleValues(v).ToArray();
169        var gch = GCHandle.Alloc(values, GCHandleType.Pinned);
170        cachedData[v] = gch;
171      }
172    }
173
174    public void InitializeState() {
175      if (cachedData != null) {
176        foreach (var gch in cachedData.Values) {
177          gch.Free();
178        }
179        cachedData = null;
180      }
181      cachedDataset = null;
182      EvaluatedSolutions = 0;
183    }
184  }
185}
Note: See TracBrowser for help on using the repository browser.