Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/ParameterOptimizer.cs @ 17989

Last change on this file since 17989 was 17989, checked in by gkronber, 3 years ago

#3087: updated native dlls for NativeInterpreter to a version that runs on Hive infrastructure. Some smaller changes because of deviations in the independently developed implementations (in particular enum types).

File size: 10.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HEAL.Attic;
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
9using HeuristicLab.NativeInterpreter;
10using HeuristicLab.Parameters;
11
12namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
13  [StorableType("A624630B-0CEB-4D06-9B26-708987A7AE8F")]
14  [Item("ParameterOptimizer", "Operator calling into native C++ code for tree interpretation.")]
15  public class ParameterOptimizer : NativeInterpreter {
16    private const string UseNonmonotonicStepsParameterName = "UseNonmonotonicSteps";
17    private const string OptimizerIterationsParameterName = "OptimizerIterations";
18
19    private const string MinimizerTypeParameterName = "MinimizerType";
20    private const string LinearSolverTypeParameterName = "LinearSolverType";
21    private const string TrustRegionStrategyTypeParameterName = "TrustRegionStrategyType";
22    private const string DogLegTypeParameterName = "DogLegType";
23    private const string LineSearchDirectionTypeParameterName = "LineSearchDirectionType";
24
25    private static readonly string[] MinimizerType = new[] { "LineSearch", "TrustRegion" };
26    private static readonly string[] LinerSolverType = new[]
27    {
28      "DenseNormalCholesky",
29      "DenseQR",
30      "SparseNormalCholesky",
31      "DenseSchur",
32      "SparseSchur",
33      "IterativeSchur",
34      "ConjugateGradients"
35    };
36    private static readonly string[] TrustRegionStrategyType = new[]
37    {
38      "LevenbergMarquardt",
39      "Dogleg"
40    };
41    private static readonly string[] DoglegType = new[]
42    {
43      "Traditional",
44      "Subspace"
45    };
46    private static readonly string[] LinearSearchDirectionType = new[]
47    {
48      "SteepestDescent",
49      "NonlinearConjugateGradient",
50      "LBFGS",
51      "BFGS"
52    };
53
54    #region parameters
55    public IFixedValueParameter<IntValue> OptimizerIterationsParameter {
56      get { return (IFixedValueParameter<IntValue>)Parameters[OptimizerIterationsParameterName]; }
57    }
58    public IFixedValueParameter<BoolValue> UseNonmonotonicStepsParameter {
59      get { return (IFixedValueParameter<BoolValue>)Parameters[UseNonmonotonicStepsParameterName]; }
60    }
61    public IConstrainedValueParameter<StringValue> MinimizerTypeParameter {
62      get { return (IConstrainedValueParameter<StringValue>)Parameters[MinimizerTypeParameterName]; }
63    }
64    public IConstrainedValueParameter<StringValue> LinearSolverTypeParameter {
65      get { return (IConstrainedValueParameter<StringValue>)Parameters[LinearSolverTypeParameterName]; }
66    }
67    public IConstrainedValueParameter<StringValue> TrustRegionStrategyTypeParameter {
68      get { return (IConstrainedValueParameter<StringValue>)Parameters[TrustRegionStrategyTypeParameterName]; }
69    }
70    public IConstrainedValueParameter<StringValue> DogLegTypeParameter {
71      get { return (IConstrainedValueParameter<StringValue>)Parameters[DogLegTypeParameterName]; }
72    }
73    public IConstrainedValueParameter<StringValue> LineSearchDirectionTypeParameter {
74      get { return (IConstrainedValueParameter<StringValue>)Parameters[LineSearchDirectionTypeParameterName]; }
75    }
76    #endregion
77
78    #region parameter properties
79    public int OptimizerIterations {
80      get { return OptimizerIterationsParameter.Value.Value; }
81      set { OptimizerIterationsParameter.Value.Value = value; }
82    }
83    public bool UseNonmonotonicSteps {
84      get { return UseNonmonotonicStepsParameter.Value.Value; }
85      set { UseNonmonotonicStepsParameter.Value.Value = value; }
86    }
87    private CeresTypes.MinimizerType Minimizer {
88      get { return (CeresTypes.MinimizerType)Enum.Parse(typeof(CeresTypes.MinimizerType), MinimizerTypeParameter.Value.Value); }
89    }
90    private CeresTypes.LinearSolverType LinearSolver {
91      get { return (CeresTypes.LinearSolverType)Enum.Parse(typeof(CeresTypes.LinearSolverType), LinearSolverTypeParameter.Value.Value); }
92    }
93    private CeresTypes.TrustRegionStrategyType TrustRegionStrategy {
94      get { return (CeresTypes.TrustRegionStrategyType)Enum.Parse(typeof(CeresTypes.TrustRegionStrategyType), TrustRegionStrategyTypeParameter.Value.Value); }
95    }
96    private CeresTypes.DoglegType Dogleg {
97      get { return (CeresTypes.DoglegType)Enum.Parse(typeof(CeresTypes.DoglegType), DogLegTypeParameter.Value.Value); }
98    }
99    private CeresTypes.LineSearchDirectionType LineSearchDirection {
100      get { return (CeresTypes.LineSearchDirectionType)Enum.Parse(typeof(CeresTypes.LineSearchDirectionType), LineSearchDirectionTypeParameter.Value.Value); }
101    }
102    #endregion
103
104    private static IConstrainedValueParameter<StringValue> InitializeParameter(string name, string[] validValues, string value, bool hidden = true) {
105      var parameter = new ConstrainedValueParameter<StringValue>(name, new ItemSet<StringValue>(validValues.Select(x => new StringValue(x))));
106      parameter.Value = parameter.ValidValues.Single(x => x.Value == value);
107      parameter.Hidden = hidden;
108      return parameter;
109    }
110
111    [StorableConstructor]
112    protected ParameterOptimizer(StorableConstructorFlag _) : base(_) { }
113
114    public ParameterOptimizer() {
115      var minimizerTypeParameter = InitializeParameter(MinimizerTypeParameterName, MinimizerType, "TrustRegion");
116      var linearSolverTypeParameter = InitializeParameter(LinearSolverTypeParameterName, LinerSolverType, "DenseQR");
117      var trustRegionStrategyTypeParameter = InitializeParameter(TrustRegionStrategyTypeParameterName, TrustRegionStrategyType, "LevenbergMarquardt");
118      var dogLegTypeParameter = InitializeParameter(DogLegTypeParameterName, DoglegType, "Traditional");
119      var lineSearchDirectionTypeParameter = InitializeParameter(LineSearchDirectionTypeParameterName, LinearSearchDirectionType, "SteepestDescent");
120
121      Parameters.Add(new FixedValueParameter<IntValue>(OptimizerIterationsParameterName, "The number of iterations for the nonlinear least squares optimizer.", new IntValue(10)));
122      Parameters.Add(new FixedValueParameter<BoolValue>(UseNonmonotonicStepsParameterName, "Allow the non linear least squares optimizer to make steps in parameter space that don't necessarily decrease the error, but might improve overall convergence.", new BoolValue(false)));
123      Parameters.AddRange(new[] { minimizerTypeParameter, linearSolverTypeParameter, trustRegionStrategyTypeParameter, dogLegTypeParameter, lineSearchDirectionTypeParameter });
124    }
125
126    public ParameterOptimizer(ParameterOptimizer original, Cloner cloner) : base(original, cloner) { }
127
128    public override IDeepCloneable Clone(Cloner cloner) {
129      return new ParameterOptimizer(this, cloner);
130    }
131
132    private static byte MapSupportedSymbols(ISymbolicExpressionTreeNode node) {
133      var opCode = OpCodes.MapSymbolToOpCode(node);
134      if (supportedOpCodes.Contains(opCode)) return opCode;
135      else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
136    }
137
138    public static Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows, string targetVariable, HashSet<ISymbolicExpressionTreeNode> nodesToOptimize, SolverOptions options, ref OptimizationSummary summary) {
139      var code = Compile(tree, dataset, MapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
140
141      for (int i = 0; i < code.Length; ++i) {
142        code[i].Optimize = nodesToOptimize.Contains(nodes[i]);
143      }
144
145      if (options.Iterations > 0) {
146        var target = dataset.GetDoubleValues(targetVariable, rows).ToArray();
147        var rowsArray = rows.ToArray();
148        var result = new double[rowsArray.Length];
149
150        NativeWrapper.GetValues(code, rowsArray, options, result, target, out summary);
151      }
152      return Enumerable.Range(0, code.Length).Where(i => nodes[i] is SymbolicExpressionTreeTerminalNode).ToDictionary(i => nodes[i], i => code[i].Value);
153    }
154
155    public Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows, string targetVariable, HashSet<ISymbolicExpressionTreeNode> nodesToOptimize = null) {
156      var options = new SolverOptions {
157        Iterations = OptimizerIterations,
158        Minimizer = Minimizer,
159        LinearSolver = LinearSolver,
160        TrustRegionStrategy = TrustRegionStrategy,
161        Dogleg = Dogleg,
162        LineSearchDirection = LineSearchDirection,
163        UseNonmonotonicSteps = UseNonmonotonicSteps ? 1 : 0
164      };
165
166      var summary = new OptimizationSummary();
167
168      // if no nodes are specified, use all the nodes
169      if (nodesToOptimize == null) {
170        nodesToOptimize = new HashSet<ISymbolicExpressionTreeNode>(tree.IterateNodesPrefix().Where(x => x is SymbolicExpressionTreeTerminalNode));
171      }
172
173      return OptimizeTree(tree, dataset, rows, targetVariable, nodesToOptimize, options, ref summary);
174    }
175
176    public static Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree[] terms, IDataset dataset, IEnumerable<int> rows, string targetVariable, HashSet<ISymbolicExpressionTreeNode> nodesToOptimize, SolverOptions options, double[] coeff, ref OptimizationSummary summary) {
177      if (options.Iterations == 0) {
178        // throw exception? set iterations to 100? return empty dictionary?
179        return new Dictionary<ISymbolicExpressionTreeNode, double>();
180      }
181
182      var termIndices = new int[terms.Length];
183      var totalCodeSize = 0;
184      var totalCode = new List<NativeInstruction>();
185      var totalNodes = new List<ISymbolicExpressionTreeNode>();
186
187      // internally the native wrapper takes a single array of NativeInstructions where the indices point to the individual terms
188      for (int i = 0; i < terms.Length; ++i) {
189        var code = Compile(terms[i], dataset, MapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
190        for (int j = 0; j < code.Length; ++j) {
191          code[j].Optimize = nodesToOptimize.Contains(nodes[j]);
192        }
193        totalCode.AddRange(code);
194        totalNodes.AddRange(nodes);
195
196        termIndices[i] = code.Length + totalCodeSize - 1;
197        totalCodeSize += code.Length;
198      }
199      var target = dataset.GetDoubleValues(targetVariable, rows).ToArray();
200      var rowsArray = rows.ToArray();
201      var result = new double[rowsArray.Length];
202      var codeArray = totalCode.ToArray();
203
204      NativeWrapper.GetValuesVarPro(codeArray, termIndices,rowsArray, coeff, options, result, target, out summary);
205      return Enumerable.Range(0, totalCodeSize).Where(i => codeArray[i].Optimize).ToDictionary(i => totalNodes[i], i => codeArray[i].Value);
206    }
207  }
208}
Note: See TracBrowser for help on using the repository browser.