source: branches/3073_IA_constraint_splitting/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/IACompiledExpressionBoundsEstimator.cs @ 17772

Last change on this file since 17772 was 17772, checked in by bburlacu, 14 months ago

#3073: Add IACompiledExpressionBoundsEstimator

File size: 16.3 KB
Line 
1using HEAL.Attic;
2
3using HeuristicLab.Common;
4using HeuristicLab.Core;
5using HeuristicLab.Data;
6using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
7using HeuristicLab.Parameters;
8
9using System;
10using System.Collections.Generic;
11using System.Linq;
12using System.Linq.Expressions;
13
14namespace HeuristicLab.Problems.DataAnalysis.Symbolic
15{
16  [StorableType("60015D64-5D8B-408A-90A1-E4111BC114D4")]
17  [Item("IA Compiled Expression Bounds Estimator", "Compile a symbolic model into a lambda and use it to evaluate model bounds.")]
18  public class IACompiledExpressionBoundsEstimator : ParameterizedNamedItem, IBoundsEstimator
19  {
20    // interval method names
21    private static readonly Dictionary<byte, string> methodName = new Dictionary<byte, string>() {
22      { OpCodes.Add, "Add" },
23      { OpCodes.Sub, "Subtract" },
24      { OpCodes.Mul, "Multiply" },
25      { OpCodes.Div, "Divide" },
26      { OpCodes.Sin, "Sine" },
27      { OpCodes.Cos, "Cosine" },
28      { OpCodes.Tan, "Tangens" },
29      { OpCodes.Tanh, "HyperbolicTangent" },
30      { OpCodes.Log, "Logarithm" },
31      { OpCodes.Exp, "Exponential" },
32      { OpCodes.Square, "Square" },
33      { OpCodes.Cube, "Cube" },
34      { OpCodes.SquareRoot, "SquareRoot" },
35      { OpCodes.CubeRoot, "CubicRoot" },
36      { OpCodes.Absolute, "Absolute" },
37      { OpCodes.AnalyticQuotient, "AnalyticalQuotient" },
38    };
39
40    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
41    private const string UseIntervalSplittingParameterName = "Use Interval splitting";
42    private const string MaxSplitParameterName = "MaxSplit";
43    private const string MinWidthParameterName = "MinWidth";
44
45    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
46      get => (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
47    }
48
49    public IFixedValueParameter<BoolValue> UseIntervalSplittingParameter {
50      get => (IFixedValueParameter<BoolValue>)Parameters[UseIntervalSplittingParameterName];
51    }
52
53    public IFixedValueParameter<IntValue> MaxSplitParameter {
54      get => (IFixedValueParameter<IntValue>)Parameters[MaxSplitParameterName];
55    }
56
57    public IFixedValueParameter<DoubleValue> MinWidthParameter {
58      get => (IFixedValueParameter<DoubleValue>)Parameters[MinWidthParameterName];
59    }
60
61    public int MaxSplit {
62      get => MaxSplitParameter.Value.Value;
63      set => MaxSplitParameter.Value.Value = value;
64    }
65
66    public double MinWidth {
67      get => MinWidthParameter.Value.Value;
68      set => MinWidthParameter.Value.Value = value;
69    }
70
71    public int EvaluatedSolutions {
72      get => EvaluatedSolutionsParameter.Value.Value;
73      set => EvaluatedSolutionsParameter.Value.Value = value;
74    }
75
76    public bool UseIntervalSplitting {
77      get => UseIntervalSplittingParameter.Value.Value;
78      set => UseIntervalSplittingParameter.Value.Value = value;
79    }
80
81    private readonly object syncRoot = new object();
82
83    public IACompiledExpressionBoundsEstimator() : base("IA Bounds Estimator",
84      "Estimates the bounds of the model with interval arithmetic, by first compiling the model into a lambda.") {
85      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName,
86        "A counter for the total number of solutions the estimator has evaluated.", new IntValue(0)));
87      Parameters.Add(new FixedValueParameter<BoolValue>(UseIntervalSplittingParameterName,
88        "Defines whether interval splitting is activated or not.", new BoolValue(false)));
89      Parameters.Add(new FixedValueParameter<IntValue>(MaxSplitParameterName,
90        "Defines the number of iterations of splitting.", new IntValue(200)));
91      Parameters.Add(new FixedValueParameter<DoubleValue>(MinWidthParameterName,
92        "Width of interval, after the splitting should stop.", new DoubleValue(0.0)));
93    }
94
95    [StorableConstructor]
96    private IACompiledExpressionBoundsEstimator(StorableConstructorFlag _) : base(_) { }
97
98    private IACompiledExpressionBoundsEstimator(IACompiledExpressionBoundsEstimator original, Cloner cloner) : base(original, cloner) { }
99
100    public override IDeepCloneable Clone(Cloner cloner) {
101      return new IACompiledExpressionBoundsEstimator(this, cloner);
102    }
103
104
105
106    public double CheckConstraint(ISymbolicExpressionTree tree, IntervalCollection variableRanges, IntervalConstraint constraint) {
107      if (!UseIntervalSplitting) {
108        var modelBound = GetModelBound(tree, variableRanges);
109        if (constraint.Interval.Contains(modelBound)) return 0.0;
110        return Math.Abs(modelBound.LowerBound - constraint.Interval.LowerBound) +
111               Math.Abs(modelBound.UpperBound - constraint.Interval.UpperBound);
112      }
113
114      if (double.IsNegativeInfinity(constraint.Interval.LowerBound) &&
115          double.IsPositiveInfinity(constraint.Interval.UpperBound)) {
116        return 0.0;
117      }
118
119      //ContainsVariableMultipleTimes(tree, out var variables);
120
121      lock (syncRoot) { EvaluatedSolutions++; }
122
123      double upperBound;
124      if (double.IsNegativeInfinity(constraint.Interval.LowerBound)) {
125        upperBound = EstimateUpperBound(tree, variableRanges.GetReadonlyDictionary(), MaxSplit, MinWidth);
126
127        return upperBound <= constraint.Interval.UpperBound
128          ? 0.0
129          : Math.Abs(upperBound - constraint.Interval.UpperBound);
130      }
131
132      double lowerBound;
133      if (double.IsPositiveInfinity(constraint.Interval.UpperBound)) {
134        lowerBound = EstimateLowerBound(tree, variableRanges.GetReadonlyDictionary(), MaxSplit, MinWidth);
135
136        return lowerBound <= constraint.Interval.LowerBound
137          ? 0.0
138          : Math.Abs(lowerBound - constraint.Interval.LowerBound);
139      }
140
141      var ranges = variableRanges.GetReadonlyDictionary();
142      lowerBound = EstimateLowerBound(tree, ranges, MaxSplit, MinWidth);
143      upperBound = EstimateUpperBound(tree, ranges, MaxSplit, MinWidth);
144
145      var res = 0.0;
146
147      res += upperBound <= constraint.Interval.UpperBound ? 0.0 : Math.Abs(upperBound - constraint.Interval.UpperBound);
148      res += lowerBound <= constraint.Interval.LowerBound ? 0.0 : Math.Abs(lowerBound - constraint.Interval.LowerBound);
149
150      return res;
151    }
152
153    public void ClearState() {
154      EvaluatedSolutions = 0;
155    }
156
157    public Interval GetModelBound(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
158      lock (syncRoot) { EvaluatedSolutions++; }
159      var resultInterval = UseIntervalSplitting
160        ? EstimateBounds(tree, variableRanges.GetReadonlyDictionary(), MaxSplit, MinWidth)
161        : EstimateBounds(tree, variableRanges.GetReadonlyDictionary());
162
163      if (resultInterval.IsInfiniteOrUndefined || resultInterval.LowerBound <= resultInterval.UpperBound)
164        return resultInterval;
165      return new Interval(resultInterval.UpperBound, resultInterval.LowerBound);
166    }
167
168    public IDictionary<ISymbolicExpressionTreeNode, Interval> GetModelNodesBounds(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
169      throw new NotSupportedException("Model nodes bounds are not supported.");
170    }
171
172    public void InitializeState() {
173      EvaluatedSolutions = 0;
174    }
175
176    public bool IsCompatible(ISymbolicExpressionTree tree) {
177      var containsUnknownSymbols = (
178        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
179        where
180          !(n.Symbol is Variable) &&
181          !(n.Symbol is Constant) &&
182          !(n.Symbol is StartSymbol) &&
183          !(n.Symbol is Addition) &&
184          !(n.Symbol is Subtraction) &&
185          !(n.Symbol is Multiplication) &&
186          !(n.Symbol is Division) &&
187          !(n.Symbol is Sine) &&
188          !(n.Symbol is Cosine) &&
189          !(n.Symbol is Tangent) &&
190          !(n.Symbol is HyperbolicTangent) &&
191          !(n.Symbol is Logarithm) &&
192          !(n.Symbol is Exponential) &&
193          !(n.Symbol is Square) &&
194          !(n.Symbol is SquareRoot) &&
195          !(n.Symbol is Cube) &&
196          !(n.Symbol is CubeRoot) &&
197          !(n.Symbol is Absolute) &&
198          !(n.Symbol is AnalyticQuotient)
199        select n).Any();
200      return !containsUnknownSymbols;
201    }
202
203    #region compile a tree into a IA arithmetic lambda and estimate bounds
204    static Expression MakeExpr(ISymbolicExpressionTreeNode node, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices, Expression args) {
205      Expression expr(ISymbolicExpressionTreeNode n) => MakeExpr(n, variableRanges, variableIndices, args);
206      var opCode = OpCodes.MapSymbolToOpCode(node);
207
208      switch (opCode) {
209        case OpCodes.Variable: {
210            var name = (node as VariableTreeNode).VariableName;
211            var weight = (node as VariableTreeNode).Weight;
212            var index = variableIndices[name];
213            return Expression.Multiply(
214              Expression.Constant(weight, typeof(double)),
215              Expression.ArrayIndex(args, Expression.Constant(index, typeof(int)))
216            );
217          }
218        case OpCodes.Constant: {
219            var v = (node as ConstantTreeNode).Value;
220            // we have to make an interval out of the constant because this may be the root of the tree (and we are expected to return an Interval)
221            return Expression.Constant(new Interval(v, v), typeof(Interval));
222          }
223        case OpCodes.Add: {
224            var e = expr(node.GetSubtree(0));
225            foreach (var s in node.Subtrees.Skip(1)) {
226              e = Expression.Add(e, expr(s));
227            }
228            return e;
229          }
230        case OpCodes.Sub: {
231            var e = expr(node.GetSubtree(0));
232            if (node.SubtreeCount == 1) {
233              return Expression.Subtract(Expression.Constant(0.0, typeof(double)), e);
234            }
235            foreach (var s in node.Subtrees.Skip(1)) {
236              e = Expression.Subtract(e, expr(s));
237            }
238            return e;
239          }
240        case OpCodes.Mul: {
241            var e = expr(node.GetSubtree(0));
242            foreach (var s in node.Subtrees.Skip(1)) {
243              e = Expression.Multiply(e, expr(s));
244            }
245            return e;
246          }
247        case OpCodes.Div: {
248            var e1 = expr(node.GetSubtree(0));
249            if (node.SubtreeCount == 1) {
250              return Expression.Divide(Expression.Constant(1.0, typeof(double)), e1);
251            }
252            // division is more expensive than multiplication so we use this construct
253            var e2 = expr(node.GetSubtree(1));
254            foreach (var s in node.Subtrees.Skip(2)) {
255              e2 = Expression.Multiply(e2, expr(s));
256            }
257            return Expression.Divide(e1, e2);
258          }
259        // all these cases share the same code: get method info by name, emit call expression
260        case OpCodes.Exp:
261        case OpCodes.Log:
262        case OpCodes.Sin:
263        case OpCodes.Cos:
264        case OpCodes.Tan:
265        case OpCodes.Tanh:
266        case OpCodes.Square:
267        case OpCodes.Cube:
268        case OpCodes.SquareRoot:
269        case OpCodes.CubeRoot:
270        case OpCodes.Absolute:
271        case OpCodes.AnalyticQuotient: {
272            var arg = expr(node.GetSubtree(0));
273            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { arg.Type });
274            return Expression.Call(fun, arg);
275          }
276        default: {
277            throw new Exception($"Unsupported OpCode {opCode} encountered.");
278          }
279      }
280    }
281
282    public static IReadOnlyDictionary<string, int> GetVariableIndices(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableIntervals, out Interval[] inputIntervals) {
283      var variableIndices = new Dictionary<string, int>();
284      var root = tree.Root;
285      while (root.Symbol is ProgramRootSymbol || root.Symbol is StartSymbol) {
286        root = root.GetSubtree(0);
287      }
288      inputIntervals = new Interval[variableIntervals.Count];
289      int count = 0;
290      foreach (var node in root.IterateNodesPrefix()) {
291        if (node is VariableTreeNode varNode) {
292          var name = varNode.VariableName;
293          if (!variableIndices.ContainsKey(name)) {
294            variableIndices[name] = count;
295            inputIntervals[count] = variableIntervals[name];
296            ++count;
297          }
298        }
299      }
300      Array.Resize<Interval>(ref inputIntervals, count);
301      return variableIndices;
302    }
303
304    public static Func<Interval[], Interval> Compile(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices) {
305      var root = tree.Root.GetSubtree(0).GetSubtree(0);
306      var args = Expression.Parameter(typeof(Interval[]));
307      var expr = MakeExpr(root, variableRanges, variableIndices, args);
308      return Expression.Lambda<Func<Interval[], Interval>>(expr, args).Compile();
309    }
310
311    public static Interval EstimateBounds(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges, int n = 0, double w = 1e-5) {
312      var variableIndices = GetVariableIndices(tree, variableRanges, out Interval[] x);
313      var f = Compile(tree, variableRanges, variableIndices);
314      if (n == 0) return f(x);
315      var inf = EstimateBound(x, f, true, n, w);
316      var sup = EstimateBound(x, f, false, n, w);
317      return inf < sup ? new Interval(inf, sup) : new Interval(sup, inf);
318    }
319    public  double EstimateLowerBound(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges, int n = 1000, double w = 1e-5) {
320      var variableIndices = GetVariableIndices(tree, variableRanges, out Interval[] x);
321      var f = Compile(tree, variableRanges, variableIndices);
322      var inf = EstimateBound(x, f, true, n, w);
323      return inf;
324    }
325
326    public double EstimateUpperBound(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges, int n = 1000, double w = 1e-5) {
327      var variableIndices = GetVariableIndices(tree, variableRanges, out Interval[] x);
328      var f = Compile(tree, variableRanges, variableIndices);
329      var sup = EstimateBound(x, f, false, n, w);
330      return sup;
331    }
332
333    static double EstimateBound(Interval[] x, Func<Interval[], Interval> f, bool m = false, int n = 1000, double w = 1e-4) {
334      double getBound(Interval iv) => m ? iv.LowerBound : -iv.UpperBound;
335      double getVolume(Interval[] box) => box.Aggregate(1.0, (acc, iv) => acc * iv.Width);
336
337      var splits = Enumerable.Range(0, x.Length).Select(_ => new List<Interval>()).ToArray();
338      var newbox = new Interval[x.Length];
339
340      int compare(Tuple<double, double, Interval[]> a, Tuple<double, double, Interval[]> b) {
341        var res = a.Item1.CompareTo(b.Item1);
342        if (res == 0) {
343          res = b.Item2.CompareTo(a.Item2);
344        }
345        return res;
346      }
347
348      var q = new SortedSet<Tuple<double, double, Interval[]>>(Comparer<Tuple<double, double, Interval[]>>.Create(compare)) {
349        Tuple.Create(getBound(f(x)), getVolume(x), x)
350      };
351
352
353      var bestBound = double.MaxValue;
354
355      // examine all the ordered pairs in the cartesian product
356      void next_pair(int i) {
357        if (i == splits.Length) {
358          var tmp = newbox.ToArray(); // make a copy to put in the queue
359          q.Add(Tuple.Create(getBound(f(tmp)), getVolume(tmp), tmp));
360          return;
361        }
362
363        foreach (var iv in splits[i]) {
364          newbox[i] = iv;
365          next_pair(i + 1);
366        }
367      }
368
369      while (q.Count > 0 && n-- > 0) {
370        var currentBound = q.Min; q.Remove(currentBound);
371        var bound = currentBound.Item1;
372        var box = currentBound.Item3;
373
374        if (!box.Any(b => b.Width > w)) {
375          bestBound = Math.Min(bestBound, bound);
376          continue;
377        }
378
379        // do the splits
380        for (int i = 0; i < box.Length; ++i) {
381          splits[i].Clear();
382          var iv = box[i];
383          if (iv.Width > w) {
384            var t = iv.Split();
385            splits[i].Add(t.Item1);
386            splits[i].Add(t.Item2);
387          } else {
388            splits[i].Add(iv);
389          }
390        }
391        next_pair(0);
392      }
393      if (q.Count > 0) {
394        bestBound = Math.Min(bestBound, q.First().Item1);
395      }
396      return m ? bestBound : -bestBound;
397    }
398    #endregion
399  }
400}
Note: See TracBrowser for help on using the repository browser.