Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3073_IA_constraint_splitting_reintegration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/IntervalArithCompiledExpressionBoundsEstimator.cs @ 17891

Last change on this file since 17891 was 17891, checked in by gkronber, 3 years ago

#3073 refactoring to prepare for trunk reintegration

File size: 9.9 KB
Line 
1using HEAL.Attic;
2using HeuristicLab.Common;
3using HeuristicLab.Core;
4using HeuristicLab.Data;
5using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
6using HeuristicLab.Parameters;
7using System;
8using System.Collections.Generic;
9using System.Linq;
10using System.Linq.Expressions;
11
12namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
13  [StorableType("60015D64-5D8B-408A-90A1-E4111BC114D4")]
14  [Item("Interval Arithmetic Compiled Expression Bounds Estimator", "Compile a symbolic model into a lambda and use it to evaluate model bounds.")]
15  public class IntervalArithCompiledExpressionBoundsEstimator : ParameterizedNamedItem, IBoundsEstimator {
16    // interval method names
17    private static readonly Dictionary<byte, string> methodName = new Dictionary<byte, string>() {
18      { OpCodes.Add, "Add" },
19      { OpCodes.Sub, "Subtract" },
20      { OpCodes.Mul, "Multiply" },
21      { OpCodes.Div, "Divide" },
22      { OpCodes.Sin, "Sine" },
23      { OpCodes.Cos, "Cosine" },
24      { OpCodes.Tan, "Tangens" },
25      { OpCodes.Tanh, "HyperbolicTangent" },
26      { OpCodes.Log, "Logarithm" },
27      { OpCodes.Exp, "Exponential" },
28      { OpCodes.Square, "Square" },
29      { OpCodes.Cube, "Cube" },
30      { OpCodes.SquareRoot, "SquareRoot" },
31      { OpCodes.CubeRoot, "CubicRoot" },
32      { OpCodes.Absolute, "Absolute" },
33      { OpCodes.AnalyticQuotient, "AnalyticalQuotient" },
34    };
35
36    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
37    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
38      get => (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
39    }
40    public int EvaluatedSolutions {
41      get => EvaluatedSolutionsParameter.Value.Value;
42      set => EvaluatedSolutionsParameter.Value.Value = value;
43    }
44
45    private readonly object syncRoot = new object();
46
47    public IntervalArithCompiledExpressionBoundsEstimator() : base("Interval Arith Bounds Estimator",
48      "Estimates the bounds of the model with interval arithmetic, by first compiling the model into a lambda.") {
49      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName,
50        "A counter for the total number of solutions the estimator has evaluated.", new IntValue(0)));
51    }
52
53    [StorableConstructor]
54    private IntervalArithCompiledExpressionBoundsEstimator(StorableConstructorFlag _) : base(_) { }
55
56    private IntervalArithCompiledExpressionBoundsEstimator(IntervalArithCompiledExpressionBoundsEstimator original, Cloner cloner) : base(original, cloner) { }
57
58    public override IDeepCloneable Clone(Cloner cloner) {
59      return new IntervalArithCompiledExpressionBoundsEstimator(this, cloner);
60    }
61
62    public double GetConstraintViolation(ISymbolicExpressionTree tree, IntervalCollection variableRanges, ShapeConstraint constraint) {
63      var modelBound = GetModelBound(tree, variableRanges);
64      if (constraint.Interval.Contains(modelBound)) return 0.0;
65      return Math.Abs(modelBound.LowerBound - constraint.Interval.LowerBound) +
66             Math.Abs(modelBound.UpperBound - constraint.Interval.UpperBound);
67    }
68
69    public void ClearState() {
70      EvaluatedSolutions = 0;
71    }
72
73    public Interval GetModelBound(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
74      lock (syncRoot) { EvaluatedSolutions++; }
75      var resultInterval = EstimateBounds(tree, variableRanges.GetReadonlyDictionary());
76
77      if (resultInterval.IsInfiniteOrUndefined || resultInterval.LowerBound <= resultInterval.UpperBound)
78        return resultInterval;
79      return new Interval(resultInterval.UpperBound, resultInterval.LowerBound);
80    }
81
82    public IDictionary<ISymbolicExpressionTreeNode, Interval> GetModelNodeBounds(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
83      throw new NotSupportedException("Model nodes bounds are not supported.");
84    }
85
86    public void InitializeState() {
87      EvaluatedSolutions = 0;
88    }
89
90    public bool IsCompatible(ISymbolicExpressionTree tree) {
91      var containsUnknownSymbols = (
92        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
93        where
94          !(n.Symbol is Variable) &&
95          !(n.Symbol is Constant) &&
96          !(n.Symbol is StartSymbol) &&
97          !(n.Symbol is Addition) &&
98          !(n.Symbol is Subtraction) &&
99          !(n.Symbol is Multiplication) &&
100          !(n.Symbol is Division) &&
101          !(n.Symbol is Sine) &&
102          !(n.Symbol is Cosine) &&
103          !(n.Symbol is Tangent) &&
104          !(n.Symbol is HyperbolicTangent) &&
105          !(n.Symbol is Logarithm) &&
106          !(n.Symbol is Exponential) &&
107          !(n.Symbol is Square) &&
108          !(n.Symbol is SquareRoot) &&
109          !(n.Symbol is Cube) &&
110          !(n.Symbol is CubeRoot) &&
111          !(n.Symbol is Absolute) &&
112          !(n.Symbol is AnalyticQuotient)
113        select n).Any();
114      return !containsUnknownSymbols;
115    }
116
117    #region compile a tree into a IA arithmetic lambda and estimate bounds
118    static Expression MakeExpr(ISymbolicExpressionTreeNode node, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices, Expression args) {
119      Expression expr(ISymbolicExpressionTreeNode n) => MakeExpr(n, variableRanges, variableIndices, args);
120      var opCode = OpCodes.MapSymbolToOpCode(node);
121
122      switch (opCode) {
123        case OpCodes.Variable: {
124            var name = (node as VariableTreeNode).VariableName;
125            var weight = (node as VariableTreeNode).Weight;
126            var index = variableIndices[name];
127            return Expression.Multiply(
128              Expression.Constant(weight, typeof(double)),
129              Expression.ArrayIndex(args, Expression.Constant(index, typeof(int)))
130            );
131          }
132        case OpCodes.Constant: {
133            var v = (node as ConstantTreeNode).Value;
134            // we have to make an interval out of the constant because this may be the root of the tree (and we are expected to return an Interval)
135            return Expression.Constant(new Interval(v, v), typeof(Interval));
136          }
137        case OpCodes.Add: {
138            var e = expr(node.GetSubtree(0));
139            foreach (var s in node.Subtrees.Skip(1)) {
140              e = Expression.Add(e, expr(s));
141            }
142            return e;
143          }
144        case OpCodes.Sub: {
145            var e = expr(node.GetSubtree(0));
146            if (node.SubtreeCount == 1) {
147              return Expression.Subtract(Expression.Constant(0.0, typeof(double)), e);
148            }
149            foreach (var s in node.Subtrees.Skip(1)) {
150              e = Expression.Subtract(e, expr(s));
151            }
152            return e;
153          }
154        case OpCodes.Mul: {
155            var e = expr(node.GetSubtree(0));
156            foreach (var s in node.Subtrees.Skip(1)) {
157              e = Expression.Multiply(e, expr(s));
158            }
159            return e;
160          }
161        case OpCodes.Div: {
162            var e1 = expr(node.GetSubtree(0));
163            if (node.SubtreeCount == 1) {
164              return Expression.Divide(Expression.Constant(1.0, typeof(double)), e1);
165            }
166            // division is more expensive than multiplication so we use this construct
167            var e2 = expr(node.GetSubtree(1));
168            foreach (var s in node.Subtrees.Skip(2)) {
169              e2 = Expression.Multiply(e2, expr(s));
170            }
171            return Expression.Divide(e1, e2);
172          }
173        // all these cases share the same code: get method info by name, emit call expression
174        case OpCodes.Exp:
175        case OpCodes.Log:
176        case OpCodes.Sin:
177        case OpCodes.Cos:
178        case OpCodes.Tan:
179        case OpCodes.Tanh:
180        case OpCodes.Square:
181        case OpCodes.Cube:
182        case OpCodes.SquareRoot:
183        case OpCodes.CubeRoot:
184        case OpCodes.Absolute:
185        case OpCodes.AnalyticQuotient: {
186            var arg = expr(node.GetSubtree(0));
187            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { arg.Type });
188            return Expression.Call(fun, arg);
189          }
190        default: {
191            throw new Exception($"Unsupported OpCode {opCode} encountered.");
192          }
193      }
194    }
195
196    public static IReadOnlyDictionary<string, int> GetVariableIndices(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableIntervals, out Interval[] inputIntervals) {
197      var variableIndices = new Dictionary<string, int>();
198      var root = tree.Root;
199      while (root.Symbol is ProgramRootSymbol || root.Symbol is StartSymbol) {
200        root = root.GetSubtree(0);
201      }
202      inputIntervals = new Interval[variableIntervals.Count];
203      int count = 0;
204      foreach (var node in root.IterateNodesPrefix()) {
205        if (node is VariableTreeNode varNode) {
206          var name = varNode.VariableName;
207          if (!variableIndices.ContainsKey(name)) {
208            variableIndices[name] = count;
209            inputIntervals[count] = variableIntervals[name];
210            ++count;
211          }
212        }
213      }
214      Array.Resize(ref inputIntervals, count);
215      return variableIndices;
216    }
217
218    public static Func<Interval[], Interval> Compile(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices) {
219      var root = tree.Root.GetSubtree(0).GetSubtree(0);
220      var args = Expression.Parameter(typeof(Interval[]));
221      var expr = MakeExpr(root, variableRanges, variableIndices, args);
222      return Expression.Lambda<Func<Interval[], Interval>>(expr, args).Compile();
223    }
224
225    public static Interval EstimateBounds(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges) {
226      var variableIndices = GetVariableIndices(tree, variableRanges, out Interval[] x);
227      var f = Compile(tree, variableRanges, variableIndices);
228      return f(x);
229    }
230    #endregion
231  }
232}
Note: See TracBrowser for help on using the repository browser.