Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/ParameterOptimizer.cs @ 18007

Last change on this file since 18007 was 18007, checked in by gkronber, 3 years ago

#3087: removed "strings-enums" in ParameterOptimizer and do not derive ParameterOptimizer from NativeInterpreter (+ renamed enum types in CeresTypes)

File size: 10.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HEAL.Attic;
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
9using HeuristicLab.NativeInterpreter;
10using HeuristicLab.Parameters;
11
12namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
13  [StorableType("A624630B-0CEB-4D06-9B26-708987A7AE8F")]
14  [Item("ParameterOptimizer", "Operator calling into native C++ code for tree interpretation.")]
15  public sealed class ParameterOptimizer : ParameterizedNamedItem {
16    private const string UseNonmonotonicStepsParameterName = "UseNonmonotonicSteps";
17    private const string OptimizerIterationsParameterName = "OptimizerIterations";
18
19    private const string MinimizerParameterName = "Minimizer";
20    private const string LinearSolverParameterName = "LinearSolver";
21    private const string TrustRegionStrategyParameterName = "TrustRegionStrategy";
22    private const string DogLegParameterName = "DogLeg";
23    private const string LineSearchDirectionParameterName = "LineSearchDirection";
24
25    #region parameters
26    public IFixedValueParameter<IntValue> OptimizerIterationsParameter {
27      get { return (IFixedValueParameter<IntValue>)Parameters[OptimizerIterationsParameterName]; }
28    }
29    public IFixedValueParameter<BoolValue> UseNonmonotonicStepsParameter {
30      get { return (IFixedValueParameter<BoolValue>)Parameters[UseNonmonotonicStepsParameterName]; }
31    }
32    public IFixedValueParameter<EnumValue<CeresTypes.Minimizer>> MinimizerTypeParameter {
33      get { return (IFixedValueParameter<EnumValue<CeresTypes.Minimizer>>)Parameters[MinimizerParameterName]; }
34    }
35    public IFixedValueParameter<EnumValue<CeresTypes.LinearSolver>> LinearSolverTypeParameter {
36      get { return (IFixedValueParameter<EnumValue<CeresTypes.LinearSolver>>)Parameters[LinearSolverParameterName]; }
37    }
38    public IFixedValueParameter<EnumValue<CeresTypes.TrustRegionStrategy>> TrustRegionStrategyTypeParameter {
39      get { return (IFixedValueParameter<EnumValue<CeresTypes.TrustRegionStrategy>>)Parameters[TrustRegionStrategyParameterName]; }
40    }
41    public IFixedValueParameter<EnumValue<CeresTypes.DogLeg>> DogLegTypeParameter {
42      get { return (IFixedValueParameter<EnumValue<CeresTypes.DogLeg>>)Parameters[DogLegParameterName]; }
43    }
44    public IFixedValueParameter<EnumValue<CeresTypes.LineSearchDirection>> LineSearchDirectionTypeParameter {
45      get { return (IFixedValueParameter<EnumValue<CeresTypes.LineSearchDirection>>)Parameters[LineSearchDirectionParameterName]; }
46    }
47    #endregion
48
49    #region parameter properties
50    public int OptimizerIterations {
51      get { return OptimizerIterationsParameter.Value.Value; }
52      set { OptimizerIterationsParameter.Value.Value = value; }
53    }
54    public bool UseNonmonotonicSteps {
55      get { return UseNonmonotonicStepsParameter.Value.Value; }
56      set { UseNonmonotonicStepsParameter.Value.Value = value; }
57    }
58    private CeresTypes.Minimizer Minimizer {
59      get { return MinimizerTypeParameter.Value.Value; }
60      set { MinimizerTypeParameter.Value.Value = value; }
61    }
62    private CeresTypes.LinearSolver LinearSolver {
63      get { return LinearSolverTypeParameter.Value.Value; }
64      set { LinearSolverTypeParameter.Value.Value = value; }
65    }
66    private CeresTypes.TrustRegionStrategy TrustRegionStrategy {
67      get { return TrustRegionStrategyTypeParameter.Value.Value; }
68      set { TrustRegionStrategyTypeParameter.Value.Value = value; }
69    }
70    private CeresTypes.DogLeg DogLeg {
71      get { return DogLegTypeParameter.Value.Value; }
72      set { DogLegTypeParameter.Value.Value = value; }
73    }
74    private CeresTypes.LineSearchDirection LineSearchDirection {
75      get { return LineSearchDirectionTypeParameter.Value.Value; }
76      set { LineSearchDirectionTypeParameter.Value.Value = value; }
77    }
78    #endregion
79
80    #region storable ctor and cloning
81    [StorableConstructor]
82    private ParameterOptimizer(StorableConstructorFlag _) : base(_) { }
83
84    public ParameterOptimizer(ParameterOptimizer original, Cloner cloner) : base(original, cloner) { }
85
86    public override IDeepCloneable Clone(Cloner cloner) {
87      return new ParameterOptimizer(this, cloner);
88    }
89    #endregion
90
91    public ParameterOptimizer() {
92      Parameters.Add(new FixedValueParameter<EnumValue<CeresTypes.Minimizer>>(MinimizerParameterName, new EnumValue<CeresTypes.Minimizer>(CeresTypes.Minimizer.TRUST_REGION)));
93      Parameters.Add(new FixedValueParameter<EnumValue<CeresTypes.LinearSolver>>(LinearSolverParameterName, new EnumValue<CeresTypes.LinearSolver>(CeresTypes.LinearSolver.DENSE_QR)));
94      Parameters.Add(new FixedValueParameter<EnumValue<CeresTypes.TrustRegionStrategy>>(TrustRegionStrategyParameterName, new EnumValue<CeresTypes.TrustRegionStrategy>(CeresTypes.TrustRegionStrategy.LEVENBERG_MARQUARDT)));
95      Parameters.Add(new FixedValueParameter<EnumValue<CeresTypes.DogLeg>>(DogLegParameterName, new EnumValue<CeresTypes.DogLeg>(CeresTypes.DogLeg.TRADITIONAL_DOGLEG)));
96      Parameters.Add(new FixedValueParameter<EnumValue<CeresTypes.LineSearchDirection>>(LineSearchDirectionParameterName, new EnumValue<CeresTypes.LineSearchDirection>(CeresTypes.LineSearchDirection.STEEPEST_DESCENT)));
97      Parameters.Add(new FixedValueParameter<IntValue>(OptimizerIterationsParameterName, "The number of iterations for the nonlinear least squares optimizer.", new IntValue(10)));
98      Parameters.Add(new FixedValueParameter<BoolValue>(UseNonmonotonicStepsParameterName, "Allow the non linear least squares optimizer to make steps in parameter space that do not necessarily decrease the error, but might improve overall convergence.", new BoolValue(false)));
99    }
100
101    private static byte MapSupportedSymbols(ISymbolicExpressionTreeNode node) {
102      var opCode = OpCodes.MapSymbolToOpCode(node);
103      if (supportedOpCodes.Contains(opCode)) return opCode;
104      else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
105    }
106
107    public static Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows, string targetVariable,
108      HashSet<ISymbolicExpressionTreeNode> nodesToOptimize, SolverOptions options, ref OptimizationSummary summary) {
109      var code = NativeInterpreter.Compile(tree, dataset, MapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
110
111      for (int i = 0; i < code.Length; ++i) {
112        code[i].Optimize = nodesToOptimize.Contains(nodes[i]);
113      }
114
115      if (options.Iterations > 0) {
116        var target = dataset.GetDoubleValues(targetVariable, rows).ToArray();
117        var rowsArray = rows.ToArray();
118        var result = new double[rowsArray.Length];
119
120        NativeWrapper.GetValues(code, rowsArray, options, result, target, out summary);
121      }
122      return Enumerable.Range(0, code.Length).Where(i => nodes[i] is SymbolicExpressionTreeTerminalNode).ToDictionary(i => nodes[i], i => code[i].Value);
123    }
124
125    public Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows, string targetVariable,
126      HashSet<ISymbolicExpressionTreeNode> nodesToOptimize = null) {
127      var options = new SolverOptions {
128        Iterations = OptimizerIterations,
129        Minimizer = Minimizer,
130        LinearSolver = LinearSolver,
131        TrustRegionStrategy = TrustRegionStrategy,
132        DogLeg = DogLeg,
133        LineSearchDirection = LineSearchDirection,
134        UseNonmonotonicSteps = UseNonmonotonicSteps ? 1 : 0
135      };
136
137      var summary = new OptimizationSummary();
138
139      // if no nodes are specified, use all the nodes
140      if (nodesToOptimize == null) {
141        nodesToOptimize = new HashSet<ISymbolicExpressionTreeNode>(tree.IterateNodesPrefix().Where(x => x is SymbolicExpressionTreeTerminalNode));
142      }
143
144      return OptimizeTree(tree, dataset, rows, targetVariable, nodesToOptimize, options, ref summary);
145    }
146
147    public static Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree[] terms, IDataset dataset, IEnumerable<int> rows, string targetVariable, HashSet<ISymbolicExpressionTreeNode> nodesToOptimize, SolverOptions options, double[] coeff, ref OptimizationSummary summary) {
148      if (options.Iterations == 0) {
149        // throw exception? set iterations to 100? return empty dictionary?
150        return new Dictionary<ISymbolicExpressionTreeNode, double>();
151      }
152
153      var termIndices = new int[terms.Length];
154      var totalCodeSize = 0;
155      var totalCode = new List<NativeInstruction>();
156      var totalNodes = new List<ISymbolicExpressionTreeNode>();
157
158      // internally the native wrapper takes a single array of NativeInstructions where the indices point to the individual terms
159      for (int i = 0; i < terms.Length; ++i) {
160        var code = NativeInterpreter.Compile(terms[i], dataset, MapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
161        for (int j = 0; j < code.Length; ++j) {
162          code[j].Optimize = nodesToOptimize.Contains(nodes[j]);
163        }
164        totalCode.AddRange(code);
165        totalNodes.AddRange(nodes);
166
167        termIndices[i] = code.Length + totalCodeSize - 1;
168        totalCodeSize += code.Length;
169      }
170      var target = dataset.GetDoubleValues(targetVariable, rows).ToArray();
171      var rowsArray = rows.ToArray();
172      var result = new double[rowsArray.Length];
173      var codeArray = totalCode.ToArray();
174
175      NativeWrapper.GetValuesVarPro(codeArray, termIndices, rowsArray, coeff, options, result, target, out summary);
176      return Enumerable.Range(0, totalCodeSize).Where(i => codeArray[i].Optimize).ToDictionary(i => totalNodes[i], i => codeArray[i].Value);
177    }
178
179    private static readonly HashSet<byte> supportedOpCodes = new HashSet<byte>() {
180      (byte)OpCode.Constant,
181      (byte)OpCode.Variable,
182      (byte)OpCode.Add,
183      (byte)OpCode.Sub,
184      (byte)OpCode.Mul,
185      (byte)OpCode.Div,
186      (byte)OpCode.Exp,
187      (byte)OpCode.Log,
188      (byte)OpCode.Sin,
189      (byte)OpCode.Cos,
190      (byte)OpCode.Tan,
191      (byte)OpCode.Tanh,
192      // (byte)OpCode.Power, // these symbols are handled differently in the NativeInterpreter than in HL
193      // (byte)OpCode.Root,
194      (byte)OpCode.SquareRoot,
195      (byte)OpCode.Square,
196      (byte)OpCode.CubeRoot,
197      (byte)OpCode.Cube,
198      (byte)OpCode.Absolute,
199      (byte)OpCode.AnalyticQuotient
200    };
201  }
202}
Note: See TracBrowser for help on using the repository browser.