Free cookie consent management tool by TermsFeed Policy Generator

source: branches/1772_HeuristicLab.EvolutionTracking/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeNativeInterpreter.cs @ 17874

Last change on this file since 17874 was 17434, checked in by bburlacu, 5 years ago

#1772: Merge trunk changes and fix all errors and compilation warnings.

File size: 7.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.InteropServices;
26using HEAL.Attic;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31using HeuristicLab.Parameters;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
34  [StorableType("91723319-8F15-4D33-B277-40AC7C7CF9AE")]
35  [Item("SymbolicDataAnalysisExpressionTreeNativeInterpreter", "Operator calling into native C++ code for tree interpretation.")]
36  public class SymbolicDataAnalysisExpressionTreeNativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
38
39    #region parameters
40    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
41      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
42    }
43    #endregion
44
45    #region properties
46    public int EvaluatedSolutions {
47      get { return EvaluatedSolutionsParameter.Value.Value; }
48      set { EvaluatedSolutionsParameter.Value.Value = value; }
49    }
50    #endregion
51
52    public SymbolicDataAnalysisExpressionTreeNativeInterpreter() {
53      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
54    }
55
56    [StorableConstructor]
57    protected SymbolicDataAnalysisExpressionTreeNativeInterpreter(StorableConstructorFlag _) : base(_) { }
58
59    protected SymbolicDataAnalysisExpressionTreeNativeInterpreter(SymbolicDataAnalysisExpressionTreeNativeInterpreter original, Cloner cloner) : base(original, cloner) {
60    }
61
62    public override IDeepCloneable Clone(Cloner cloner) {
63      return new SymbolicDataAnalysisExpressionTreeNativeInterpreter(this, cloner);
64    }
65
66    public static NativeInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) {
67      if (cachedData == null || cachedDataset != dataset) {
68        InitCache(dataset);
69      }
70
71      var root = tree.Root.GetSubtree(0).GetSubtree(0);
72      var code = new NativeInstruction[root.GetLength()];
73      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
74      code[0] = new NativeInstruction { narg = (ushort)root.SubtreeCount, opcode = opCodeMapper(root) };
75      int c = 1;
76      nodes = (List<ISymbolicExpressionTreeNode>)root.IterateNodesBreadth();
77      for (int i = 0; i < nodes.Count; ++i) {
78        var node = nodes[i];
79        for (int j = 0; j < node.SubtreeCount; ++j) {
80          var s = node.GetSubtree(j);
81          if (s.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
82          code[c + j] = new NativeInstruction { narg = (ushort)s.SubtreeCount, opcode = opCodeMapper(s) };
83        }
84
85        if (node is VariableTreeNode variable) {
86          code[i].value = variable.Weight;
87          code[i].data = cachedData[variable.VariableName].AddrOfPinnedObject();
88        } else if (node is ConstantTreeNode constant) {
89          code[i].value = constant.Value;
90        }
91
92        code[i].childIndex = c;
93        c += node.SubtreeCount;
94      }
95      return code;
96    }
97
98    private readonly object syncRoot = new object();
99
100    [ThreadStatic]
101    private static Dictionary<string, GCHandle> cachedData;
102
103    [ThreadStatic]
104    private static IDataset cachedDataset;
105
106    private static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
107      (byte)OpCode.Constant,
108      (byte)OpCode.Variable,
109      (byte)OpCode.Add,
110      (byte)OpCode.Sub,
111      (byte)OpCode.Mul,
112      (byte)OpCode.Div,
113      (byte)OpCode.Exp,
114      (byte)OpCode.Log,
115      (byte)OpCode.Sin,
116      (byte)OpCode.Cos,
117      (byte)OpCode.Tan,
118      (byte)OpCode.Tanh,
119      (byte)OpCode.Power,
120      (byte)OpCode.Root,
121      (byte)OpCode.SquareRoot,
122      (byte)OpCode.Square,
123      (byte)OpCode.CubeRoot,
124      (byte)OpCode.Cube,
125      (byte)OpCode.Absolute,
126      (byte)OpCode.AnalyticQuotient
127    };
128
129    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
130      if (!rows.Any()) return Enumerable.Empty<double>();
131
132      byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {
133        var opCode = OpCodes.MapSymbolToOpCode(node);
134        if (supportedOpCodes.Contains(opCode)) return opCode;
135        else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
136      };
137      var code = Compile(tree, dataset, mapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
138
139      var rowsArray = rows.ToArray();
140      var result = new double[rowsArray.Length];
141      NativeWrapper.GetValues(code, code.Length, rowsArray, rowsArray.Length, result);
142
143      // when evaluation took place without any error, we can increment the counter
144      lock (syncRoot) {
145        EvaluatedSolutions++;
146      }
147
148      return result;
149    }
150
151    public static Dictionary<ISymbolicExpressionTreeNode, double> OptimizeConstants(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows, string targetVariable, int iterations) {
152      byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {
153        var opCode = OpCodes.MapSymbolToOpCode(node);
154        if (supportedOpCodes.Contains(opCode)) return opCode;
155        else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
156      };
157      var code = Compile(tree, dataset, mapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
158      if (iterations > 0) {
159        var target = dataset.GetDoubleValues(targetVariable, rows).ToArray();
160        var rowsArray = rows.ToArray();
161        var result = new double[rowsArray.Length];
162        NativeWrapper.GetValues(code, code.Length, rowsArray, rowsArray.Length, result, target, iterations);
163      }
164      return Enumerable.Range(0, code.Length).Where(i => nodes[i] is SymbolicExpressionTreeTerminalNode).ToDictionary(i => nodes[i], i => code[i].value);
165    }
166
167    private static void InitCache(IDataset dataset) {
168      cachedDataset = dataset;
169      cachedData = new Dictionary<string, GCHandle>();
170      foreach (var v in dataset.DoubleVariables) {
171        var values = dataset.GetDoubleValues(v).ToArray();
172        var gch = GCHandle.Alloc(values, GCHandleType.Pinned);
173        cachedData[v] = gch;
174      }
175    }
176
177    public void ClearState() {
178      if (cachedData != null) {
179        foreach (var gch in cachedData.Values) {
180          gch.Free();
181        }
182        cachedData = null;
183      }
184      cachedDataset = null;
185      EvaluatedSolutions = 0;
186    }
187
188    public void InitializeState() {
189      ClearState();
190    }
191  }
192}
Note: See TracBrowser for help on using the repository browser.