Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/13/21 14:38:02 (3 years ago)
Author:
gkronber
Message:

#3087: removed "strings-enums" in ParameterOptimizer and do not derive ParameterOptimizer from NativeInterpreter (+ renamed enum types in CeresTypes)

Location:
branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter
Files:
1 deleted
2 edited

Legend:

Unmodified
Added
Removed
  • branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/NativeInterpreter.cs

    r17989 r18007  
    3838  [StorableType("91723319-8F15-4D33-B277-40AC7C7CF9AE")]
    3939  [Item("NativeInterpreter", "Operator calling into native C++ code for tree interpretation.")]
    40   public class NativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
     40  public sealed class NativeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
    4141    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
    4242
     
    5858    }
    5959
     60    #region storable ctor and cloning
    6061    [StorableConstructor]
    61     protected NativeInterpreter(StorableConstructorFlag _) : base(_) { }
    62 
    63     protected NativeInterpreter(NativeInterpreter original, Cloner cloner) : base(original, cloner) {
    64     }
    65 
     62    private NativeInterpreter(StorableConstructorFlag _) : base(_) { }
    6663    public override IDeepCloneable Clone(Cloner cloner) {
    6764      return new NativeInterpreter(this, cloner);
    6865    }
     66
     67    private NativeInterpreter(NativeInterpreter original, Cloner cloner) : base(original, cloner) { }
     68    #endregion
     69
    6970    public static NativeInstruction[] Compile(ISymbolicExpressionTree tree, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) {
    7071      var root = tree.Root.GetSubtree(0).GetSubtree(0);
     
    7374
    7475    public static NativeInstruction[] Compile(ISymbolicExpressionTreeNode root, IDataset dataset, Func<ISymbolicExpressionTreeNode, byte> opCodeMapper, out List<ISymbolicExpressionTreeNode> nodes) {
    75       if (cachedData == null || cachedDataset != dataset) {
     76      if (cachedData == null || cachedDataset != dataset || cachedDataset is ModifiableDataset) {
    7677        InitCache(dataset);
    7778      }
    78 
     79     
    7980      nodes = root.IterateNodesPrefix().ToList(); nodes.Reverse();
    8081      var code = new NativeInstruction[nodes.Count];
    81 
    82       for (int i = 0; i < nodes.Count; ++i) {
    83         var node = nodes[i];
    84         code[i] = new NativeInstruction { Arity = (ushort)node.SubtreeCount, OpCode = opCodeMapper(node), Length = (ushort)node.GetLength(), Optimize = true };
    85 
    86         if (node is VariableTreeNode variable) {
     82      if (root.SubtreeCount > ushort.MaxValue) throw new ArgumentException("Number of subtrees is too big (>65.535)");
     83      int i = code.Length - 1;
     84      foreach (var n in root.IterateNodesPrefix()) {
     85        code[i] = new NativeInstruction { Arity = (ushort)n.SubtreeCount, OpCode = opCodeMapper(n), Length = 1, Optimize = false };
     86        if (n is VariableTreeNode variable) {
    8787          code[i].Value = variable.Weight;
    8888          code[i].Data = cachedData[variable.VariableName].AddrOfPinnedObject();
    89         } else if (node is ConstantTreeNode constant) {
     89        } else if (n is ConstantTreeNode constant) {
    9090          code[i].Value = constant.Value;
    9191        }
     92        --i;
    9293      }
     94      // second pass to calculate lengths
     95      for (i = 0; i < code.Length; i++) {
     96        var c = i - 1;
     97        for (int j = 0; j < code[i].Arity; ++j) {
     98          code[i].Length += code[c].Length;
     99          c -= code[c].Length;
     100        }
     101      }
     102
    93103      return code;
     104    }
     105
     106    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
     107      return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
     108    }
     109
     110    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
     111      if (!rows.Any()) return Enumerable.Empty<double>();
     112
     113      byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {
     114        var opCode = OpCodes.MapSymbolToOpCode(node);
     115        if (supportedOpCodes.Contains(opCode)) return opCode;
     116        else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
     117      };
     118      var code = Compile(tree, dataset, mapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
     119
     120      var result = new double[rows.Length];
     121      var options = new SolverOptions { Iterations = 0 }; // Evaluate only. Do not optimize.
     122
     123      NativeWrapper.GetValues(code, rows, options, result, target: null, out var summary);
     124
     125      // when evaluation took place without any error, we can increment the counter
     126      lock (syncRoot) {
     127        EvaluatedSolutions++;
     128      }
     129
     130      return result;
    94131    }
    95132
     
    102139    private static IDataset cachedDataset;
    103140
    104     protected static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
     141    private static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
    105142      (byte)OpCode.Constant,
    106143      (byte)OpCode.Variable,
     
    115152      (byte)OpCode.Tan,
    116153      (byte)OpCode.Tanh,
    117       (byte)OpCode.Power,
    118       (byte)OpCode.Root,
     154      // (byte)OpCode.Power, // these symbols are handled differently in the NativeInterpreter than in HL
     155      // (byte)OpCode.Root,
    119156      (byte)OpCode.SquareRoot,
    120157      (byte)OpCode.Square,
     
    125162    };
    126163
    127     public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
    128       return GetSymbolicExpressionTreeValues(tree, dataset, rows.ToArray());
    129     }
    130    
    131164    private static void InitCache(IDataset dataset) {
    132165      cachedDataset = dataset;
     
    159192      ClearState();
    160193    }
    161 
    162     public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, int[] rows) {
    163       if (!rows.Any()) return Enumerable.Empty<double>();
    164 
    165       byte mapSupportedSymbols(ISymbolicExpressionTreeNode node) {
    166         var opCode = OpCodes.MapSymbolToOpCode(node);
    167         if (supportedOpCodes.Contains(opCode)) return opCode;
    168         else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
    169       };
    170       var code = Compile(tree, dataset, mapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
    171 
    172       var result = new double[rows.Length];
    173       var options = new SolverOptions { /* not using any options here */ };
    174 
    175       var summary = new OptimizationSummary(); // also not used
    176       NativeWrapper.GetValues(code, rows, options, result, target: null, out summary);
    177 
    178       // when evaluation took place without any error, we can increment the counter
    179       lock (syncRoot) {
    180         EvaluatedSolutions++;
    181       }
    182 
    183       return result;
    184     }
    185194  }
    186195}
  • branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/ParameterOptimizer.cs

    r17989 r18007  
    1313  [StorableType("A624630B-0CEB-4D06-9B26-708987A7AE8F")]
    1414  [Item("ParameterOptimizer", "Operator calling into native C++ code for tree interpretation.")]
    15   public class ParameterOptimizer : NativeInterpreter {
     15  public sealed class ParameterOptimizer : ParameterizedNamedItem {
    1616    private const string UseNonmonotonicStepsParameterName = "UseNonmonotonicSteps";
    1717    private const string OptimizerIterationsParameterName = "OptimizerIterations";
    1818
    19     private const string MinimizerTypeParameterName = "MinimizerType";
    20     private const string LinearSolverTypeParameterName = "LinearSolverType";
    21     private const string TrustRegionStrategyTypeParameterName = "TrustRegionStrategyType";
    22     private const string DogLegTypeParameterName = "DogLegType";
    23     private const string LineSearchDirectionTypeParameterName = "LineSearchDirectionType";
    24 
    25     private static readonly string[] MinimizerType = new[] { "LineSearch", "TrustRegion" };
    26     private static readonly string[] LinerSolverType = new[]
    27     {
    28       "DenseNormalCholesky",
    29       "DenseQR",
    30       "SparseNormalCholesky",
    31       "DenseSchur",
    32       "SparseSchur",
    33       "IterativeSchur",
    34       "ConjugateGradients"
    35     };
    36     private static readonly string[] TrustRegionStrategyType = new[]
    37     {
    38       "LevenbergMarquardt",
    39       "Dogleg"
    40     };
    41     private static readonly string[] DoglegType = new[]
    42     {
    43       "Traditional",
    44       "Subspace"
    45     };
    46     private static readonly string[] LinearSearchDirectionType = new[]
    47     {
    48       "SteepestDescent",
    49       "NonlinearConjugateGradient",
    50       "LBFGS",
    51       "BFGS"
    52     };
     19    private const string MinimizerParameterName = "Minimizer";
     20    private const string LinearSolverParameterName = "LinearSolver";
     21    private const string TrustRegionStrategyParameterName = "TrustRegionStrategy";
     22    private const string DogLegParameterName = "DogLeg";
     23    private const string LineSearchDirectionParameterName = "LineSearchDirection";
    5324
    5425    #region parameters
     
    5930      get { return (IFixedValueParameter<BoolValue>)Parameters[UseNonmonotonicStepsParameterName]; }
    6031    }
    61     public IConstrainedValueParameter<StringValue> MinimizerTypeParameter {
    62       get { return (IConstrainedValueParameter<StringValue>)Parameters[MinimizerTypeParameterName]; }
    63     }
    64     public IConstrainedValueParameter<StringValue> LinearSolverTypeParameter {
    65       get { return (IConstrainedValueParameter<StringValue>)Parameters[LinearSolverTypeParameterName]; }
    66     }
    67     public IConstrainedValueParameter<StringValue> TrustRegionStrategyTypeParameter {
    68       get { return (IConstrainedValueParameter<StringValue>)Parameters[TrustRegionStrategyTypeParameterName]; }
    69     }
    70     public IConstrainedValueParameter<StringValue> DogLegTypeParameter {
    71       get { return (IConstrainedValueParameter<StringValue>)Parameters[DogLegTypeParameterName]; }
    72     }
    73     public IConstrainedValueParameter<StringValue> LineSearchDirectionTypeParameter {
    74       get { return (IConstrainedValueParameter<StringValue>)Parameters[LineSearchDirectionTypeParameterName]; }
     32    public IFixedValueParameter<EnumValue<CeresTypes.Minimizer>> MinimizerTypeParameter {
     33      get { return (IFixedValueParameter<EnumValue<CeresTypes.Minimizer>>)Parameters[MinimizerParameterName]; }
     34    }
     35    public IFixedValueParameter<EnumValue<CeresTypes.LinearSolver>> LinearSolverTypeParameter {
     36      get { return (IFixedValueParameter<EnumValue<CeresTypes.LinearSolver>>)Parameters[LinearSolverParameterName]; }
     37    }
     38    public IFixedValueParameter<EnumValue<CeresTypes.TrustRegionStrategy>> TrustRegionStrategyTypeParameter {
     39      get { return (IFixedValueParameter<EnumValue<CeresTypes.TrustRegionStrategy>>)Parameters[TrustRegionStrategyParameterName]; }
     40    }
     41    public IFixedValueParameter<EnumValue<CeresTypes.DogLeg>> DogLegTypeParameter {
     42      get { return (IFixedValueParameter<EnumValue<CeresTypes.DogLeg>>)Parameters[DogLegParameterName]; }
     43    }
     44    public IFixedValueParameter<EnumValue<CeresTypes.LineSearchDirection>> LineSearchDirectionTypeParameter {
     45      get { return (IFixedValueParameter<EnumValue<CeresTypes.LineSearchDirection>>)Parameters[LineSearchDirectionParameterName]; }
    7546    }
    7647    #endregion
     
    8556      set { UseNonmonotonicStepsParameter.Value.Value = value; }
    8657    }
    87     private CeresTypes.MinimizerType Minimizer {
    88       get { return (CeresTypes.MinimizerType)Enum.Parse(typeof(CeresTypes.MinimizerType), MinimizerTypeParameter.Value.Value); }
    89     }
    90     private CeresTypes.LinearSolverType LinearSolver {
    91       get { return (CeresTypes.LinearSolverType)Enum.Parse(typeof(CeresTypes.LinearSolverType), LinearSolverTypeParameter.Value.Value); }
    92     }
    93     private CeresTypes.TrustRegionStrategyType TrustRegionStrategy {
    94       get { return (CeresTypes.TrustRegionStrategyType)Enum.Parse(typeof(CeresTypes.TrustRegionStrategyType), TrustRegionStrategyTypeParameter.Value.Value); }
    95     }
    96     private CeresTypes.DoglegType Dogleg {
    97       get { return (CeresTypes.DoglegType)Enum.Parse(typeof(CeresTypes.DoglegType), DogLegTypeParameter.Value.Value); }
    98     }
    99     private CeresTypes.LineSearchDirectionType LineSearchDirection {
    100       get { return (CeresTypes.LineSearchDirectionType)Enum.Parse(typeof(CeresTypes.LineSearchDirectionType), LineSearchDirectionTypeParameter.Value.Value); }
     58    private CeresTypes.Minimizer Minimizer {
     59      get { return MinimizerTypeParameter.Value.Value; }
     60      set { MinimizerTypeParameter.Value.Value = value; }
     61    }
     62    private CeresTypes.LinearSolver LinearSolver {
     63      get { return LinearSolverTypeParameter.Value.Value; }
     64      set { LinearSolverTypeParameter.Value.Value = value; }
     65    }
     66    private CeresTypes.TrustRegionStrategy TrustRegionStrategy {
     67      get { return TrustRegionStrategyTypeParameter.Value.Value; }
     68      set { TrustRegionStrategyTypeParameter.Value.Value = value; }
     69    }
     70    private CeresTypes.DogLeg DogLeg {
     71      get { return DogLegTypeParameter.Value.Value; }
     72      set { DogLegTypeParameter.Value.Value = value; }
     73    }
     74    private CeresTypes.LineSearchDirection LineSearchDirection {
     75      get { return LineSearchDirectionTypeParameter.Value.Value; }
     76      set { LineSearchDirectionTypeParameter.Value.Value = value; }
    10177    }
    10278    #endregion
    10379
    104     private static IConstrainedValueParameter<StringValue> InitializeParameter(string name, string[] validValues, string value, bool hidden = true) {
    105       var parameter = new ConstrainedValueParameter<StringValue>(name, new ItemSet<StringValue>(validValues.Select(x => new StringValue(x))));
    106       parameter.Value = parameter.ValidValues.Single(x => x.Value == value);
    107       parameter.Hidden = hidden;
    108       return parameter;
    109     }
    110 
     80    #region storable ctor and cloning
    11181    [StorableConstructor]
    112     protected ParameterOptimizer(StorableConstructorFlag _) : base(_) { }
    113 
    114     public ParameterOptimizer() {
    115       var minimizerTypeParameter = InitializeParameter(MinimizerTypeParameterName, MinimizerType, "TrustRegion");
    116       var linearSolverTypeParameter = InitializeParameter(LinearSolverTypeParameterName, LinerSolverType, "DenseQR");
    117       var trustRegionStrategyTypeParameter = InitializeParameter(TrustRegionStrategyTypeParameterName, TrustRegionStrategyType, "LevenbergMarquardt");
    118       var dogLegTypeParameter = InitializeParameter(DogLegTypeParameterName, DoglegType, "Traditional");
    119       var lineSearchDirectionTypeParameter = InitializeParameter(LineSearchDirectionTypeParameterName, LinearSearchDirectionType, "SteepestDescent");
    120 
    121       Parameters.Add(new FixedValueParameter<IntValue>(OptimizerIterationsParameterName, "The number of iterations for the nonlinear least squares optimizer.", new IntValue(10)));
    122       Parameters.Add(new FixedValueParameter<BoolValue>(UseNonmonotonicStepsParameterName, "Allow the non linear least squares optimizer to make steps in parameter space that don't necessarily decrease the error, but might improve overall convergence.", new BoolValue(false)));
    123       Parameters.AddRange(new[] { minimizerTypeParameter, linearSolverTypeParameter, trustRegionStrategyTypeParameter, dogLegTypeParameter, lineSearchDirectionTypeParameter });
    124     }
     82    private ParameterOptimizer(StorableConstructorFlag _) : base(_) { }
    12583
    12684    public ParameterOptimizer(ParameterOptimizer original, Cloner cloner) : base(original, cloner) { }
     
    12886    public override IDeepCloneable Clone(Cloner cloner) {
    12987      return new ParameterOptimizer(this, cloner);
     88    }
     89    #endregion
     90
     91    public ParameterOptimizer() {
     92      Parameters.Add(new FixedValueParameter<EnumValue<CeresTypes.Minimizer>>(MinimizerParameterName, new EnumValue<CeresTypes.Minimizer>(CeresTypes.Minimizer.TRUST_REGION)));
     93      Parameters.Add(new FixedValueParameter<EnumValue<CeresTypes.LinearSolver>>(LinearSolverParameterName, new EnumValue<CeresTypes.LinearSolver>(CeresTypes.LinearSolver.DENSE_QR)));
     94      Parameters.Add(new FixedValueParameter<EnumValue<CeresTypes.TrustRegionStrategy>>(TrustRegionStrategyParameterName, new EnumValue<CeresTypes.TrustRegionStrategy>(CeresTypes.TrustRegionStrategy.LEVENBERG_MARQUARDT)));
     95      Parameters.Add(new FixedValueParameter<EnumValue<CeresTypes.DogLeg>>(DogLegParameterName, new EnumValue<CeresTypes.DogLeg>(CeresTypes.DogLeg.TRADITIONAL_DOGLEG)));
     96      Parameters.Add(new FixedValueParameter<EnumValue<CeresTypes.LineSearchDirection>>(LineSearchDirectionParameterName, new EnumValue<CeresTypes.LineSearchDirection>(CeresTypes.LineSearchDirection.STEEPEST_DESCENT)));
     97      Parameters.Add(new FixedValueParameter<IntValue>(OptimizerIterationsParameterName, "The number of iterations for the nonlinear least squares optimizer.", new IntValue(10)));
     98      Parameters.Add(new FixedValueParameter<BoolValue>(UseNonmonotonicStepsParameterName, "Allow the non linear least squares optimizer to make steps in parameter space that do not necessarily decrease the error, but might improve overall convergence.", new BoolValue(false)));
    13099    }
    131100
     
    136105    }
    137106
    138     public static Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows, string targetVariable, HashSet<ISymbolicExpressionTreeNode> nodesToOptimize, SolverOptions options, ref OptimizationSummary summary) {
    139       var code = Compile(tree, dataset, MapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
     107    public static Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows, string targetVariable,
     108      HashSet<ISymbolicExpressionTreeNode> nodesToOptimize, SolverOptions options, ref OptimizationSummary summary) {
     109      var code = NativeInterpreter.Compile(tree, dataset, MapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
    140110
    141111      for (int i = 0; i < code.Length; ++i) {
     
    153123    }
    154124
    155     public Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows, string targetVariable, HashSet<ISymbolicExpressionTreeNode> nodesToOptimize = null) {
     125    public Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows, string targetVariable,
     126      HashSet<ISymbolicExpressionTreeNode> nodesToOptimize = null) {
    156127      var options = new SolverOptions {
    157128        Iterations = OptimizerIterations,
     
    159130        LinearSolver = LinearSolver,
    160131        TrustRegionStrategy = TrustRegionStrategy,
    161         Dogleg = Dogleg,
     132        DogLeg = DogLeg,
    162133        LineSearchDirection = LineSearchDirection,
    163134        UseNonmonotonicSteps = UseNonmonotonicSteps ? 1 : 0
     
    187158      // internally the native wrapper takes a single array of NativeInstructions where the indices point to the individual terms
    188159      for (int i = 0; i < terms.Length; ++i) {
    189         var code = Compile(terms[i], dataset, MapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
     160        var code = NativeInterpreter.Compile(terms[i], dataset, MapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
    190161        for (int j = 0; j < code.Length; ++j) {
    191162          code[j].Optimize = nodesToOptimize.Contains(nodes[j]);
     
    202173      var codeArray = totalCode.ToArray();
    203174
    204       NativeWrapper.GetValuesVarPro(codeArray, termIndices,rowsArray, coeff, options, result, target, out summary);
     175      NativeWrapper.GetValuesVarPro(codeArray, termIndices, rowsArray, coeff, options, result, target, out summary);
    205176      return Enumerable.Range(0, totalCodeSize).Where(i => codeArray[i].Optimize).ToDictionary(i => totalNodes[i], i => codeArray[i].Value);
    206177    }
     178
     179    private static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
     180      (byte)OpCode.Constant,
     181      (byte)OpCode.Variable,
     182      (byte)OpCode.Add,
     183      (byte)OpCode.Sub,
     184      (byte)OpCode.Mul,
     185      (byte)OpCode.Div,
     186      (byte)OpCode.Exp,
     187      (byte)OpCode.Log,
     188      (byte)OpCode.Sin,
     189      (byte)OpCode.Cos,
     190      (byte)OpCode.Tan,
     191      (byte)OpCode.Tanh,
     192      // (byte)OpCode.Power, // these symbols are handled differently in the NativeInterpreter than in HL
     193      // (byte)OpCode.Root,
     194      (byte)OpCode.SquareRoot,
     195      (byte)OpCode.Square,
     196      (byte)OpCode.CubeRoot,
     197      (byte)OpCode.Cube,
     198      (byte)OpCode.Absolute,
     199      (byte)OpCode.AnalyticQuotient
     200    };
    207201  }
    208202}
Note: See TracChangeset for help on using the changeset viewer.