source: branches/3073_IA_constraint_splitting_reintegration/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/IntervalArithCompiledExpressionBoundsEstimator.cs @ 17896

Last change on this file since 17896 was 17896, checked in by gkronber, 17 months ago

#3073: refactoring to prepare for trunk integration

File size: 10.8 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using HEAL.Attic;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using System;
31using System.Collections.Generic;
32using System.Linq;
33using System.Linq.Expressions;
34
35namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
36  [StorableType("60015D64-5D8B-408A-90A1-E4111BC114D4")]
37  [Item("Interval Arithmetic Compiled Expression Bounds Estimator", "Compile a symbolic model into a lambda and use it to evaluate model bounds.")]
38  public class IntervalArithCompiledExpressionBoundsEstimator : ParameterizedNamedItem, IBoundsEstimator {
39    // interval method names
40    private static readonly Dictionary<byte, string> methodName = new Dictionary<byte, string>() {
41      { OpCodes.Add, "Add" },
42      { OpCodes.Sub, "Subtract" },
43      { OpCodes.Mul, "Multiply" },
44      { OpCodes.Div, "Divide" },
45      { OpCodes.Sin, "Sine" },
46      { OpCodes.Cos, "Cosine" },
47      { OpCodes.Tan, "Tangens" },
48      { OpCodes.Tanh, "HyperbolicTangent" },
49      { OpCodes.Log, "Logarithm" },
50      { OpCodes.Exp, "Exponential" },
51      { OpCodes.Square, "Square" },
52      { OpCodes.Cube, "Cube" },
53      { OpCodes.SquareRoot, "SquareRoot" },
54      { OpCodes.CubeRoot, "CubicRoot" },
55      { OpCodes.Absolute, "Absolute" },
56      { OpCodes.AnalyticQuotient, "AnalyticalQuotient" },
57    };
58
59    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
60    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
61      get => (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
62    }
63    public int EvaluatedSolutions {
64      get => EvaluatedSolutionsParameter.Value.Value;
65      set => EvaluatedSolutionsParameter.Value.Value = value;
66    }
67
68    private readonly object syncRoot = new object();
69
70    public IntervalArithCompiledExpressionBoundsEstimator() : base("Interval Arithmetic Compiled Expression Bounds Estimator",
71      "Estimates the bounds of the model with interval arithmetic, by first compiling the model into a lambda.") {
72      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName,
73        "A counter for the total number of solutions the estimator has evaluated.", new IntValue(0)));
74    }
75
76    [StorableConstructor]
77    private IntervalArithCompiledExpressionBoundsEstimator(StorableConstructorFlag _) : base(_) { }
78
79    private IntervalArithCompiledExpressionBoundsEstimator(IntervalArithCompiledExpressionBoundsEstimator original, Cloner cloner) : base(original, cloner) { }
80
81    public override IDeepCloneable Clone(Cloner cloner) {
82      return new IntervalArithCompiledExpressionBoundsEstimator(this, cloner);
83    }
84
85    public double GetConstraintViolation(ISymbolicExpressionTree tree, IntervalCollection variableRanges, ShapeConstraint constraint) {
86      var modelBound = GetModelBound(tree, variableRanges);
87      if (constraint.Interval.Contains(modelBound)) return 0.0;
88      return Math.Abs(modelBound.LowerBound - constraint.Interval.LowerBound) +
89             Math.Abs(modelBound.UpperBound - constraint.Interval.UpperBound);
90    }
91
92    public void ClearState() {
93      EvaluatedSolutions = 0;
94    }
95
96    public Interval GetModelBound(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
97      lock (syncRoot) { EvaluatedSolutions++; }
98      var resultInterval = EstimateBounds(tree, variableRanges.GetReadonlyDictionary());
99
100      if (resultInterval.IsInfiniteOrUndefined || resultInterval.LowerBound <= resultInterval.UpperBound)
101        return resultInterval;
102      return new Interval(resultInterval.UpperBound, resultInterval.LowerBound);
103    }
104
105    public IDictionary<ISymbolicExpressionTreeNode, Interval> GetModelNodeBounds(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
106      throw new NotSupportedException("Model nodes bounds are not supported.");
107    }
108
109    public void InitializeState() {
110      EvaluatedSolutions = 0;
111    }
112
113    public bool IsCompatible(ISymbolicExpressionTree tree) {
114      var containsUnknownSymbols = (
115        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
116        where
117          !(n.Symbol is Variable) &&
118          !(n.Symbol is Constant) &&
119          !(n.Symbol is StartSymbol) &&
120          !(n.Symbol is Addition) &&
121          !(n.Symbol is Subtraction) &&
122          !(n.Symbol is Multiplication) &&
123          !(n.Symbol is Division) &&
124          !(n.Symbol is Sine) &&
125          !(n.Symbol is Cosine) &&
126          !(n.Symbol is Tangent) &&
127          !(n.Symbol is HyperbolicTangent) &&
128          !(n.Symbol is Logarithm) &&
129          !(n.Symbol is Exponential) &&
130          !(n.Symbol is Square) &&
131          !(n.Symbol is SquareRoot) &&
132          !(n.Symbol is Cube) &&
133          !(n.Symbol is CubeRoot) &&
134          !(n.Symbol is Absolute) &&
135          !(n.Symbol is AnalyticQuotient)
136        select n).Any();
137      return !containsUnknownSymbols;
138    }
139
140    #region compile a tree into a IA arithmetic lambda and estimate bounds
141    static Expression MakeExpr(ISymbolicExpressionTreeNode node, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices, Expression args) {
142      Expression expr(ISymbolicExpressionTreeNode n) => MakeExpr(n, variableRanges, variableIndices, args);
143      var opCode = OpCodes.MapSymbolToOpCode(node);
144
145      switch (opCode) {
146        case OpCodes.Variable: {
147            var name = (node as VariableTreeNode).VariableName;
148            var weight = (node as VariableTreeNode).Weight;
149            var index = variableIndices[name];
150            return Expression.Multiply(
151              Expression.Constant(weight, typeof(double)),
152              Expression.ArrayIndex(args, Expression.Constant(index, typeof(int)))
153            );
154          }
155        case OpCodes.Constant: {
156            var v = (node as ConstantTreeNode).Value;
157            // we have to make an interval out of the constant because this may be the root of the tree (and we are expected to return an Interval)
158            return Expression.Constant(new Interval(v, v), typeof(Interval));
159          }
160        case OpCodes.Add: {
161            var e = expr(node.GetSubtree(0));
162            foreach (var s in node.Subtrees.Skip(1)) {
163              e = Expression.Add(e, expr(s));
164            }
165            return e;
166          }
167        case OpCodes.Sub: {
168            var e = expr(node.GetSubtree(0));
169            if (node.SubtreeCount == 1) {
170              return Expression.Subtract(Expression.Constant(0.0, typeof(double)), e);
171            }
172            foreach (var s in node.Subtrees.Skip(1)) {
173              e = Expression.Subtract(e, expr(s));
174            }
175            return e;
176          }
177        case OpCodes.Mul: {
178            var e = expr(node.GetSubtree(0));
179            foreach (var s in node.Subtrees.Skip(1)) {
180              e = Expression.Multiply(e, expr(s));
181            }
182            return e;
183          }
184        case OpCodes.Div: {
185            var e1 = expr(node.GetSubtree(0));
186            if (node.SubtreeCount == 1) {
187              return Expression.Divide(Expression.Constant(1.0, typeof(double)), e1);
188            }
189            // division is more expensive than multiplication so we use this construct
190            var e2 = expr(node.GetSubtree(1));
191            foreach (var s in node.Subtrees.Skip(2)) {
192              e2 = Expression.Multiply(e2, expr(s));
193            }
194            return Expression.Divide(e1, e2);
195          }
196        // all these cases share the same code: get method info by name, emit call expression
197        case OpCodes.Exp:
198        case OpCodes.Log:
199        case OpCodes.Sin:
200        case OpCodes.Cos:
201        case OpCodes.Tan:
202        case OpCodes.Tanh:
203        case OpCodes.Square:
204        case OpCodes.Cube:
205        case OpCodes.SquareRoot:
206        case OpCodes.CubeRoot:
207        case OpCodes.Absolute:
208        case OpCodes.AnalyticQuotient: {
209            var arg = expr(node.GetSubtree(0));
210            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { arg.Type });
211            return Expression.Call(fun, arg);
212          }
213        default: {
214            throw new Exception($"Unsupported OpCode {opCode} encountered.");
215          }
216      }
217    }
218
219    public static IReadOnlyDictionary<string, int> GetVariableIndices(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableIntervals, out Interval[] inputIntervals) {
220      var variableIndices = new Dictionary<string, int>();
221      var root = tree.Root;
222      while (root.Symbol is ProgramRootSymbol || root.Symbol is StartSymbol) {
223        root = root.GetSubtree(0);
224      }
225      inputIntervals = new Interval[variableIntervals.Count];
226      int count = 0;
227      foreach (var node in root.IterateNodesPrefix()) {
228        if (node is VariableTreeNode varNode) {
229          var name = varNode.VariableName;
230          if (!variableIndices.ContainsKey(name)) {
231            variableIndices[name] = count;
232            inputIntervals[count] = variableIntervals[name];
233            ++count;
234          }
235        }
236      }
237      Array.Resize(ref inputIntervals, count);
238      return variableIndices;
239    }
240
241    public static Func<Interval[], Interval> Compile(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices) {
242      var root = tree.Root.GetSubtree(0).GetSubtree(0);
243      var args = Expression.Parameter(typeof(Interval[]));
244      var expr = MakeExpr(root, variableRanges, variableIndices, args);
245      return Expression.Lambda<Func<Interval[], Interval>>(expr, args).Compile();
246    }
247
248    public static Interval EstimateBounds(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges) {
249      var variableIndices = GetVariableIndices(tree, variableRanges, out Interval[] x);
250      var f = Compile(tree, variableRanges, variableIndices);
251      return f(x);
252    }
253    #endregion
254  }
255}
Note: See TracBrowser for help on using the repository browser.