Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/13/21 14:38:02 (3 years ago)
Author:
gkronber
Message:

#3087: removed "strings-enums" in ParameterOptimizer and do not derive ParameterOptimizer from NativeInterpreter (+ renamed enum types in CeresTypes)

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/NativeInterpreter.cs

    r17989 r18007  
    3838  [StorableType("91723319-8F15-4D33-B277-40AC7C7CF9AE")]
    3939  [Item("NativeInterpreter", "Operator calling into native C++ code for tree interpretation.")]
    40   public class NativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
     40  public sealed class NativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
    4141    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
    4242
     
    5858    }
    5959
     60    #region storable ctor and cloning
    6061    [StorableConstructor]
    61     protected NativeInterpreter(StorableConstructorFlag _) : base(_) { }
    62 
    63     protected NativeInterpreter(NativeInterpreter original, Cloner cloner) : base(original, cloner) {
    64     }
    65 
     62    private NativeInterpreter(StorableConstructorFlag _) : base(_) { }
    6663    public override IDeepCloneable Clone(Cloner cloner) {
    6764      return new NativeInterpreter(this, cloner);
    6865    }
     66
     67    private NativeInterpreter(NativeInterpreter original, Cloner cloner) : base(original, cloner) { }
     68    #endregion
     69
    6970    public static NativeInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) {
    7071      var root = tree.Root.GetSubtree(0).GetSubtree(0);
     
    7374
    7475    public static NativeInstruction[] Compile(ISymbolicExpressionTreeNode root, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) {
    75       if (cachedData == null || cachedDataset != dataset) {
     76      if (cachedData == null || cachedDataset != dataset || cachedDataset is ModifiableDataset) {
    7677        InitCache(dataset);
    7778      }
    78 
     79     
    7980      nodes = root.IterateNodesPrefix().ToList(); nodes.Reverse();
    8081      var code = new NativeInstruction[nodes.Count];
    81 
    82       for (int i = 0; i < nodes.Count; ++i) {
    83         var node = nodes[i];
    84         code[i] = new NativeInstruction { Arity = (ushort)node.SubtreeCount, OpCode = opCodeMapper(node), Length = (ushort)node.GetLength(), Optimize = true };
    85 
    86         if (node is VariableTreeNode variable) {
     82      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
     83      int i = code.Length - 1;
     84      foreach (var n in root.IterateNodesPrefix()) {
     85        code[i] = new NativeInstruction { Arity = (ushort)n.SubtreeCount, OpCode = opCodeMapper(n), Length = 1, Optimize = false };
     86        if (n is VariableTreeNode variable) {
    8787          code[i].Value = variable.Weight;
    8888          code[i].Data = cachedData[variable.VariableName].AddrOfPinnedObject();
    89         } else if (node is ConstantTreeNode constant) {
     89        } else if (n is ConstantTreeNode constant) {
    9090          code[i].Value = constant.Value;
    9191        }
     92        --i;
    9293      }
     94      // second pass to calculate lengths
     95      for (i = 0; i < code.Length; i++) {
     96        var c = i - 1;
     97        for (int j = 0; j < code[i].Arity; ++j) {
     98          code[i].Length += code[c].Length;
     99          c -= code[c].Length;
     100        }
     101      }
     102
    93103      return code;
     104    }
     105
     106    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
     107      return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
     108    }
     109
     110    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
     111      if (!rows.Any()) return Enumerable.Empty<double>();
     112
     113      byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {
     114        var opCode = OpCodes.MapSymbolToOpCode(node);
     115        if (supportedOpCodes.Contains(opCode)) return opCode;
     116        else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
     117      };
     118      var code = Compile(tree, dataset, mapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
     119
     120      var result = new double[rows.Length];
     121      var options = new SolverOptions { Iterations = 0 }; // Evaluate only. Do not optimize.
     122
     123      NativeWrapper.GetValues(code, rows, options, result, target: null, out var summary);
     124
     125      // when evaluation took place without any error, we can increment the counter
     126      lock (syncRoot) {
     127        EvaluatedSolutions++;
     128      }
     129
     130      return result;
    94131    }
    95132
     
    102139    private static IDataset cachedDataset;
    103140
    104     protected static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
     141    private static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
    105142      (byte)OpCode.Constant,
    106143      (byte)OpCode.Variable,
     
    115152      (byte)OpCode.Tan,
    116153      (byte)OpCode.Tanh,
    117       (byte)OpCode.Power,
    118       (byte)OpCode.Root,
     154      // (byte)OpCode.Power, // these symbols are handled differently in the NativeInterpreter than in HL
     155      // (byte)OpCode.Root,
    119156      (byte)OpCode.SquareRoot,
    120157      (byte)OpCode.Square,
     
    125162    };
    126163
    127     public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
    128       return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
    129     }
    130    
    131164    private static void InitCache(IDataset dataset) {
    132165      cachedDataset = dataset;
     
    159192      ClearState();
    160193    }
    161 
    162     public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
    163       if (!rows.Any()) return Enumerable.Empty<double>();
    164 
    165       byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {
    166         var opCode = OpCodes.MapSymbolToOpCode(node);
    167         if (supportedOpCodes.Contains(opCode)) return opCode;
    168         else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
    169       };
    170       var code = Compile(tree, dataset, mapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
    171 
    172       var result = new double[rows.Length];
    173       var options = new SolverOptions { /* not using any options here */ };
    174 
    175       var summary = new OptimizationSummary(); // also not used
    176       NativeWrapper.GetValues(code, rows, options, result, target: null, out summary);
    177 
    178       // when evaluation took place without any error, we can increment the counter
    179       lock (syncRoot) {
    180         EvaluatedSolutions++;
    181       }
    182 
    183       return result;
    184     }
    185194  }
    186195}
Note: See TracChangeset for help on using the changeset viewer.