Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeNativeInterpreter.cs

Last change on this file was 18220, checked in by gkronber, 2 years ago

#3136: reintegrated structure-template GP branch into trunk

File size: 6.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.InteropServices;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HEAL.Attic;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
34  [StorableType("91723319-8F15-4D33-B277-40AC7C7CF9AE")]
35  [Item("SymbolicDataAnalysisExpressionTreeNativeInterpreter", "An interpreter that wraps a native dll")]
36  public class SymbolicDataAnalysisExpressionTreeNativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
38
39    #region parameters
40    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
41      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
42    }
43    #endregion
44
45    #region properties
46    public int EvaluatedSolutions {
47      get { return EvaluatedSolutionsParameter.Value.Value; }
48      set { EvaluatedSolutionsParameter.Value.Value = value; }
49    }
50    #endregion
51
52    public void ClearState() { }
53
54    public SymbolicDataAnalysisExpressionTreeNativeInterpreter() {
55      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
56    }
57
58    [StorableConstructor]
59    protected SymbolicDataAnalysisExpressionTreeNativeInterpreter(StorableConstructorFlag _) : base(_) { }
60
61    protected SymbolicDataAnalysisExpressionTreeNativeInterpreter(SymbolicDataAnalysisExpressionTreeNativeInterpreter original, Cloner cloner) : base(original, cloner) {
62    }
63
64    public override IDeepCloneable Clone(Cloner cloner) {
65      return new SymbolicDataAnalysisExpressionTreeNativeInterpreter(this, cloner);
66    }
67
68    private NativeInstruction[] Compile(ISymbolicExpressionTree tree, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper) {
69      var root = tree.Root.GetSubtree(0).GetSubtree(0);
70      var code = new NativeInstruction[root.GetLength()];
71      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
72      code[0] = new NativeInstruction { narg = (ushort)root.SubtreeCount, opcode = opCodeMapper(root) };
73      int c = 1, i = 0;
74      foreach (var node in root.IterateNodesBreadth()) {
75        for (int j = 0; j < node.SubtreeCount; ++j) {
76          var s = node.GetSubtree(j);
77          if (s.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
78          code[c + j] = new NativeInstruction { narg = (ushort)s.SubtreeCount, opcode = opCodeMapper(s) };
79        }
80
81        if (node is VariableTreeNode variable) {
82          code[i].weight = variable.Weight;
83          code[i].data = cachedData[variable.VariableName].AddrOfPinnedObject();
84        } else if (node is INumericTreeNode numeric) {
85          code[i].value = numeric.Value;
86        }
87
88        code[i].childIndex = c;
89        c += node.SubtreeCount;
90        ++i;
91      }
92      return code;
93    }
94
95    private readonly object syncRoot = new object();
96
97    [ThreadStatic]
98    private static Dictionary<string, GCHandle> cachedData;
99
100    [ThreadStatic]
101    private static IDataset cachedDataset;
102
103    private static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
104      (byte)OpCode.Constant,
105      (byte)OpCode.Number,
106      (byte)OpCode.Variable,
107      (byte)OpCode.Add,
108      (byte)OpCode.Sub,
109      (byte)OpCode.Mul,
110      (byte)OpCode.Div,
111      (byte)OpCode.Exp,
112      (byte)OpCode.Log,
113      (byte)OpCode.Sin,
114      (byte)OpCode.Cos,
115      (byte)OpCode.Tan,
116      (byte)OpCode.Tanh,
117      (byte)OpCode.Power,
118      (byte)OpCode.Root,
119      (byte)OpCode.SquareRoot,
120      (byte)OpCode.Square,
121      (byte)OpCode.CubeRoot,
122      (byte)OpCode.Cube,
123      (byte)OpCode.Absolute,
124      (byte)OpCode.AnalyticQuotient,
125      (byte)OpCode.SubFunction
126    };
127
128    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
129      if (!rows.Any()) return Enumerable.Empty<double>();
130
131      if (cachedData == null || cachedDataset != dataset || cachedDataset is ModifiableDataset) {
132        InitCache(dataset);
133      }
134
135      byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {       
136        var opCode = OpCodes.MapSymbolToOpCode(node);
137        if (supportedOpCodes.Contains(opCode)) return opCode;
138        else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
139      };
140      var code = Compile(tree, mapSupportedSymbols);
141
142      var rowsArray = rows.ToArray();
143      var result = new double[rowsArray.Length];
144
145      NativeWrapper.GetValuesVectorized(code, code.Length, rowsArray, rowsArray.Length, result);
146
147      // when evaluation took place without any error, we can increment the counter
148      lock (syncRoot) {
149        EvaluatedSolutions++;
150      }
151
152      return result;
153    }
154
155    private void InitCache(IDataset dataset) {
156      cachedDataset = dataset;
157
158      // free handles to old data
159      if (cachedData != null) {
160        foreach (var gch in cachedData.Values) {
161          gch.Free();
162        }
163        cachedData = null;
164      }
165
166      // cache new data
167      cachedData = new Dictionary<string, GCHandle>();
168      foreach (var v in dataset.DoubleVariables) {
169        var values = dataset.GetDoubleValues(v).ToArray();
170        var gch = GCHandle.Alloc(values, GCHandleType.Pinned);
171        cachedData[v] = gch;
172      }
173    }
174
175    public void InitializeState() {
176      if (cachedData != null) {
177        foreach (var gch in cachedData.Values) {
178          gch.Free();
179        }
180        cachedData = null;
181      }
182      cachedDataset = null;
183      EvaluatedSolutions = 0;
184    }
185  }
186}
Note: See TracBrowser for help on using the repository browser.