Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/IntervalArithBoundsEstimator.cs @ 18066

Last change on this file since 18066 was 17964, checked in by chaider, 4 years ago

#3073 Added sample on start page and default problem in shape constrained regression problem data

File size: 14.5 KB
RevLine 
[17896]1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23using System;
[17763]24using System.Collections.Generic;
25using System.Linq;
26using HEAL.Attic;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31using HeuristicLab.Parameters;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
34  [StorableType("C8539434-6FB0-47D0-9F5A-2CAE5D8B8B4F")]
[17891]35  [Item("Interval Arithmetic Bounds Estimator", "Interpreter for calculation of intervals of symbolic models.")]
36  public sealed class IntervalArithBoundsEstimator : ParameterizedNamedItem, IBoundsEstimator {
[17763]37    #region Parameters
38
39    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
40
41    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter =>
[17891]42      (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
[17763]43
44    public int EvaluatedSolutions {
45      get => EvaluatedSolutionsParameter.Value.Value;
46      set => EvaluatedSolutionsParameter.Value.Value = value;
47    }
48    #endregion
49
50    #region Constructors
51
52    [StorableConstructor]
[17891]53    private IntervalArithBoundsEstimator(StorableConstructorFlag _) : base(_) { }
[17768]54
[17910]55    protected IntervalArithBoundsEstimator(IntervalArithBoundsEstimator original, Cloner cloner) : base(original, cloner) { }
[17763]56
[17896]57    public IntervalArithBoundsEstimator() : base("Interval Arithmetic Bounds Estimator",
[17768]58      "Estimates the bounds of the model with interval arithmetic") {
59      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName,
60        "A counter for the total number of solutions the estimator has evaluated.", new IntValue(0)));
[17763]61    }
62
63    public override IDeepCloneable Clone(Cloner cloner) {
[17891]64      return new IntervalArithBoundsEstimator(this, cloner);
[17763]65    }
66
[17768]67    #endregion
[17763]68
69    #region IStatefulItem Members
70
71    private readonly object syncRoot = new object();
72
73    public void InitializeState() {
74      EvaluatedSolutions = 0;
75    }
76
77    public void ClearState() { }
78
[17768]79    #endregion
[17763]80
81    #region Evaluation
82
83    private static Instruction[] PrepareInterpreterState(
84      ISymbolicExpressionTree tree,
85      IDictionary<string, Interval> variableRanges) {
86      if (variableRanges == null)
[17891]87        throw new ArgumentNullException("No variable ranges are present!", nameof(variableRanges));
[17763]88
[17891]89      // Check if all variables used in the tree are present in the dataset
[17763]90      foreach (var variable in tree.IterateNodesPrefix().OfType<VariableTreeNode>().Select(n => n.VariableName)
91                                   .Distinct())
92        if (!variableRanges.ContainsKey(variable))
93          throw new InvalidOperationException($"No ranges for variable {variable} is present");
94
95      var code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
96      foreach (var instr in code.Where(i => i.opCode == OpCodes.Variable)) {
[17891]97        var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
[17763]98        instr.data = variableRanges[variableTreeNode.VariableName];
99      }
100
101      return code;
102    }
103
[17891]104    // Use ref parameter, because the tree will be iterated through recursively from the left-side branch to the right side
105    // Update instructionCounter, whenever Evaluate is called
[17763]106    public static Interval Evaluate(
107      Instruction[] instructions, ref int instructionCounter,
108      IDictionary<ISymbolicExpressionTreeNode, Interval> nodeIntervals = null,
109      IDictionary<string, Interval> variableIntervals = null) {
110      var currentInstr = instructions[instructionCounter];
111      instructionCounter++;
[17891]112      Interval result;
[17763]113
114      switch (currentInstr.opCode) {
115        case OpCodes.Variable: {
[17891]116            var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
117            var weightInterval = new Interval(variableTreeNode.Weight, variableTreeNode.Weight);
[17763]118
[17891]119            Interval variableInterval;
120            if (variableIntervals != null && variableIntervals.ContainsKey(variableTreeNode.VariableName))
121              variableInterval = variableIntervals[variableTreeNode.VariableName];
122            else
123              variableInterval = (Interval)currentInstr.data;
[17763]124
[17891]125            result = Interval.Multiply(variableInterval, weightInterval);
126            break;
127          }
[17763]128        case OpCodes.Constant: {
[17891]129            var constTreeNode = (ConstantTreeNode)currentInstr.dynamicNode;
130            result = new Interval(constTreeNode.Value, constTreeNode.Value);
131            break;
132          }
[17763]133        case OpCodes.Add: {
[17891]134            result = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
135            for (var i = 1; i < currentInstr.nArguments; i++) {
136              var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
137              result = Interval.Add(result, argumentInterval);
138            }
139
140            break;
[17763]141          }
142        case OpCodes.Sub: {
[17891]143            result = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
144            if (currentInstr.nArguments == 1)
145              result = Interval.Multiply(new Interval(-1, -1), result);
[17763]146
[17891]147            for (var i = 1; i < currentInstr.nArguments; i++) {
148              var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
149              result = Interval.Subtract(result, argumentInterval);
150            }
151
152            break;
[17763]153          }
[17891]154        case OpCodes.Mul: {
155            result = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
156            for (var i = 1; i < currentInstr.nArguments; i++) {
157              var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
158              result = Interval.Multiply(result, argumentInterval);
159            }
[17763]160
[17891]161            break;
[17763]162          }
163        case OpCodes.Div: {
[17891]164            result = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
165            if (currentInstr.nArguments == 1)
166              result = Interval.Divide(new Interval(1, 1), result);
[17763]167
[17891]168            for (var i = 1; i < currentInstr.nArguments; i++) {
169              var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
170              result = Interval.Divide(result, argumentInterval);
171            }
172
173            break;
174          }
175        case OpCodes.Sin: {
[17763]176            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
[17891]177            result = Interval.Sine(argumentInterval);
178            break;
[17763]179          }
180        case OpCodes.Cos: {
[17891]181            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
182            result = Interval.Cosine(argumentInterval);
183            break;
184          }
[17763]185        case OpCodes.Tan: {
[17891]186            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
187            result = Interval.Tangens(argumentInterval);
188            break;
189          }
[17763]190        case OpCodes.Tanh: {
[17891]191            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
192            result = Interval.HyperbolicTangent(argumentInterval);
193            break;
194          }
[17763]195        case OpCodes.Log: {
[17891]196            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
197            result = Interval.Logarithm(argumentInterval);
198            break;
199          }
[17763]200        case OpCodes.Exp: {
[17891]201            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
202            result = Interval.Exponential(argumentInterval);
203            break;
204          }
[17763]205        case OpCodes.Square: {
[17891]206            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
207            result = Interval.Square(argumentInterval);
208            break;
209          }
[17763]210        case OpCodes.SquareRoot: {
[17891]211            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
212            result = Interval.SquareRoot(argumentInterval);
213            break;
214          }
[17763]215        case OpCodes.Cube: {
[17891]216            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
217            result = Interval.Cube(argumentInterval);
218            break;
219          }
[17763]220        case OpCodes.CubeRoot: {
[17891]221            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
222            result = Interval.CubicRoot(argumentInterval);
223            break;
224          }
[17964]225        case OpCodes.Power: {
226          var a = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
227          var b = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
228          // support only integer powers
229          if (b.LowerBound == b.UpperBound && Math.Truncate(b.LowerBound) == b.LowerBound) {
230            result = Interval.Power(a, (int)b.LowerBound);
231          } else {
232            throw new NotSupportedException("Interval is only supported for integer powers");
233          }
234          break;
235        }
[17763]236        case OpCodes.Absolute: {
237            var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
[17891]238            result = Interval.Absolute(argumentInterval);
239            break;
[17763]240          }
[17891]241        case OpCodes.AnalyticQuotient: {
242            result = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
243            for (var i = 1; i < currentInstr.nArguments; i++) {
244              var argumentInterval = Evaluate(instructions, ref instructionCounter, nodeIntervals, variableIntervals);
[17911]245              result = Interval.AnalyticQuotient(result, argumentInterval);
[17891]246            }
[17763]247
[17891]248            break;
249          }
[17763]250        default:
251          throw new NotSupportedException(
252            $"The tree contains the unknown symbol {currentInstr.dynamicNode.Symbol}");
253      }
254
255      if (!(nodeIntervals == null || nodeIntervals.ContainsKey(currentInstr.dynamicNode)))
256        nodeIntervals.Add(currentInstr.dynamicNode, result);
257
258      return result;
259    }
260
[17768]261    #endregion
[17763]262
263    #region Helpers
264
[17768]265    private static IDictionary<string, Interval> GetOccurringVariableRanges(
266      ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
[17763]267      var variables = tree.IterateNodesPrefix().OfType<VariableTreeNode>().Select(v => v.VariableName).Distinct()
268                          .ToList();
269
270      return variables.ToDictionary(x => x, x => variableRanges.GetReadonlyDictionary()[x]);
271    }
272
273    #endregion
[17891]274 
[17763]275    public Interval GetModelBound(ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
276      lock (syncRoot) {
277        EvaluatedSolutions++;
278      }
279
280      var occuringVariableRanges = GetOccurringVariableRanges(tree, variableRanges);
281      var instructions = PrepareInterpreterState(tree, occuringVariableRanges);
282      Interval resultInterval;
[17891]283      var instructionCounter = 0;
284      resultInterval = Evaluate(instructions, ref instructionCounter, variableIntervals: occuringVariableRanges);
[17763]285
286      // because of numerical errors the bounds might be incorrect
287      if (resultInterval.IsInfiniteOrUndefined || resultInterval.LowerBound <= resultInterval.UpperBound)
288        return resultInterval;
289
290      return new Interval(resultInterval.UpperBound, resultInterval.LowerBound);
291    }
292
[17887]293    public IDictionary<ISymbolicExpressionTreeNode, Interval> GetModelNodeBounds(
[17768]294      ISymbolicExpressionTree tree, IntervalCollection variableRanges) {
[17763]295      throw new NotImplementedException();
296    }
297
[17891]298    public double GetConstraintViolation(
[17887]299      ISymbolicExpressionTree tree, IntervalCollection variableRanges, ShapeConstraint constraint) {
[17768]300      var occuringVariableRanges = GetOccurringVariableRanges(tree, variableRanges);
301      var instructions = PrepareInterpreterState(tree, occuringVariableRanges);
[17891]302      var instructionCounter = 0;
303      var modelBound = Evaluate(instructions, ref instructionCounter, variableIntervals: occuringVariableRanges);
304      if (constraint.Interval.Contains(modelBound)) return 0.0;
[17773]305
306
[17891]307      var error = 0.0;
[17773]308
[17891]309      if (!constraint.Interval.Contains(modelBound.LowerBound)) {
310        error += Math.Abs(modelBound.LowerBound - constraint.Interval.LowerBound);
[17768]311      }
312
[17891]313      if (!constraint.Interval.Contains(modelBound.UpperBound)) {
314        error += Math.Abs(modelBound.UpperBound - constraint.Interval.UpperBound);
[17768]315      }
316
[17891]317      return error;
[17768]318    }
319
320
321    public bool IsCompatible(ISymbolicExpressionTree tree) {
322      var containsUnknownSymbols = (
323        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
324        where
325          !(n.Symbol is Variable) &&
326          !(n.Symbol is Constant) &&
327          !(n.Symbol is StartSymbol) &&
328          !(n.Symbol is Addition) &&
329          !(n.Symbol is Subtraction) &&
330          !(n.Symbol is Multiplication) &&
331          !(n.Symbol is Division) &&
332          !(n.Symbol is Sine) &&
333          !(n.Symbol is Cosine) &&
334          !(n.Symbol is Tangent) &&
335          !(n.Symbol is HyperbolicTangent) &&
336          !(n.Symbol is Logarithm) &&
337          !(n.Symbol is Exponential) &&
338          !(n.Symbol is Square) &&
339          !(n.Symbol is SquareRoot) &&
340          !(n.Symbol is Cube) &&
341          !(n.Symbol is CubeRoot) &&
[17964]342          !(n.Symbol is Power) &&
[17768]343          !(n.Symbol is Absolute) &&
344          !(n.Symbol is AnalyticQuotient)
345        select n).Any();
346      return !containsUnknownSymbols;
347    }
[17763]348  }
349}
Note: See TracBrowser for help on using the repository browser.