Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/IntervalArithCompiledExpressionBoundsEstimator.cs

Last change on this file was 18132, checked in by gkronber, 3 years ago

#3140: merged r18091:18131 from branch to trunk

File size: 11.1 KB
RevLine 
[17896]1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using HEAL.Attic;
[17772]25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using System;
31using System.Collections.Generic;
32using System.Linq;
33using System.Linq.Expressions;
34
[17891]35namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
[17772]36  [StorableType("60015D64-5D8B-408A-90A1-E4111BC114D4")]
[17891]37  [Item("Interval Arithmetic Compiled Expression Bounds Estimator", "Compile a symbolic model into a lambda and use it to evaluate model bounds.")]
38  public class IntervalArithCompiledExpressionBoundsEstimator : ParameterizedNamedItem, IBoundsEstimator {
[17772]39    // interval method names
40    private static readonly Dictionary<byte, string> methodName = new Dictionary<byte, string>() {
41      { OpCodes.Add, "Add" },
42      { OpCodes.Sub, "Subtract" },
43      { OpCodes.Mul, "Multiply" },
44      { OpCodes.Div, "Divide" },
45      { OpCodes.Sin, "Sine" },
46      { OpCodes.Cos, "Cosine" },
47      { OpCodes.Tan, "Tangens" },
48      { OpCodes.Tanh, "HyperbolicTangent" },
49      { OpCodes.Log, "Logarithm" },
50      { OpCodes.Exp, "Exponential" },
51      { OpCodes.Square, "Square" },
52      { OpCodes.Cube, "Cube" },
53      { OpCodes.SquareRoot, "SquareRoot" },
54      { OpCodes.CubeRoot, "CubicRoot" },
55      { OpCodes.Absolute, "Absolute" },
56      { OpCodes.AnalyticQuotient, "AnalyticalQuotient" },
57    };
58
59    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
60    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
61      get => (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
62    }
63    public int EvaluatedSolutions {
64      get => EvaluatedSolutionsParameter.Value.Value;
65      set => EvaluatedSolutionsParameter.Value.Value = value;
66    }
67
68    private readonly object syncRoot = new object();
69
[17896]70    public IntervalArithCompiledExpressionBoundsEstimator() : base("Interval Arithmetic Compiled Expression Bounds Estimator",
[17772]71      "Estimates the bounds of the model with interval arithmetic, by first compiling the model into a lambda.") {
72      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName,
73        "A counter for the total number of solutions the estimator has evaluated.", new IntValue(0)));
74    }
75
76    [StorableConstructor]
[17909]77    protected IntervalArithCompiledExpressionBoundsEstimator(StorableConstructorFlag _) : base(_) { }
[17772]78
[17907]79    protected IntervalArithCompiledExpressionBoundsEstimator(IntervalArithCompiledExpressionBoundsEstimator original, Cloner cloner) : base(original, cloner) { }
[17772]80
81    public override IDeepCloneable Clone(Cloner cloner) {
[17891]82      return new IntervalArithCompiledExpressionBoundsEstimator(this, cloner);
[17772]83    }
84
[17891]85    public double GetConstraintViolation(ISymbolicExpressionTree tree, IntervalCollection variableRanges, ShapeConstraint constraint) {
86      var modelBound = GetModelBound(tree, variableRanges);
87      if (constraint.Interval.Contains(modelBound)) return 0.0;
88      return Math.Abs(modelBound.LowerBound - constraint.Interval.LowerBound) +
89             Math.Abs(modelBound.UpperBound - constraint.Interval.UpperBound);
[17772]90    }
91
92    public void ClearState() {
93      EvaluatedSolutions = 0;
94    }
95
96    public Interval GetModelBound(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
97      lock (syncRoot) { EvaluatedSolutions++; }
[17891]98      var resultInterval = EstimateBounds(tree, variableRanges.GetReadonlyDictionary());
[17772]99
100      if (resultInterval.IsInfiniteOrUndefined || resultInterval.LowerBound <= resultInterval.UpperBound)
101        return resultInterval;
102      return new Interval(resultInterval.UpperBound, resultInterval.LowerBound);
103    }
104
[17887]105    public IDictionary<ISymbolicExpressionTreeNode, Interval> GetModelNodeBounds(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
[17772]106      throw new NotSupportedException("Model nodes bounds are not supported.");
107    }
108
109    public void InitializeState() {
110      EvaluatedSolutions = 0;
111    }
112
113    public bool IsCompatible(ISymbolicExpressionTree tree) {
114      var containsUnknownSymbols = (
115        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
116        where
117          !(n.Symbol is Variable) &&
[18132]118          !(n.Symbol is Number) &&
[17772]119          !(n.Symbol is Constant) &&
120          !(n.Symbol is StartSymbol) &&
121          !(n.Symbol is Addition) &&
122          !(n.Symbol is Subtraction) &&
123          !(n.Symbol is Multiplication) &&
124          !(n.Symbol is Division) &&
125          !(n.Symbol is Sine) &&
126          !(n.Symbol is Cosine) &&
127          !(n.Symbol is Tangent) &&
128          !(n.Symbol is HyperbolicTangent) &&
129          !(n.Symbol is Logarithm) &&
130          !(n.Symbol is Exponential) &&
131          !(n.Symbol is Square) &&
132          !(n.Symbol is SquareRoot) &&
133          !(n.Symbol is Cube) &&
134          !(n.Symbol is CubeRoot) &&
135          !(n.Symbol is Absolute) &&
136          !(n.Symbol is AnalyticQuotient)
137        select n).Any();
138      return !containsUnknownSymbols;
139    }
140
141    #region compile a tree into a IA arithmetic lambda and estimate bounds
142    static Expression MakeExpr(ISymbolicExpressionTreeNode node, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices, Expression args) {
143      Expression expr(ISymbolicExpressionTreeNode n) => MakeExpr(n, variableRanges, variableIndices, args);
144      var opCode = OpCodes.MapSymbolToOpCode(node);
145
146      switch (opCode) {
147        case OpCodes.Variable: {
148            var name = (node as VariableTreeNode).VariableName;
149            var weight = (node as VariableTreeNode).Weight;
150            var index = variableIndices[name];
151            return Expression.Multiply(
152              Expression.Constant(weight, typeof(double)),
153              Expression.ArrayIndex(args, Expression.Constant(index, typeof(int)))
154            );
155          }
[18132]156        case OpCodes.Constant: // fall through
157        case OpCodes.Number: {
158            var v = (node as INumericTreeNode).Value;
159            // we have to make an interval out of the number because this may be the root of the tree (and we are expected to return an Interval)
[17772]160            return Expression.Constant(new Interval(v, v), typeof(Interval));
161          }
162        case OpCodes.Add: {
163            var e = expr(node.GetSubtree(0));
164            foreach (var s in node.Subtrees.Skip(1)) {
165              e = Expression.Add(e, expr(s));
166            }
167            return e;
168          }
169        case OpCodes.Sub: {
170            var e = expr(node.GetSubtree(0));
171            if (node.SubtreeCount == 1) {
172              return Expression.Subtract(Expression.Constant(0.0, typeof(double)), e);
173            }
174            foreach (var s in node.Subtrees.Skip(1)) {
175              e = Expression.Subtract(e, expr(s));
176            }
177            return e;
178          }
179        case OpCodes.Mul: {
180            var e = expr(node.GetSubtree(0));
181            foreach (var s in node.Subtrees.Skip(1)) {
182              e = Expression.Multiply(e, expr(s));
183            }
184            return e;
185          }
186        case OpCodes.Div: {
187            var e1 = expr(node.GetSubtree(0));
188            if (node.SubtreeCount == 1) {
189              return Expression.Divide(Expression.Constant(1.0, typeof(double)), e1);
190            }
191            // division is more expensive than multiplication so we use this construct
192            var e2 = expr(node.GetSubtree(1));
193            foreach (var s in node.Subtrees.Skip(2)) {
194              e2 = Expression.Multiply(e2, expr(s));
195            }
196            return Expression.Divide(e1, e2);
197          }
[17906]198        case OpCodes.AnalyticQuotient: {
199            var a = expr(node.GetSubtree(0));
200            var b = expr(node.GetSubtree(1));
201            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { a.Type, b.Type });
202            return Expression.Call(fun, a, b);
203          }
[17772]204        // all these cases share the same code: get method info by name, emit call expression
205        case OpCodes.Exp:
206        case OpCodes.Log:
207        case OpCodes.Sin:
208        case OpCodes.Cos:
209        case OpCodes.Tan:
210        case OpCodes.Tanh:
211        case OpCodes.Square:
212        case OpCodes.Cube:
213        case OpCodes.SquareRoot:
214        case OpCodes.CubeRoot:
[17906]215        case OpCodes.Absolute: {
[17772]216            var arg = expr(node.GetSubtree(0));
217            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { arg.Type });
218            return Expression.Call(fun, arg);
219          }
220        default: {
221            throw new Exception($"Unsupported OpCode {opCode} encountered.");
222          }
223      }
224    }
225
226    public static IReadOnlyDictionary<string, int> GetVariableIndices(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableIntervals, out Interval[] inputIntervals) {
227      var variableIndices = new Dictionary<string, int>();
228      var root = tree.Root;
229      while (root.Symbol is ProgramRootSymbol || root.Symbol is StartSymbol) {
230        root = root.GetSubtree(0);
231      }
232      inputIntervals = new Interval[variableIntervals.Count];
233      int count = 0;
234      foreach (var node in root.IterateNodesPrefix()) {
235        if (node is VariableTreeNode varNode) {
236          var name = varNode.VariableName;
237          if (!variableIndices.ContainsKey(name)) {
238            variableIndices[name] = count;
239            inputIntervals[count] = variableIntervals[name];
240            ++count;
241          }
242        }
243      }
[17891]244      Array.Resize(ref inputIntervals, count);
[17772]245      return variableIndices;
246    }
247
248    public static Func<Interval[], Interval> Compile(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices) {
249      var root = tree.Root.GetSubtree(0).GetSubtree(0);
250      var args = Expression.Parameter(typeof(Interval[]));
251      var expr = MakeExpr(root, variableRanges, variableIndices, args);
252      return Expression.Lambda<Func<Interval[], Interval>>(expr, args).Compile();
253    }
254
[17891]255    public static Interval EstimateBounds(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges) {
[17772]256      var variableIndices = GetVariableIndices(tree, variableRanges, out Interval[] x);
257      var f = Compile(tree, variableRanges, variableIndices);
[17891]258      return f(x);
[17772]259    }
260    #endregion
261  }
262}
Note: See TracBrowser for help on using the repository browser.