Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3026_IntegrationIntoSymSpace/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/IntervalArithCompiledExpressionBoundsEstimator.cs @ 18027

Last change on this file since 18027 was 17928, checked in by dpiringe, 4 years ago

#3026

  • merged trunk into branch
File size: 11.0 KB
RevLine 
[17896]1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using HEAL.Attic;
[17772]25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using System;
31using System.Collections.Generic;
32using System.Linq;
33using System.Linq.Expressions;
34
[17891]35namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
[17772]36  [StorableType("60015D64-5D8B-408A-90A1-E4111BC114D4")]
[17891]37  [Item("Interval Arithmetic Compiled Expression Bounds Estimator", "Compile a symbolic model into a lambda and use it to evaluate model bounds.")]
38  public class IntervalArithCompiledExpressionBoundsEstimator : ParameterizedNamedItem, IBoundsEstimator {
[17772]39    // interval method names
40    private static readonly Dictionary<byte, string> methodName = new Dictionary<byte, string>() {
41      { OpCodes.Add, "Add" },
42      { OpCodes.Sub, "Subtract" },
43      { OpCodes.Mul, "Multiply" },
44      { OpCodes.Div, "Divide" },
45      { OpCodes.Sin, "Sine" },
46      { OpCodes.Cos, "Cosine" },
47      { OpCodes.Tan, "Tangens" },
48      { OpCodes.Tanh, "HyperbolicTangent" },
49      { OpCodes.Log, "Logarithm" },
50      { OpCodes.Exp, "Exponential" },
51      { OpCodes.Square, "Square" },
52      { OpCodes.Cube, "Cube" },
53      { OpCodes.SquareRoot, "SquareRoot" },
54      { OpCodes.CubeRoot, "CubicRoot" },
55      { OpCodes.Absolute, "Absolute" },
56      { OpCodes.AnalyticQuotient, "AnalyticalQuotient" },
57    };
58
59    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
60    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
61      get => (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
62    }
63    public int EvaluatedSolutions {
64      get => EvaluatedSolutionsParameter.Value.Value;
65      set => EvaluatedSolutionsParameter.Value.Value = value;
66    }
67
68    private readonly object syncRoot = new object();
69
[17896]70    public IntervalArithCompiledExpressionBoundsEstimator() : base("Interval Arithmetic Compiled Expression Bounds Estimator",
[17772]71      "Estimates the bounds of the model with interval arithmetic, by first compiling the model into a lambda.") {
72      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName,
73        "A counter for the total number of solutions the estimator has evaluated.", new IntValue(0)));
74    }
75
76    [StorableConstructor]
[17909]77    protected IntervalArithCompiledExpressionBoundsEstimator(StorableConstructorFlag _) : base(_) { }
[17772]78
[17907]79    protected IntervalArithCompiledExpressionBoundsEstimator(IntervalArithCompiledExpressionBoundsEstimator original, Cloner cloner) : base(original, cloner) { }
[17772]80
81    public override IDeepCloneable Clone(Cloner cloner) {
[17891]82      return new IntervalArithCompiledExpressionBoundsEstimator(this, cloner);
[17772]83    }
84
[17891]85    public double GetConstraintViolation(ISymbolicExpressionTree tree, IntervalCollection variableRanges, ShapeConstraint constraint) {
86      var modelBound = GetModelBound(tree, variableRanges);
87      if (constraint.Interval.Contains(modelBound)) return 0.0;
88      return Math.Abs(modelBound.LowerBound - constraint.Interval.LowerBound) +
89             Math.Abs(modelBound.UpperBound - constraint.Interval.UpperBound);
[17772]90    }
91
92    public void ClearState() {
93      EvaluatedSolutions = 0;
94    }
95
96    public Interval GetModelBound(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
97      lock (syncRoot) { EvaluatedSolutions++; }
[17891]98      var resultInterval = EstimateBounds(tree, variableRanges.GetReadonlyDictionary());
[17772]99
100      if (resultInterval.IsInfiniteOrUndefined || resultInterval.LowerBound <= resultInterval.UpperBound)
101        return resultInterval;
102      return new Interval(resultInterval.UpperBound, resultInterval.LowerBound);
103    }
104
[17887]105    public IDictionary<ISymbolicExpressionTreeNode, Interval> GetModelNodeBounds(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
[17772]106      throw new NotSupportedException("Model nodes bounds are not supported.");
107    }
108
109    public void InitializeState() {
110      EvaluatedSolutions = 0;
111    }
112
113    public bool IsCompatible(ISymbolicExpressionTree tree) {
114      var containsUnknownSymbols = (
115        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
116        where
117          !(n.Symbol is Variable) &&
118          !(n.Symbol is Constant) &&
119          !(n.Symbol is StartSymbol) &&
120          !(n.Symbol is Addition) &&
121          !(n.Symbol is Subtraction) &&
122          !(n.Symbol is Multiplication) &&
123          !(n.Symbol is Division) &&
124          !(n.Symbol is Sine) &&
125          !(n.Symbol is Cosine) &&
126          !(n.Symbol is Tangent) &&
127          !(n.Symbol is HyperbolicTangent) &&
128          !(n.Symbol is Logarithm) &&
129          !(n.Symbol is Exponential) &&
130          !(n.Symbol is Square) &&
131          !(n.Symbol is SquareRoot) &&
132          !(n.Symbol is Cube) &&
133          !(n.Symbol is CubeRoot) &&
134          !(n.Symbol is Absolute) &&
135          !(n.Symbol is AnalyticQuotient)
136        select n).Any();
137      return !containsUnknownSymbols;
138    }
139
140    #region compile a tree into a IA arithmetic lambda and estimate bounds
141    static Expression MakeExpr(ISymbolicExpressionTreeNode node, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices, Expression args) {
142      Expression expr(ISymbolicExpressionTreeNode n) => MakeExpr(n, variableRanges, variableIndices, args);
143      var opCode = OpCodes.MapSymbolToOpCode(node);
144
145      switch (opCode) {
146        case OpCodes.Variable: {
147            var name = (node as VariableTreeNode).VariableName;
148            var weight = (node as VariableTreeNode).Weight;
149            var index = variableIndices[name];
150            return Expression.Multiply(
151              Expression.Constant(weight, typeof(double)),
152              Expression.ArrayIndex(args, Expression.Constant(index, typeof(int)))
153            );
154          }
155        case OpCodes.Constant: {
156            var v = (node as ConstantTreeNode).Value;
157            // we have to make an interval out of the constant because this may be the root of the tree (and we are expected to return an Interval)
158            return Expression.Constant(new Interval(v, v), typeof(Interval));
159          }
160        case OpCodes.Add: {
161            var e = expr(node.GetSubtree(0));
162            foreach (var s in node.Subtrees.Skip(1)) {
163              e = Expression.Add(e, expr(s));
164            }
165            return e;
166          }
167        case OpCodes.Sub: {
168            var e = expr(node.GetSubtree(0));
169            if (node.SubtreeCount == 1) {
170              return Expression.Subtract(Expression.Constant(0.0, typeof(double)), e);
171            }
172            foreach (var s in node.Subtrees.Skip(1)) {
173              e = Expression.Subtract(e, expr(s));
174            }
175            return e;
176          }
177        case OpCodes.Mul: {
178            var e = expr(node.GetSubtree(0));
179            foreach (var s in node.Subtrees.Skip(1)) {
180              e = Expression.Multiply(e, expr(s));
181            }
182            return e;
183          }
184        case OpCodes.Div: {
185            var e1 = expr(node.GetSubtree(0));
186            if (node.SubtreeCount == 1) {
187              return Expression.Divide(Expression.Constant(1.0, typeof(double)), e1);
188            }
189            // division is more expensive than multiplication so we use this construct
190            var e2 = expr(node.GetSubtree(1));
191            foreach (var s in node.Subtrees.Skip(2)) {
192              e2 = Expression.Multiply(e2, expr(s));
193            }
194            return Expression.Divide(e1, e2);
195          }
[17906]196        case OpCodes.AnalyticQuotient: {
197            var a = expr(node.GetSubtree(0));
198            var b = expr(node.GetSubtree(1));
199            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { a.Type, b.Type });
200            return Expression.Call(fun, a, b);
201          }
[17772]202        // all these cases share the same code: get method info by name, emit call expression
203        case OpCodes.Exp:
204        case OpCodes.Log:
205        case OpCodes.Sin:
206        case OpCodes.Cos:
207        case OpCodes.Tan:
208        case OpCodes.Tanh:
209        case OpCodes.Square:
210        case OpCodes.Cube:
211        case OpCodes.SquareRoot:
212        case OpCodes.CubeRoot:
[17906]213        case OpCodes.Absolute: {
[17772]214            var arg = expr(node.GetSubtree(0));
215            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { arg.Type });
216            return Expression.Call(fun, arg);
217          }
218        default: {
219            throw new Exception($"Unsupported OpCode {opCode} encountered.");
220          }
221      }
222    }
223
224    public static IReadOnlyDictionary<string, int> GetVariableIndices(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableIntervals, out Interval[] inputIntervals) {
225      var variableIndices = new Dictionary<string, int>();
226      var root = tree.Root;
227      while (root.Symbol is ProgramRootSymbol || root.Symbol is StartSymbol) {
228        root = root.GetSubtree(0);
229      }
230      inputIntervals = new Interval[variableIntervals.Count];
231      int count = 0;
232      foreach (var node in root.IterateNodesPrefix()) {
233        if (node is VariableTreeNode varNode) {
234          var name = varNode.VariableName;
235          if (!variableIndices.ContainsKey(name)) {
236            variableIndices[name] = count;
237            inputIntervals[count] = variableIntervals[name];
238            ++count;
239          }
240        }
241      }
[17891]242      Array.Resize(ref inputIntervals, count);
[17772]243      return variableIndices;
244    }
245
246    public static Func<Interval[], Interval> Compile(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices) {
247      var root = tree.Root.GetSubtree(0).GetSubtree(0);
248      var args = Expression.Parameter(typeof(Interval[]));
249      var expr = MakeExpr(root, variableRanges, variableIndices, args);
250      return Expression.Lambda<Func<Interval[], Interval>>(expr, args).Compile();
251    }
252
[17891]253    public static Interval EstimateBounds(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges) {
[17772]254      var variableIndices = GetVariableIndices(tree, variableRanges, out Interval[] x);
255      var f = Compile(tree, variableRanges, variableIndices);
[17891]256      return f(x);
[17772]257    }
258    #endregion
259  }
260}
Note: See TracBrowser for help on using the repository browser.