Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3140_NumberSymbol/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/IntervalArithCompiledExpressionBoundsEstimator.cs @ 18093

Last change on this file since 18093 was 18093, checked in by chaider, 2 years ago

#3041

  • Renaming Constant Symbol to Num, behaves like before
  • Adding new Symbol RealConstant (Constant), this symbol behaves now like a real constant, won't be changed by parameter optimization or manipulators
  • Refactored classes part1
File size: 11.1 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using HEAL.Attic;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using System;
31using System.Collections.Generic;
32using System.Linq;
33using System.Linq.Expressions;
34
35namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
36  [StorableType("60015D64-5D8B-408A-90A1-E4111BC114D4")]
37  [Item("Interval Arithmetic Compiled Expression Bounds Estimator", "Compile a symbolic model into a lambda and use it to evaluate model bounds.")]
38  public class IntervalArithCompiledExpressionBoundsEstimator : ParameterizedNamedItem, IBoundsEstimator {
39    // interval method names
40    private static readonly Dictionary<byte, string> methodName = new Dictionary<byte, string>() {
41      { OpCodes.Add, "Add" },
42      { OpCodes.Sub, "Subtract" },
43      { OpCodes.Mul, "Multiply" },
44      { OpCodes.Div, "Divide" },
45      { OpCodes.Sin, "Sine" },
46      { OpCodes.Cos, "Cosine" },
47      { OpCodes.Tan, "Tangens" },
48      { OpCodes.Tanh, "HyperbolicTangent" },
49      { OpCodes.Log, "Logarithm" },
50      { OpCodes.Exp, "Exponential" },
51      { OpCodes.Square, "Square" },
52      { OpCodes.Cube, "Cube" },
53      { OpCodes.SquareRoot, "SquareRoot" },
54      { OpCodes.CubeRoot, "CubicRoot" },
55      { OpCodes.Absolute, "Absolute" },
56      { OpCodes.AnalyticQuotient, "AnalyticalQuotient" },
57    };
58
59    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
60    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
61      get => (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
62    }
63    public int EvaluatedSolutions {
64      get => EvaluatedSolutionsParameter.Value.Value;
65      set => EvaluatedSolutionsParameter.Value.Value = value;
66    }
67
68    private readonly object syncRoot = new object();
69
70    public IntervalArithCompiledExpressionBoundsEstimator() : base("Interval Arithmetic Compiled Expression Bounds Estimator",
71      "Estimates the bounds of the model with interval arithmetic, by first compiling the model into a lambda.") {
72      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName,
73        "A counter for the total number of solutions the estimator has evaluated.", new IntValue(0)));
74    }
75
76    [StorableConstructor]
77    protected IntervalArithCompiledExpressionBoundsEstimator(StorableConstructorFlag _) : base(_) { }
78
79    protected IntervalArithCompiledExpressionBoundsEstimator(IntervalArithCompiledExpressionBoundsEstimator original, Cloner cloner) : base(original, cloner) { }
80
81    public override IDeepCloneable Clone(Cloner cloner) {
82      return new IntervalArithCompiledExpressionBoundsEstimator(this, cloner);
83    }
84
85    public double GetConstraintViolation(ISymbolicExpressionTree tree, IntervalCollection variableRanges, ShapeConstraint constraint) {
86      var modelBound = GetModelBound(tree, variableRanges);
87      if (constraint.Interval.Contains(modelBound)) return 0.0;
88      return Math.Abs(modelBound.LowerBound - constraint.Interval.LowerBound) +
89             Math.Abs(modelBound.UpperBound - constraint.Interval.UpperBound);
90    }
91
92    public void ClearState() {
93      EvaluatedSolutions = 0;
94    }
95
96    public Interval GetModelBound(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
97      lock (syncRoot) { EvaluatedSolutions++; }
98      var resultInterval = EstimateBounds(tree, variableRanges.GetReadonlyDictionary());
99
100      if (resultInterval.IsInfiniteOrUndefined || resultInterval.LowerBound <= resultInterval.UpperBound)
101        return resultInterval;
102      return new Interval(resultInterval.UpperBound, resultInterval.LowerBound);
103    }
104
105    public IDictionary<ISymbolicExpressionTreeNode, Interval> GetModelNodeBounds(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
106      throw new NotSupportedException("Model nodes bounds are not supported.");
107    }
108
109    public void InitializeState() {
110      EvaluatedSolutions = 0;
111    }
112
113    public bool IsCompatible(ISymbolicExpressionTree tree) {
114      var containsUnknownSymbols = (
115        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
116        where
117          !(n.Symbol is Variable) &&
118          !(n.Symbol is Num) &&
119          !(n.Symbol is RealConstant) &&
120          !(n.Symbol is StartSymbol) &&
121          !(n.Symbol is Addition) &&
122          !(n.Symbol is Subtraction) &&
123          !(n.Symbol is Multiplication) &&
124          !(n.Symbol is Division) &&
125          !(n.Symbol is Sine) &&
126          !(n.Symbol is Cosine) &&
127          !(n.Symbol is Tangent) &&
128          !(n.Symbol is HyperbolicTangent) &&
129          !(n.Symbol is Logarithm) &&
130          !(n.Symbol is Exponential) &&
131          !(n.Symbol is Square) &&
132          !(n.Symbol is SquareRoot) &&
133          !(n.Symbol is Cube) &&
134          !(n.Symbol is CubeRoot) &&
135          !(n.Symbol is Absolute) &&
136          !(n.Symbol is AnalyticQuotient)
137        select n).Any();
138      return !containsUnknownSymbols;
139    }
140
141    #region compile a tree into a IA arithmetic lambda and estimate bounds
142    static Expression MakeExpr(ISymbolicExpressionTreeNode node, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices, Expression args) {
143      Expression expr(ISymbolicExpressionTreeNode n) => MakeExpr(n, variableRanges, variableIndices, args);
144      var opCode = OpCodes.MapSymbolToOpCode(node);
145
146      switch (opCode) {
147        case OpCodes.Variable: {
148            var name = (node as VariableTreeNode).VariableName;
149            var weight = (node as VariableTreeNode).Weight;
150            var index = variableIndices[name];
151            return Expression.Multiply(
152              Expression.Constant(weight, typeof(double)),
153              Expression.ArrayIndex(args, Expression.Constant(index, typeof(int)))
154            );
155          }
156        case OpCodes.Constant: {
157            var v = (node as NumTreeNode).Value;
158            // we have to make an interval out of the constant because this may be the root of the tree (and we are expected to return an Interval)
159            return Expression.Constant(new Interval(v, v), typeof(Interval));
160          }
161        case OpCodes.Add: {
162            var e = expr(node.GetSubtree(0));
163            foreach (var s in node.Subtrees.Skip(1)) {
164              e = Expression.Add(e, expr(s));
165            }
166            return e;
167          }
168        case OpCodes.Sub: {
169            var e = expr(node.GetSubtree(0));
170            if (node.SubtreeCount == 1) {
171              return Expression.Subtract(Expression.Constant(0.0, typeof(double)), e);
172            }
173            foreach (var s in node.Subtrees.Skip(1)) {
174              e = Expression.Subtract(e, expr(s));
175            }
176            return e;
177          }
178        case OpCodes.Mul: {
179            var e = expr(node.GetSubtree(0));
180            foreach (var s in node.Subtrees.Skip(1)) {
181              e = Expression.Multiply(e, expr(s));
182            }
183            return e;
184          }
185        case OpCodes.Div: {
186            var e1 = expr(node.GetSubtree(0));
187            if (node.SubtreeCount == 1) {
188              return Expression.Divide(Expression.Constant(1.0, typeof(double)), e1);
189            }
190            // division is more expensive than multiplication so we use this construct
191            var e2 = expr(node.GetSubtree(1));
192            foreach (var s in node.Subtrees.Skip(2)) {
193              e2 = Expression.Multiply(e2, expr(s));
194            }
195            return Expression.Divide(e1, e2);
196          }
197        case OpCodes.AnalyticQuotient: {
198            var a = expr(node.GetSubtree(0));
199            var b = expr(node.GetSubtree(1));
200            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { a.Type, b.Type });
201            return Expression.Call(fun, a, b);
202          }
203        // all these cases share the same code: get method info by name, emit call expression
204        case OpCodes.Exp:
205        case OpCodes.Log:
206        case OpCodes.Sin:
207        case OpCodes.Cos:
208        case OpCodes.Tan:
209        case OpCodes.Tanh:
210        case OpCodes.Square:
211        case OpCodes.Cube:
212        case OpCodes.SquareRoot:
213        case OpCodes.CubeRoot:
214        case OpCodes.Absolute: {
215            var arg = expr(node.GetSubtree(0));
216            var fun = typeof(Interval).GetMethod(methodName[opCode], new[] { arg.Type });
217            return Expression.Call(fun, arg);
218          }
219        default: {
220            throw new Exception($"Unsupported OpCode {opCode} encountered.");
221          }
222      }
223    }
224
225    public static IReadOnlyDictionary<string, int> GetVariableIndices(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableIntervals, out Interval[] inputIntervals) {
226      var variableIndices = new Dictionary<string, int>();
227      var root = tree.Root;
228      while (root.Symbol is ProgramRootSymbol || root.Symbol is StartSymbol) {
229        root = root.GetSubtree(0);
230      }
231      inputIntervals = new Interval[variableIntervals.Count];
232      int count = 0;
233      foreach (var node in root.IterateNodesPrefix()) {
234        if (node is VariableTreeNode varNode) {
235          var name = varNode.VariableName;
236          if (!variableIndices.ContainsKey(name)) {
237            variableIndices[name] = count;
238            inputIntervals[count] = variableIntervals[name];
239            ++count;
240          }
241        }
242      }
243      Array.Resize(ref inputIntervals, count);
244      return variableIndices;
245    }
246
247    public static Func<Interval[], Interval> Compile(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges, IReadOnlyDictionary<string, int> variableIndices) {
248      var root = tree.Root.GetSubtree(0).GetSubtree(0);
249      var args = Expression.Parameter(typeof(Interval[]));
250      var expr = MakeExpr(root, variableRanges, variableIndices, args);
251      return Expression.Lambda<Func<Interval[], Interval>>(expr, args).Compile();
252    }
253
254    public static Interval EstimateBounds(ISymbolicExpressionTree tree, IReadOnlyDictionary<string, Interval> variableRanges) {
255      var variableIndices = GetVariableIndices(tree, variableRanges, out Interval[] x);
256      var f = Compile(tree, variableRanges, variableIndices);
257      return f(x);
258    }
259    #endregion
260  }
261}
Note: See TracBrowser for help on using the repository browser.