Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/ParameterOptimizer.cs @ 17853

Last change on this file since 17853 was 17853, checked in by bburlacu, 3 years ago

#3087: Add accidentally omitted files.

File size: 10.3 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HEAL.Attic;
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
9using HeuristicLab.Parameters;
10
11namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
12  [StorableType("A624630B-0CEB-4D06-9B26-708987A7AE8F")]
13  [Item("ParameterOptimizer", "Operator calling into native C++ code for tree interpretation.")]
14  public class ParameterOptimizer : NativeInterpreter {
15    private const string UseNonmonotonicStepsParameterName = "UseNonmonotonicSteps";
16    private const string OptimizerIterationsParameterName = "OptimizerIterations";
17
18    private const string MinimizerTypeParameterName = "MinimizerType";
19    private const string LinearSolverTypeParameterName = "LinearSolverType";
20    private const string TrustRegionStrategyTypeParameterName = "TrustRegionStrategyType";
21    private const string DogLegTypeParameterName = "DogLegType";
22    private const string LineSearchDirectionTypeParameterName = "LineSearchDirectionType";
23
24    private static readonly string[] MinimizerType = new[] { "LineSearch", "TrustRegion" };
25    private static readonly string[] LinerSolverType = new[]
26    {
27      "DenseNormalCholesky",
28      "DenseQR",
29      "SparseNormalCholesky",
30      "DenseSchur",
31      "SparseSchur",
32      "IterativeSchur",
33      "ConjugateGradients"
34    };
35    private static readonly string[] TrustRegionStrategyType = new[]
36    {
37      "LevenbergMarquardt",
38      "Dogleg"
39    };
40    private static readonly string[] DoglegType = new[]
41    {
42      "Traditional",
43      "Subspace"
44    };
45    private static readonly string[] LinearSearchDirectionType = new[]
46    {
47      "SteepestDescent",
48      "NonlinearConjugateGradient",
49      "LBFGS",
50      "BFGS"
51    };
52
53    #region parameters
54    public IFixedValueParameter<IntValue> OptimizerIterationsParameter {
55      get { return (IFixedValueParameter<IntValue>)Parameters[OptimizerIterationsParameterName]; }
56    }
57    public IFixedValueParameter<BoolValue> UseNonmonotonicStepsParameter {
58      get { return (IFixedValueParameter<BoolValue>)Parameters[UseNonmonotonicStepsParameterName]; }
59    }
60    public IConstrainedValueParameter<StringValue> MinimizerTypeParameter {
61      get { return (IConstrainedValueParameter<StringValue>)Parameters[MinimizerTypeParameterName]; }
62    }
63    public IConstrainedValueParameter<StringValue> LinearSolverTypeParameter {
64      get { return (IConstrainedValueParameter<StringValue>)Parameters[LinearSolverTypeParameterName]; }
65    }
66    public IConstrainedValueParameter<StringValue> TrustRegionStrategyTypeParameter {
67      get { return (IConstrainedValueParameter<StringValue>)Parameters[TrustRegionStrategyTypeParameterName]; }
68    }
69    public IConstrainedValueParameter<StringValue> DogLegTypeParameter {
70      get { return (IConstrainedValueParameter<StringValue>)Parameters[DogLegTypeParameterName]; }
71    }
72    public IConstrainedValueParameter<StringValue> LineSearchDirectionTypeParameter {
73      get { return (IConstrainedValueParameter<StringValue>)Parameters[LineSearchDirectionTypeParameterName]; }
74    }
75    #endregion
76
77    #region parameter properties
78    public int OptimizerIterations {
79      get { return OptimizerIterationsParameter.Value.Value; }
80      set { OptimizerIterationsParameter.Value.Value = value; }
81    }
82    public bool UseNonmonotonicSteps {
83      get { return UseNonmonotonicStepsParameter.Value.Value; }
84      set { UseNonmonotonicStepsParameter.Value.Value = value; }
85    }
86    private int Minimizer {
87      get { return Array.IndexOf(MinimizerType, MinimizerTypeParameter.Value.Value); }
88    }
89    private int LinearSolver {
90      get { return Array.IndexOf(LinerSolverType, LinearSolverTypeParameter.Value.Value); }
91    }
92    private int TrustRegionStrategy {
93      get { return Array.IndexOf(TrustRegionStrategyType, TrustRegionStrategyTypeParameter.Value.Value); }
94    }
95    private int Dogleg {
96      get { return Array.IndexOf(DoglegType, DogLegTypeParameter.Value.Value); }
97    }
98    private int LineSearchDirection {
99      get { return Array.IndexOf(LinearSearchDirectionType, LineSearchDirectionTypeParameter.Value.Value); }
100    }
101    #endregion
102
103    private static IConstrainedValueParameter<StringValue> InitializeParameter(string name, string[] validValues, string value, bool hidden = true) {
104      var parameter = new ConstrainedValueParameter<StringValue>(name, new ItemSet<StringValue>(validValues.Select(x => new StringValue(x))));
105      parameter.Value = parameter.ValidValues.Single(x => x.Value == value);
106      parameter.Hidden = hidden;
107      return parameter;
108    }
109
110    [StorableConstructor]
111    protected ParameterOptimizer(StorableConstructorFlag _) : base(_) { }
112
113    public ParameterOptimizer() {
114      var minimizerTypeParameter = InitializeParameter(MinimizerTypeParameterName, MinimizerType, "TrustRegion");
115      var linearSolverTypeParameter = InitializeParameter(LinearSolverTypeParameterName, LinerSolverType, "DenseQR");
116      var trustRegionStrategyTypeParameter = InitializeParameter(TrustRegionStrategyTypeParameterName, TrustRegionStrategyType, "LevenbergMarquardt");
117      var dogLegTypeParameter = InitializeParameter(DogLegTypeParameterName, DoglegType, "Traditional");
118      var lineSearchDirectionTypeParameter = InitializeParameter(LineSearchDirectionTypeParameterName, LinearSearchDirectionType, "SteepestDescent");
119
120      Parameters.Add(new FixedValueParameter<IntValue>(OptimizerIterationsParameterName, "The number of iterations for the nonlinear least squares optimizer.", new IntValue(10)));
121      Parameters.Add(new FixedValueParameter<BoolValue>(UseNonmonotonicStepsParameterName, "Allow the non linear least squares optimizer to make steps in parameter space that don't necessarily decrease the error, but might improve overall convergence.", new BoolValue(false)));
122      Parameters.AddRange(new[] { minimizerTypeParameter, linearSolverTypeParameter, trustRegionStrategyTypeParameter, dogLegTypeParameter, lineSearchDirectionTypeParameter });
123    }
124
125    public ParameterOptimizer(ParameterOptimizer original, Cloner cloner) : base(original, cloner) { }
126
127    public override IDeepCloneable Clone(Cloner cloner) {
128      return new ParameterOptimizer(this, cloner);
129    }
130
131    private static byte MapSupportedSymbols(ISymbolicExpressionTreeNode node) {
132      var opCode = OpCodes.MapSymbolToOpCode(node);
133      if (supportedOpCodes.Contains(opCode)) return opCode;
134      else throw new NotSupportedException($"The native interpreter does not support {node.Symbol.Name}");
135    }
136
137    public static Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows, string targetVariable, HashSet<ISymbolicExpressionTreeNode> nodesToOptimize, SolverOptions options, ref OptimizationSummary summary) {
138      var code = Compile(tree, dataset, MapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
139
140      for (int i = 0; i < code.Length; ++i) {
141        code[i].Optimize = nodesToOptimize.Contains(nodes[i]);
142      }
143
144      if (options.Iterations > 0) {
145        var target = dataset.GetDoubleValues(targetVariable, rows).ToArray();
146        var rowsArray = rows.ToArray();
147        var result = new double[rowsArray.Length];
148
149        NativeWrapper.GetValues(code, rowsArray, result, target, options, ref summary);
150      }
151      return Enumerable.Range(0, code.Length).Where(i => nodes[i] is SymbolicExpressionTreeTerminalNode).ToDictionary(i => nodes[i], i => code[i].Value);
152    }
153
154    public Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows, string targetVariable, HashSet<ISymbolicExpressionTreeNode> nodesToOptimize = null) {
155      var options = new SolverOptions {
156        Iterations = OptimizerIterations,
157        Minimizer = Minimizer,
158        LinearSolver = LinearSolver,
159        TrustRegionStrategy = TrustRegionStrategy,
160        Dogleg = Dogleg,
161        LineSearchDirection = LineSearchDirection,
162        UseNonmonotonicSteps = UseNonmonotonicSteps ? 1 : 0
163      };
164
165      var summary = new OptimizationSummary();
166
167      // if no nodes are specified, use all the nodes
168      if (nodesToOptimize == null) {
169        nodesToOptimize = new HashSet<ISymbolicExpressionTreeNode>(tree.IterateNodesPrefix().Where(x => x is SymbolicExpressionTreeTerminalNode));
170      }
171
172      return OptimizeTree(tree, dataset, rows, targetVariable, nodesToOptimize, options, ref summary);
173    }
174
175    public static Dictionary<ISymbolicExpressionTreeNode, double> OptimizeTree(ISymbolicExpressionTree[] terms, IDataset dataset, IEnumerable<int> rows, string targetVariable, HashSet<ISymbolicExpressionTreeNode> nodesToOptimize, SolverOptions options, double[] coeff, ref OptimizationSummary summary) {
176      if (options.Iterations == 0) {
177        // throw exception? set iterations to 100? return empty dictionary?
178        return new Dictionary<ISymbolicExpressionTreeNode, double>();
179      }
180
181      var termIndices = new int[terms.Length];
182      var totalCodeSize = 0;
183      var totalCode = new List<NativeInstruction>();
184      var totalNodes = new List<ISymbolicExpressionTreeNode>();
185
186      // internally the native wrapper takes a single array of NativeInstructions where the indices point to the individual terms
187      for (int i = 0; i < terms.Length; ++i) {
188        var code = Compile(terms[i], dataset, MapSupportedSymbols, out List<ISymbolicExpressionTreeNode> nodes);
189        for (int j = 0; j < code.Length; ++j) {
190          code[j].Optimize = nodesToOptimize.Contains(nodes[j]);
191        }
192        totalCode.AddRange(code);
193        totalNodes.AddRange(nodes);
194
195        termIndices[i] = code.Length + totalCodeSize - 1;
196        totalCodeSize += code.Length;
197      }
198      var target = dataset.GetDoubleValues(targetVariable, rows).ToArray();
199      var rowsArray = rows.ToArray();
200      var result = new double[rowsArray.Length];
201      var codeArray = totalCode.ToArray();
202
203      NativeWrapper.GetValuesVarPro(codeArray, termIndices, rowsArray, coeff, result, target, options, ref summary);
204      return Enumerable.Range(0, totalCodeSize).Where(i => codeArray[i].Optimize).ToDictionary(i => totalNodes[i], i => codeArray[i].Value);
205    }
206  }
207}
Note: See TracBrowser for help on using the repository browser.