Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeNativeInterpreter.cs @ 17801

Last change on this file since 17801 was 17801, checked in by gkronber, 4 years ago

#3084 updated interpreters to always invalidate their cached dataset when the cached dataset is a ModifiableDataset (as in the PDP)

File size: 6.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.InteropServices;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HEAL.Attic;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
34  [StorableType("91723319-8F15-4D33-B277-40AC7C7CF9AE")]
35  [Item("SymbolicDataAnalysisExpressionTreeNativeInterpreter", "An interpreter that wraps a native dll")]
36  public class SymbolicDataAnalysisExpressionTreeNativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
38
39    #region parameters
40    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
41      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
42    }
43    #endregion
44
45    #region properties
46    public int EvaluatedSolutions {
47      get { return EvaluatedSolutionsParameter.Value.Value; }
48      set { EvaluatedSolutionsParameter.Value.Value = value; }
49    }
50    #endregion
51
52    public void ClearState() { }
53
54    public SymbolicDataAnalysisExpressionTreeNativeInterpreter() {
55      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
56    }
57
58    [StorableConstructor]
59    protected SymbolicDataAnalysisExpressionTreeNativeInterpreter(StorableConstructorFlag _) : base(_) { }
60
61    protected SymbolicDataAnalysisExpressionTreeNativeInterpreter(SymbolicDataAnalysisExpressionTreeNativeInterpreter original, Cloner cloner) : base(original, cloner) {
62    }
63
64    public override IDeepCloneable Clone(Cloner cloner) {
65      return new SymbolicDataAnalysisExpressionTreeNativeInterpreter(this, cloner);
66    }
67
68    private NativeInstruction[] Compile(ISymbolicExpressionTree tree, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper) {
69      var root = tree.Root.GetSubtree(0).GetSubtree(0);
70      var code = new NativeInstruction[root.GetLength()];
71      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
72      code[0] = new NativeInstruction { narg = (ushort)root.SubtreeCount, opcode = opCodeMapper(root) };
73      int c = 1, i = 0;
74      foreach (var node in root.IterateNodesBreadth()) {
75        for (int j = 0; j < node.SubtreeCount; ++j) {
76          var s = node.GetSubtree(j);
77          if (s.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
78          code[c + j] = new NativeInstruction { narg = (ushort)s.SubtreeCount, opcode = opCodeMapper(s) };
79        }
80
81        if (node is VariableTreeNode variable) {
82          code[i].weight = variable.Weight;
83          code[i].data = cachedData[variable.VariableName].AddrOfPinnedObject();
84        } else if (node is ConstantTreeNode constant) {
85          code[i].value = constant.Value;
86        }
87
88        code[i].childIndex = c;
89        c += node.SubtreeCount;
90        ++i;
91      }
92      return code;
93    }
94
95    private readonly object syncRoot = new object();
96
97    [ThreadStatic]
98    private static Dictionary<string, GCHandle> cachedData;
99
100    [ThreadStatic]
101    private static IDataset cachedDataset;
102
103    private static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
104      (byte)OpCode.Constant,
105      (byte)OpCode.Variable,
106      (byte)OpCode.Add,
107      (byte)OpCode.Sub,
108      (byte)OpCode.Mul,
109      (byte)OpCode.Div,
110      (byte)OpCode.Exp,
111      (byte)OpCode.Log,
112      (byte)OpCode.Sin,
113      (byte)OpCode.Cos,
114      (byte)OpCode.Tan,
115      (byte)OpCode.Tanh,
116      (byte)OpCode.Power,
117      (byte)OpCode.Root,
118      (byte)OpCode.SquareRoot,
119      (byte)OpCode.Square,
120      (byte)OpCode.CubeRoot,
121      (byte)OpCode.Cube,
122      (byte)OpCode.Absolute,
123      (byte)OpCode.AnalyticQuotient
124    };
125
126    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
127      if (!rows.Any()) return Enumerable.Empty<double>();
128
129      if (cachedData == null || cachedDataset != dataset || cachedDataset is ModifiableDataset) {
130        InitCache(dataset);
131      }
132
133      byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {       
134        var opCode = OpCodes.MapSymbolToOpCode(node);
135        if (supportedOpCodes.Contains(opCode)) return opCode;
136        else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
137      };
138      var code = Compile(tree, mapSupportedSymbols);
139
140      var rowsArray = rows.ToArray();
141      var result = new double[rowsArray.Length];
142
143      NativeWrapper.GetValuesVectorized(code, code.Length, rowsArray, rowsArray.Length, result);
144
145      // when evaluation took place without any error, we can increment the counter
146      lock (syncRoot) {
147        EvaluatedSolutions++;
148      }
149
150      return result;
151    }
152
153    private void InitCache(IDataset dataset) {
154      cachedDataset = dataset;
155
156      // free handles to old data
157      if (cachedData != null) {
158        foreach (var gch in cachedData.Values) {
159          gch.Free();
160        }
161        cachedData = null;
162      }
163
164      // cache new data
165      cachedData = new Dictionary<string, GCHandle>();
166      foreach (var v in dataset.DoubleVariables) {
167        var values = dataset.GetDoubleValues(v).ToArray();
168        var gch = GCHandle.Alloc(values, GCHandleType.Pinned);
169        cachedData[v] = gch;
170      }
171    }
172
173    public void InitializeState() {
174      if (cachedData != null) {
175        foreach (var gch in cachedData.Values) {
176          gch.Free();
177        }
178        cachedData = null;
179      }
180      cachedDataset = null;
181      EvaluatedSolutions = 0;
182    }
183  }
184}
Note: See TracBrowser for help on using the repository browser.