Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/IntervalArithCompiledExpressionBoundsEstimator.cs @ 18220

Last change on this file since 18220 was 18132, checked in by gkronber, 3 years ago

#3140: merged r18091:18131 from branch to trunk

File size: 11.1 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using HEAL.Attic;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using System;
31using System.Collections.Generic;
32using System.Linq;
33using System.Linq.Expressions;
34
35namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
36  [StorableType("60015D64-5D8B-408A-90A1-E4111BC114D4")]
37  [Item("Interval Arithmetic Compiled Expression Bounds Estimator", "Compile a symbolic model into a lambda and use it to evaluate model bounds.")]
38  public class IntervalArithCompiledExpressionBoundsEstimator : ParameterizedNamedItem, IBoundsEstimator {
39    // interval method names
40    private static readonly Dictionary<byte, string> methodName = new Dictionary<byte, string>() {
41      { OpCodes.Add, "Add" },
42      { OpCodes.Sub, "Subtract" },
43      { OpCodes.Mul, "Multiply" },
44      { OpCodes.Div, "Divide" },
45      { OpCodes.Sin, "Sine" },
46      { OpCodes.Cos, "Cosine" },
47      { OpCodes.Tan, "Tangens" },
48      { OpCodes.Tanh, "HyperbolicTangent" },
49      { OpCodes.Log, "Logarithm" },
50      { OpCodes.Exp, "Exponential" },
51      { OpCodes.Square, "Square" },
52      { OpCodes.Cube, "Cube" },
53      { OpCodes.SquareRoot, "SquareRoot" },
54      { OpCodes.CubeRoot, "CubicRoot" },
55      { OpCodes.Absolute, "Absolute" },
56      { OpCodes.AnalyticQuotient, "AnalyticalQuotient" },
57    };
58
59    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
60    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
61      get => (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
62    }
63    public int EvaluatedSolutions {
64      get => EvaluatedSolutionsParameter.Value.Value;
65      set => EvaluatedSolutionsParameter.Value.Value = value;
66    }
67
68    private readonly object syncRoot = new object();
69
70    public IntervalArithCompiledExpressionBoundsEstimator() : base("Interval Arithmetic Compiled Expression Bounds Estimator",
71      "Estimates the bounds of the model with interval arithmetic, by first compiling the model into a lambda.") {
72      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName,
73        "A counter for the total number of solutions the estimator has evaluated.", new IntValue(0)));
74    }
75
76    [StorableConstructor]
77    protected IntervalArithCompiledExpressionBoundsEstimator(StorableConstructorFlag _) : base(_) { }
78
79    protected IntervalArithCompiledExpressionBoundsEstimator(IntervalArithCompiledExpressionBoundsEstimator original, Cloner cloner) : base(original, cloner) { }
80
81    public override IDeepCloneable Clone(Cloner cloner) {
82      return new IntervalArithCompiledExpressionBoundsEstimator(this, cloner);
83    }
84
85    public double GetConstraintViolation(ISymbolicExpressionTree tree, IntervalCollection variableRanges, ShapeConstraint constraint) {
86      var modelBound = GetModelBound(tree, variableRanges);
87      if (constraint.Interval.Contains(modelBound)) return 0.0;
88      return Math.Abs(modelBound.LowerBound - constraint.Interval.LowerBound) +
89             Math.Abs(modelBound.UpperBound - constraint.Interval.UpperBound);
90    }
91
92    public void ClearState() {
93      EvaluatedSolutions = 0;
94    }
95
96    public Interval GetModelBound(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
97      lock (syncRoot) { EvaluatedSolutions++; }
98      var resultInterval = EstimateBounds(tree, variableRanges.GetReadonlyDictionary());
99
100      if (resultInterval.IsInfiniteOrUndefined || resultInterval.LowerBound <= resultInterval.UpperBound)
101        return resultInterval;
102      return new Interval(resultInterval.UpperBound, resultInterval.LowerBound);
103    }
104
105    public IDictionary<ISymbolicExpressionTreeNode, Interval> GetModelNodeBounds(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
106      throw new NotSupportedException("Model nodes bounds are not supported.");
107    }
108
109    public void InitializeState() {
110      EvaluatedSolutions = 0;
111    }
112
113    public bool IsCompatible(ISymbolicExpressionTree tree) {
114      var containsUnknownSymbols = (
115        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
116        where
117          !(n.Symbol is Variable) &&
118          !(n.Symbol is Number) &&
119          !(n.Symbol is Constant) &&
120          !(n.Symbol is StartSymbol) &&
121          !(n.Symbol is Addition) &&
122          !(n.Symbol is Subtraction) &&
123          !(n.Symbol is Multiplication) &&
124          !(n.Symbol is Division) &&
125          !(n.Symbol is Sine) &&
126          !(n.Symbol is Cosine) &&
127          !(n.Symbol is Tangent) &&
128          !(n.Symbol is HyperbolicTangent) &&
129          !(n.Symbol is Logarithm) &&
130          !(n.Symbol is Exponential) &&
131          !(n.Symbol is Square) &&
132          !(n.Symbol is SquareRoot) &&
133          !(n.Symbol is Cube) &&
134          !(n.Symbol is CubeRoot) &&
135          !(n.Symbol is Absolute) &&
136          !(n.Symbol is AnalyticQuotient)
137        select n).Any();
138      return !containsUnknownSymbols;
139    }
140
141    #region compile a tree into a IA arithmetic lambda and estimate bounds
142    static Expression MakeExpr(ISymbolicExpressionTreeNode node, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices, Expression args) {
143      Expression expr(ISymbolicExpressionTreeNode n) => MakeExpr(n, variableRanges, variableIndices, args);
144      var opCode = OpCodes.MapSymbolToOpCode(node);
145
146      switch (opCode) {
147        case OpCodes.Variable: {
148            var name = (node as VariableTreeNode).VariableName;
149            var weight = (node as VariableTreeNode).Weight;
150            var index = variableIndices[name];
151            return Expression.Multiply(
152              Expression.Constant(weight, typeof(double)),
153              Expression.ArrayIndex(args, Expression.Constant(index, typeof(int)))
154            );
155          }
156        case OpCodes.Constant: // fall through
157        case OpCodes.Number: {
158            var v = (node as INumericTreeNode).Value;
159            // we have to make an interval out of the number because this may be the root of the tree (and we are expected to return an Interval)
160            return Expression.Constant(new Interval(v, v), typeof(Interval));
161          }
162        case OpCodes.Add: {
163            var e = expr(node.GetSubtree(0));
164            foreach (var s in node.Subtrees.Skip(1)) {
165              e = Expression.Add(e, expr(s));
166            }
167            return e;
168          }
169        case OpCodes.Sub: {
170            var e = expr(node.GetSubtree(0));
171            if (node.SubtreeCount == 1) {
172              return Expression.Subtract(Expression.Constant(0.0, typeof(double)), e);
173            }
174            foreach (var s in node.Subtrees.Skip(1)) {
175              e = Expression.Subtract(e, expr(s));
176            }
177            return e;
178          }
179        case OpCodes.Mul: {
180            var e = expr(node.GetSubtree(0));
181            foreach (var s in node.Subtrees.Skip(1)) {
182              e = Expression.Multiply(e, expr(s));
183            }
184            return e;
185          }
186        case OpCodes.Div: {
187            var e1 = expr(node.GetSubtree(0));
188            if (node.SubtreeCount == 1) {
189              return Expression.Divide(Expression.Constant(1.0, typeof(double)), e1);
190            }
191            // division is more expensive than multiplication so we use this construct
192            var e2 = expr(node.GetSubtree(1));
193            foreach (var s in node.Subtrees.Skip(2)) {
194              e2 = Expression.Multiply(e2, expr(s));
195            }
196            return Expression.Divide(e1, e2);
197          }
198        case OpCodes.AnalyticQuotient: {
199            var a = expr(node.GetSubtree(0));
200            var b = expr(node.GetSubtree(1));
201            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { a.Type, b.Type });
202            return Expression.Call(fun, a, b);
203          }
204        // all these cases share the same code: get method info by name, emit call expression
205        case OpCodes.Exp:
206        case OpCodes.Log:
207        case OpCodes.Sin:
208        case OpCodes.Cos:
209        case OpCodes.Tan:
210        case OpCodes.Tanh:
211        case OpCodes.Square:
212        case OpCodes.Cube:
213        case OpCodes.SquareRoot:
214        case OpCodes.CubeRoot:
215        case OpCodes.Absolute: {
216            var arg = expr(node.GetSubtree(0));
217            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { arg.Type });
218            return Expression.Call(fun, arg);
219          }
220        default: {
221            throw new Exception($"Unsupported OpCode {opCode} encountered.");
222          }
223      }
224    }
225
226    public static IReadOnlyDictionary<string, int> GetVariableIndices(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableIntervals, out Interval[] inputIntervals) {
227      var variableIndices = new Dictionary<string, int>();
228      var root = tree.Root;
229      while (root.Symbol is ProgramRootSymbol || root.Symbol is StartSymbol) {
230        root = root.GetSubtree(0);
231      }
232      inputIntervals = new Interval[variableIntervals.Count];
233      int count = 0;
234      foreach (var node in root.IterateNodesPrefix()) {
235        if (node is VariableTreeNode varNode) {
236          var name = varNode.VariableName;
237          if (!variableIndices.ContainsKey(name)) {
238            variableIndices[name] = count;
239            inputIntervals[count] = variableIntervals[name];
240            ++count;
241          }
242        }
243      }
244      Array.Resize(ref inputIntervals, count);
245      return variableIndices;
246    }
247
248    public static Func<Interval[], Interval> Compile(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices) {
249      var root = tree.Root.GetSubtree(0).GetSubtree(0);
250      var args = Expression.Parameter(typeof(Interval[]));
251      var expr = MakeExpr(root, variableRanges, variableIndices, args);
252      return Expression.Lambda<Func<Interval[], Interval>>(expr, args).Compile();
253    }
254
255    public static Interval EstimateBounds(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges) {
256      var variableIndices = GetVariableIndices(tree, variableRanges, out Interval[] x);
257      var f = Compile(tree, variableRanges, variableIndices);
258      return f(x);
259    }
260    #endregion
261  }
262}
Note: See TracBrowser for help on using the repository browser.