source: branches/3073_IA_constraint_splitting/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/IABoundsEstimator.cs @ 17763

Last change on this file since 17763 was 17763, checked in by chaider, 14 months ago

#3073 Added interface for bound estimators and added an IABoundEstimator

File size: 19.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Collections.ObjectModel;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HEAL.Attic;
8using HeuristicLab.Common;
9using HeuristicLab.Core;
10using HeuristicLab.Data;
11using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
12using HeuristicLab.Parameters;
13
14namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
15  [StorableType("C8539434-6FB0-47D0-9F5A-2CAE5D8B8B4F")]
16  [Item("IA Bounds Estimator", "Interpreter for calculation of intervals of symbolic models.")]
17  public sealed class IABoundsEstimator : ParameterizedNamedItem, IBoundsEstimator{
18    #region Parameters
19
20    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
21    private const string UseIntervalSplittingParameterName = "Use Interval splitting";
22    private const string SplittingIterationsParameterName = "Splitting Iterations";
23    private const string SplittingWidthParameterName = "Splitting width";
24
25    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter =>
26      (IFixedValueParameter<IntValue>) Parameters[EvaluatedSolutionsParameterName];
27
28    public IFixedValueParameter<BoolValue> UseIntervalSplittingParameter =>
29      (IFixedValueParameter<BoolValue>) Parameters[UseIntervalSplittingParameterName];
30
31    public IFixedValueParameter<IntValue> SplittingIterationsParameter =>
32      (IFixedValueParameter<IntValue>) Parameters[SplittingIterationsParameterName];
33
34    public IFixedValueParameter<DoubleValue> SplittingWidthParameter =>
35      (IFixedValueParameter<DoubleValue>) Parameters[SplittingWidthParameterName];
36
37    public int EvaluatedSolutions {
38      get => EvaluatedSolutionsParameter.Value.Value;
39      set => EvaluatedSolutionsParameter.Value.Value = value;
40    }
41
42    public bool UseIntervalSplitting {
43      get => UseIntervalSplittingParameter.Value.Value;
44      set => UseIntervalSplittingParameter.Value.Value = value;
45    }
46
47    public int SplittingIterations {
48      get => SplittingIterationsParameter.Value.Value;
49      set => SplittingIterationsParameter.Value.Value = value;
50    }
51
52    public double SplittingWidth {
53      get => SplittingWidthParameter.Value.Value;
54      set => SplittingWidthParameter.Value.Value = value;
55    }
56    #endregion
57
58    #region Constructors
59
60    [StorableConstructor]
61    private IABoundsEstimator(StorableConstructorFlag _) : base(_) { }
62       
63    private IABoundsEstimator(IABoundsEstimator original, Cloner cloner) : base(original, cloner) { }
64
65    public IABoundsEstimator() : base("IA Bounds Estimator", "Estimates the bounds of the model with interval arithmetic") {
66      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the estimator has evaluated.", new IntValue(0)));
67      Parameters.Add(new FixedValueParameter<BoolValue>(UseIntervalSplittingParameterName, "Defines whether interval splitting is activated or not.", new BoolValue(false)));
68      Parameters.Add(new FixedValueParameter<IntValue>(SplittingIterationsParameterName, "Defines the number of iterations of splitting.", new IntValue(200)));
69      Parameters.Add(new FixedValueParameter<DoubleValue>(SplittingWidthParameterName, "Width of interval, after the splitting should stop.", new DoubleValue(0.0)));
70    }
71
72    public override IDeepCloneable Clone(Cloner cloner) {
73      return new IABoundsEstimator(this, cloner);
74    }
75
76        #endregion
77
78    #region IStatefulItem Members
79
80    private readonly object syncRoot = new object();
81
82    public void InitializeState() {
83      EvaluatedSolutions = 0;
84    }
85
86    public void ClearState() { }
87
88        #endregion
89
90    #region Evaluation
91
92    private static Instruction[] PrepareInterpreterState(
93      ISymbolicExpressionTree tree,
94      IDictionary<string, Interval> variableRanges) {
95      if (variableRanges == null)
96        throw new ArgumentNullException("No variablew ranges are present!", nameof(variableRanges));
97
98      //Check if all variables used in the tree are present in the dataset
99      foreach (var variable in tree.IterateNodesPrefix().OfType<VariableTreeNode>().Select(n => n.VariableName)
100                                   .Distinct())
101        if (!variableRanges.ContainsKey(variable))
102          throw new InvalidOperationException($"No ranges for variable {variable} is present");
103
104      var code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
105      foreach (var instr in code.Where(i => i.opCode == OpCodes.Variable)) {
106        var variableTreeNode = (VariableTreeNode) instr.dynamicNode;
107        instr.data = variableRanges[variableTreeNode.VariableName];
108      }
109
110      return code;
111    }
112
113    public static Interval Evaluate(
114      Instruction[] instructions, ref int instructionCounter,
115      IDictionary<ISymbolicExpressionTreeNode, Interval> nodeIntervals = null,
116      IDictionary<string, Interval> variableIntervals = null) {
117      var currentInstr = instructions[instructionCounter];
118      //Use ref parameter, because the tree will be iterated through recursively from the left-side branch to the right side
119      //Update instructionCounter, whenever Evaluate is called
120      instructionCounter++;
121      Interval result = null;
122
123      switch (currentInstr.opCode) {
124        //Variables, Constants, ...
125        case OpCodes.Variable: {
126          var variableTreeNode = (VariableTreeNode) currentInstr.dynamicNode;
127          var weightInterval = new Interval(variableTreeNode.Weight, variableTreeNode.Weight);
128
129          Interval variableInterval;
130          if (variableIntervals != null && variableIntervals.ContainsKey(variableTreeNode.VariableName))
131            variableInterval = variableIntervals[variableTreeNode.VariableName];
132          else
133            variableInterval = (Interval) currentInstr.data;
134
135          result = Interval.Multiply(variableInterval, weightInterval);
136          break;
137        }
138        case OpCodes.Constant: {
139          var constTreeNode = (ConstantTreeNode) currentInstr.dynamicNode;
140          result = new Interval(constTreeNode.Value, constTreeNode.Value);
141          break;
142        }
143        //Elementary arithmetic rules
144        case OpCodes.Add: {
145          result = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
146          for (var i = 1; i < currentInstr.nArguments; i++) {
147            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
148            result = Interval.Add(result, argumentInterval);
149          }
150
151          break;
152        }
153        case OpCodes.Sub: {
154          result = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
155          if (currentInstr.nArguments == 1)
156            result = Interval.Multiply(new Interval(-1, -1), result);
157
158          for (var i = 1; i < currentInstr.nArguments; i++) {
159            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
160            result = Interval.Subtract(result, argumentInterval);
161          }
162
163          break;
164        }
165        case OpCodes.Mul: {
166          result = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
167          for (var i = 1; i < currentInstr.nArguments; i++) {
168            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
169            result = Interval.Multiply(result, argumentInterval);
170          }
171
172          break;
173        }
174        case OpCodes.Div: {
175          result = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
176          if (currentInstr.nArguments == 1)
177            result = Interval.Divide(new Interval(1, 1), result);
178
179          for (var i = 1; i < currentInstr.nArguments; i++) {
180            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
181            result = Interval.Divide(result, argumentInterval);
182          }
183
184          break;
185        }
186        //Trigonometric functions
187        case OpCodes.Sin: {
188          var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
189          result = Interval.Sine(argumentInterval);
190          break;
191        }
192        case OpCodes.Cos: {
193          var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
194          result = Interval.Cosine(argumentInterval);
195          break;
196        }
197        case OpCodes.Tan: {
198          var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
199          result = Interval.Tangens(argumentInterval);
200          break;
201        }
202        case OpCodes.Tanh: {
203          var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
204          result = Interval.HyperbolicTangent(argumentInterval);
205          break;
206        }
207        //Exponential functions
208        case OpCodes.Log: {
209          var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
210          result = Interval.Logarithm(argumentInterval);
211          break;
212        }
213        case OpCodes.Exp: {
214          var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
215          result = Interval.Exponential(argumentInterval);
216          break;
217        }
218        case OpCodes.Square: {
219          var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
220          result = Interval.Square(argumentInterval);
221          break;
222        }
223        case OpCodes.SquareRoot: {
224          var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
225          result = Interval.SquareRoot(argumentInterval);
226          break;
227        }
228        case OpCodes.Cube: {
229          var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
230          result = Interval.Cube(argumentInterval);
231          break;
232        }
233        case OpCodes.CubeRoot: {
234          var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
235          result = Interval.CubicRoot(argumentInterval);
236          break;
237        }
238        case OpCodes.Absolute: {
239          var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
240          result = Interval.Absolute(argumentInterval);
241          break;
242        }
243        case OpCodes.AnalyticQuotient: {
244          result = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
245          for (var i = 1; i < currentInstr.nArguments; i++) {
246            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
247            result = Interval.AnalyticalQuotient(result, argumentInterval);
248          }
249
250          break;
251        }
252        default:
253          throw new NotSupportedException(
254            $"The tree contains the unknown symbol {currentInstr.dynamicNode.Symbol}");
255      }
256
257      if (!(nodeIntervals == null || nodeIntervals.ContainsKey(currentInstr.dynamicNode)))
258        nodeIntervals.Add(currentInstr.dynamicNode, result);
259
260      return result;
261    }
262
263        #endregion
264
265    #region Helpers
266
267    private static IDictionary<string, Interval> GetOccurringVariableRanges(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
268      var variables = tree.IterateNodesPrefix().OfType<VariableTreeNode>().Select(v => v.VariableName).Distinct()
269                          .ToList();
270
271      return variables.ToDictionary(x => x, x => variableRanges.GetReadonlyDictionary()[x]);
272    }
273
274    private static bool ContainsVariableMultipleTimes(ISymbolicExpressionTree tree, out List<String> variables) {
275      variables = new List<string>();
276      var varlist = tree.IterateNodesPrefix().OfType<VariableTreeNode>().GroupBy(x => x.VariableName);
277      foreach (var group in varlist) {
278        if (group.Count() > 1) {
279          variables.Add(group.Key);
280        }
281      }
282
283      return varlist.Any(group => group.Count() > 1);
284    }
285
286    // a multi-dimensional box with an associated bound
287    // boxbounds are ordered first by bound (smaller first), then by size of box (larger first) then by distance of bottom left corner to origin
288    private class BoxBound : IComparable<BoxBound> {
289      public List<Interval> box;
290      public double bound;
291
292      public BoxBound(IEnumerable<Interval> box, double bound) {
293        this.box = new List<Interval>(box);
294        this.bound = bound;
295      }
296
297      public int CompareTo(BoxBound other) {
298        if (bound != other.bound) return bound.CompareTo(other.bound);
299
300        var thisSize = box.Aggregate(1.0, (current, dimExtent) => current * dimExtent.Width);
301        var otherSize = other.box.Aggregate(1.0, (current, dimExtent) => current * dimExtent.Width);
302        if (thisSize != otherSize) return -thisSize.CompareTo(otherSize);
303
304        var thisDist = box.Sum(dimExtent => dimExtent.LowerBound * dimExtent.LowerBound);
305        var otherDist = other.box.Sum(dimExtent => dimExtent.LowerBound * dimExtent.LowerBound);
306        if (thisDist != otherDist) return thisDist.CompareTo(otherDist);
307
308        // which is smaller first along the dimensions?
309        for (int i = 0; i < box.Count; i++) {
310          if (box[i].LowerBound != other.box[i].LowerBound) return box[i].LowerBound.CompareTo(other.box[i].LowerBound);
311        }
312
313        return 0;
314      }
315    }
316
317    #endregion
318
319    #region Splitting
320
321    public static Interval EvaluateWithSplitting(Instruction[] instructions,
322                                                 IDictionary<string, Interval> variableIntervals,
323                                                 List<string> multipleOccurenceVariables, int splittingIterations, double splittingWidth, IDictionary<ISymbolicExpressionTreeNode, Interval> nodeIntervals = null) {
324      var savedIntervals = variableIntervals.ToDictionary(entry => entry.Key, entry => entry.Value);
325      var min = FindBound(instructions, variableIntervals, multipleOccurenceVariables, splittingIterations, splittingWidth, nodeIntervals,
326        minimization: true);
327      var max = FindBound(instructions, savedIntervals,  multipleOccurenceVariables, splittingIterations, splittingWidth, nodeIntervals,
328        minimization: false);
329
330      return new Interval(min, max);
331    }
332
333    private static double FindBound(Instruction[] instructions,
334                                    IDictionary<string, Interval> variableIntervals,
335                                    List<string> multipleOccurenceVariables, int splittingIterations, double splittingWidth, IDictionary<ISymbolicExpressionTreeNode, Interval> nodeIntervals = null, bool minimization = true) {
336      SortedSet<BoxBound> prioQ = new SortedSet<BoxBound>();
337
338      var ic = 0;
339      //Calculate full box
340      //IReadOnlyDictionary<string, Interval> readonlyRanges =
341      //  variableIntervals.ToDictionary(k => k.Key, k => k.Value);
342      var interval = Evaluate(instructions, ref ic, nodeIntervals, variableIntervals:variableIntervals);
343      // the order of keys in a dictionary is guaranteed to be the same order as values in a dictionary
344      // https://docs.microsoft.com/en-us/dotnet/api/system.collections.idictionary.keys?view=netcore-3.1#remarks
345      //var box = variableIntervals.Values;
346      //Box only contains intervals from multiple occurence variables
347      var box = multipleOccurenceVariables.Select(k => variableIntervals[k]);
348      if (minimization) {
349        prioQ.Add(new BoxBound(box, interval.LowerBound));
350      } else {
351        prioQ.Add(new BoxBound(box, -interval.UpperBound));
352      }
353
354      var discardedBound = double.MaxValue;
355      var runningBound = double.MaxValue;
356      for (var depth = 0; depth < splittingIterations && prioQ.Count > 0; ++depth) {
357        var currentBound = prioQ.Min;
358        prioQ.Remove(currentBound);
359
360        if (currentBound.box.All(x => x.Width < splittingWidth)) {
361          discardedBound = Math.Min(discardedBound, currentBound.bound);
362          continue;
363        }
364
365        var newBoxes = Split(currentBound.box, splittingWidth);
366
367        var innerBound = double.MaxValue;
368        foreach (var newBox in newBoxes) {
369          //var intervalEnum = newBox.GetEnumerator();
370          //var keyEnum = readonlyRanges.Keys.GetEnumerator();
371          //while (intervalEnum.MoveNext() & keyEnum.MoveNext()) {
372          //  variableIntervals[keyEnum.Current] = intervalEnum.Current;
373          //}
374          //Set the splitted variables
375          var intervalEnum = newBox.GetEnumerator();
376          foreach (var key in multipleOccurenceVariables) {
377            intervalEnum.MoveNext();
378            variableIntervals[key] = intervalEnum.Current;
379          }
380
381          ic = 0;
382          var res = Evaluate(instructions, ref ic, nodeIntervals,
383            new ReadOnlyDictionary<string, Interval>(variableIntervals));
384          //if (minimization) {
385          //  prioQ.Add(new BoxBound(newBox, res.LowerBound));
386          //} else {
387          //  prioQ.Add(new BoxBound(newBox, -res.UpperBound));
388          //}
389          var boxBound = new BoxBound(newBox, minimization ? res.LowerBound : -res.UpperBound);
390          prioQ.Add(boxBound);
391          innerBound = Math.Min(innerBound, boxBound.bound);
392        }
393
394        runningBound = innerBound;
395
396      }
397
398      var bound = Math.Min(runningBound, discardedBound);
399      if (bound == double.MaxValue)
400        return minimization ? interval.LowerBound : interval.UpperBound;
401
402      return minimization ? bound : -bound;
403      //return minimization ? prioQ.First().bound : -prioQ.First().bound;
404    }
405
406    private static IEnumerable<IEnumerable<Interval>> Split(List<Interval> box) {
407      var boxes = box.Select(region => region.Split())
408                     .Select(split => new List<Interval> {split.Item1, split.Item2})
409                     .ToList();
410
411      return boxes.CartesianProduct();
412    }
413
414    private static IEnumerable<IEnumerable<Interval>> Split(List<Interval> box, double minWidth) {
415      List<Interval> toList(Tuple<Interval, Interval> t) => new List<Interval>{t.Item1, t.Item2};
416      var boxes = box.Select(region => region.Width > minWidth ? toList(region.Split()) : new List<Interval> {region})
417                     .ToList();
418
419      return boxes.CartesianProduct();
420    }
421
422    #endregion
423
424    public Interval GetModelBound(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
425      lock (syncRoot) {
426        EvaluatedSolutions++;
427      }
428
429      var occuringVariableRanges = GetOccurringVariableRanges(tree, variableRanges);
430      var instructions = PrepareInterpreterState(tree, occuringVariableRanges);
431      Interval resultInterval;
432      if (!UseIntervalSplitting) {
433        var instructionCounter = 0;
434        resultInterval = Evaluate(instructions, ref instructionCounter, variableIntervals: occuringVariableRanges);
435      } else {
436        var vars = ContainsVariableMultipleTimes(tree, out var variables);
437        resultInterval = EvaluateWithSplitting(instructions, occuringVariableRanges, variables, SplittingIterations, SplittingWidth);
438      }
439
440      // because of numerical errors the bounds might be incorrect
441      if (resultInterval.IsInfiniteOrUndefined || resultInterval.LowerBound <= resultInterval.UpperBound)
442        return resultInterval;
443
444      return new Interval(resultInterval.UpperBound, resultInterval.LowerBound);
445    }
446
447    public IDictionary<ISymbolicExpressionTreeNode, Interval> GetModelNodesBounds(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
448      throw new NotImplementedException();
449    }
450
451   
452  }
453}
Note: See TracBrowser for help on using the repository browser.