Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3140_NumberSymbol/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/IntervalArithCompiledExpressionBoundsEstimator.cs @ 18112

Last change on this file since 18112 was 18112, checked in by chaider, 2 years ago

#3140

  • Adding INumericSymbol and INumericTreeNode
  • Using the new interfaces inside of interpreters and formatters
  • Renaming Num to Number, RealConstant to Constant
  • More classes refactored
File size: 11.4 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using HEAL.Attic;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using System;
31using System.Collections.Generic;
32using System.Linq;
33using System.Linq.Expressions;
34
35namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
36  [StorableType("60015D64-5D8B-408A-90A1-E4111BC114D4")]
37  [Item("Interval Arithmetic Compiled Expression Bounds Estimator", "Compile a symbolic model into a lambda and use it to evaluate model bounds.")]
38  public class IntervalArithCompiledExpressionBoundsEstimator : ParameterizedNamedItem, IBoundsEstimator {
39    // interval method names
40    private static readonly Dictionary<byte, string> methodName = new Dictionary<byte, string>() {
41      { OpCodes.Add, "Add" },
42      { OpCodes.Sub, "Subtract" },
43      { OpCodes.Mul, "Multiply" },
44      { OpCodes.Div, "Divide" },
45      { OpCodes.Sin, "Sine" },
46      { OpCodes.Cos, "Cosine" },
47      { OpCodes.Tan, "Tangens" },
48      { OpCodes.Tanh, "HyperbolicTangent" },
49      { OpCodes.Log, "Logarithm" },
50      { OpCodes.Exp, "Exponential" },
51      { OpCodes.Square, "Square" },
52      { OpCodes.Cube, "Cube" },
53      { OpCodes.SquareRoot, "SquareRoot" },
54      { OpCodes.CubeRoot, "CubicRoot" },
55      { OpCodes.Absolute, "Absolute" },
56      { OpCodes.AnalyticQuotient, "AnalyticalQuotient" },
57    };
58
59    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
60    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
61      get => (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
62    }
63    public int EvaluatedSolutions {
64      get => EvaluatedSolutionsParameter.Value.Value;
65      set => EvaluatedSolutionsParameter.Value.Value = value;
66    }
67
68    private readonly object syncRoot = new object();
69
70    public IntervalArithCompiledExpressionBoundsEstimator() : base("Interval Arithmetic Compiled Expression Bounds Estimator",
71      "Estimates the bounds of the model with interval arithmetic, by first compiling the model into a lambda.") {
72      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName,
73        "A counter for the total number of solutions the estimator has evaluated.", new IntValue(0)));
74    }
75
76    [StorableConstructor]
77    protected IntervalArithCompiledExpressionBoundsEstimator(StorableConstructorFlag _) : base(_) { }
78
79    protected IntervalArithCompiledExpressionBoundsEstimator(IntervalArithCompiledExpressionBoundsEstimator original, Cloner cloner) : base(original, cloner) { }
80
81    public override IDeepCloneable Clone(Cloner cloner) {
82      return new IntervalArithCompiledExpressionBoundsEstimator(this, cloner);
83    }
84
85    public double GetConstraintViolation(ISymbolicExpressionTree tree, IntervalCollection variableRanges, ShapeConstraint constraint) {
86      var modelBound = GetModelBound(tree, variableRanges);
87      if (constraint.Interval.Contains(modelBound)) return 0.0;
88      return Math.Abs(modelBound.LowerBound - constraint.Interval.LowerBound) +
89             Math.Abs(modelBound.UpperBound - constraint.Interval.UpperBound);
90    }
91
92    public void ClearState() {
93      EvaluatedSolutions = 0;
94    }
95
96    public Interval GetModelBound(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
97      lock (syncRoot) { EvaluatedSolutions++; }
98      var resultInterval = EstimateBounds(tree, variableRanges.GetReadonlyDictionary());
99
100      if (resultInterval.IsInfiniteOrUndefined || resultInterval.LowerBound <= resultInterval.UpperBound)
101        return resultInterval;
102      return new Interval(resultInterval.UpperBound, resultInterval.LowerBound);
103    }
104
105    public IDictionary<ISymbolicExpressionTreeNode, Interval> GetModelNodeBounds(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
106      throw new NotSupportedException("Model nodes bounds are not supported.");
107    }
108
109    public void InitializeState() {
110      EvaluatedSolutions = 0;
111    }
112
113    public bool IsCompatible(ISymbolicExpressionTree tree) {
114      var containsUnknownSymbols = (
115        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
116        where
117          !(n.Symbol is Variable) &&
118          !(n.Symbol is Number) &&
119          !(n.Symbol is Constant) &&
120          !(n.Symbol is StartSymbol) &&
121          !(n.Symbol is Addition) &&
122          !(n.Symbol is Subtraction) &&
123          !(n.Symbol is Multiplication) &&
124          !(n.Symbol is Division) &&
125          !(n.Symbol is Sine) &&
126          !(n.Symbol is Cosine) &&
127          !(n.Symbol is Tangent) &&
128          !(n.Symbol is HyperbolicTangent) &&
129          !(n.Symbol is Logarithm) &&
130          !(n.Symbol is Exponential) &&
131          !(n.Symbol is Square) &&
132          !(n.Symbol is SquareRoot) &&
133          !(n.Symbol is Cube) &&
134          !(n.Symbol is CubeRoot) &&
135          !(n.Symbol is Absolute) &&
136          !(n.Symbol is AnalyticQuotient)
137        select n).Any();
138      return !containsUnknownSymbols;
139    }
140
141    #region compile a tree into a IA arithmetic lambda and estimate bounds
142    static Expression MakeExpr(ISymbolicExpressionTreeNode node, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices, Expression args) {
143      Expression expr(ISymbolicExpressionTreeNode n) => MakeExpr(n, variableRanges, variableIndices, args);
144      var opCode = OpCodes.MapSymbolToOpCode(node);
145
146      switch (opCode) {
147        case OpCodes.Variable: {
148            var name = (node as VariableTreeNode).VariableName;
149            var weight = (node as VariableTreeNode).Weight;
150            var index = variableIndices[name];
151            return Expression.Multiply(
152              Expression.Constant(weight, typeof(double)),
153              Expression.ArrayIndex(args, Expression.Constant(index, typeof(int)))
154            );
155          }
156        case OpCodes.Constant: {
157            var v = (node as NumberTreeNode).Value;
158            // we have to make an interval out of the constant because this may be the root of the tree (and we are expected to return an Interval)
159            return Expression.Constant(new Interval(v, v), typeof(Interval));
160          }
161        case OpCodes.Number: {
162            var v = (node as NumberTreeNode).Value;
163            // we have to make an interval out of the constant because this may be the root of the tree (and we are expected to return an Interval)
164            return Expression.Constant(new Interval(v, v), typeof(Interval));
165          }
166        case OpCodes.Add: {
167            var e = expr(node.GetSubtree(0));
168            foreach (var s in node.Subtrees.Skip(1)) {
169              e = Expression.Add(e, expr(s));
170            }
171            return e;
172          }
173        case OpCodes.Sub: {
174            var e = expr(node.GetSubtree(0));
175            if (node.SubtreeCount == 1) {
176              return Expression.Subtract(Expression.Constant(0.0, typeof(double)), e);
177            }
178            foreach (var s in node.Subtrees.Skip(1)) {
179              e = Expression.Subtract(e, expr(s));
180            }
181            return e;
182          }
183        case OpCodes.Mul: {
184            var e = expr(node.GetSubtree(0));
185            foreach (var s in node.Subtrees.Skip(1)) {
186              e = Expression.Multiply(e, expr(s));
187            }
188            return e;
189          }
190        case OpCodes.Div: {
191            var e1 = expr(node.GetSubtree(0));
192            if (node.SubtreeCount == 1) {
193              return Expression.Divide(Expression.Constant(1.0, typeof(double)), e1);
194            }
195            // division is more expensive than multiplication so we use this construct
196            var e2 = expr(node.GetSubtree(1));
197            foreach (var s in node.Subtrees.Skip(2)) {
198              e2 = Expression.Multiply(e2, expr(s));
199            }
200            return Expression.Divide(e1, e2);
201          }
202        case OpCodes.AnalyticQuotient: {
203            var a = expr(node.GetSubtree(0));
204            var b = expr(node.GetSubtree(1));
205            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { a.Type, b.Type });
206            return Expression.Call(fun, a, b);
207          }
208        // all these cases share the same code: get method info by name, emit call expression
209        case OpCodes.Exp:
210        case OpCodes.Log:
211        case OpCodes.Sin:
212        case OpCodes.Cos:
213        case OpCodes.Tan:
214        case OpCodes.Tanh:
215        case OpCodes.Square:
216        case OpCodes.Cube:
217        case OpCodes.SquareRoot:
218        case OpCodes.CubeRoot:
219        case OpCodes.Absolute: {
220            var arg = expr(node.GetSubtree(0));
221            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { arg.Type });
222            return Expression.Call(fun, arg);
223          }
224        default: {
225            throw new Exception($"Unsupported OpCode {opCode} encountered.");
226          }
227      }
228    }
229
230    public static IReadOnlyDictionary<string, int> GetVariableIndices(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableIntervals, out Interval[] inputIntervals) {
231      var variableIndices = new Dictionary<string, int>();
232      var root = tree.Root;
233      while (root.Symbol is ProgramRootSymbol || root.Symbol is StartSymbol) {
234        root = root.GetSubtree(0);
235      }
236      inputIntervals = new Interval[variableIntervals.Count];
237      int count = 0;
238      foreach (var node in root.IterateNodesPrefix()) {
239        if (node is VariableTreeNode varNode) {
240          var name = varNode.VariableName;
241          if (!variableIndices.ContainsKey(name)) {
242            variableIndices[name] = count;
243            inputIntervals[count] = variableIntervals[name];
244            ++count;
245          }
246        }
247      }
248      Array.Resize(ref inputIntervals, count);
249      return variableIndices;
250    }
251
252    public static Func<Interval[], Interval> Compile(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices) {
253      var root = tree.Root.GetSubtree(0).GetSubtree(0);
254      var args = Expression.Parameter(typeof(Interval[]));
255      var expr = MakeExpr(root, variableRanges, variableIndices, args);
256      return Expression.Lambda<Func<Interval[], Interval>>(expr, args).Compile();
257    }
258
259    public static Interval EstimateBounds(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges) {
260      var variableIndices = GetVariableIndices(tree, variableRanges, out Interval[] x);
261      var f = Compile(tree, variableRanges, variableIndices);
262      return f(x);
263    }
264    #endregion
265  }
266}
Note: See TracBrowser for help on using the repository browser.