source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/IntervalEvaluator.cs @ 17295

Last change on this file since 17295 was 17295, checked in by gkronber, 3 years ago

#2994: refactoring: moved types into separate files

File size: 3.2 KB
Line 
1using System;
2using System.Collections.Generic;
3using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
4
5namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
6  public sealed class IntervalEvaluator : Interpreter<AlgebraicInterval> {
7    [ThreadStatic]
8    private IDictionary<string, Interval> intervals;
9
10    public Interval Evaluate(ISymbolicExpressionTree tree, IDictionary<string, Interval> intervals) {
11      this.intervals = intervals;
12      var code = Compile(tree);
13      Evaluate(code);
14      if (code[0].value.LowerBound.Value.Value > code[0].value.UpperBound.Value.Value) throw new InvalidProgramException($"lower: {code[0].value.LowerBound.Value.Value} > upper: {code[0].value.UpperBound.Value.Value}");
15      return new Interval(code[0].value.LowerBound.Value.Value, code[0].value.UpperBound.Value.Value);
16    }
17
18    public Interval Evaluate(ISymbolicExpressionTree tree, IDictionary<string, Interval> intervals, ISymbolicExpressionTreeNode[] paramNodes, out double[] lowerGradient, out double[] upperGradient) {
19      this.intervals = intervals;
20      var code = Compile(tree);
21      Evaluate(code);
22      lowerGradient = new double[paramNodes.Length];
23      upperGradient = new double[paramNodes.Length];
24      var l = code[0].value.LowerBound;
25      var u = code[0].value.UpperBound;
26      for (int i = 0; i < paramNodes.Length; ++i) {
27        if (paramNodes[i] == null) continue;
28        if (l.Gradient.Elements.TryGetValue(paramNodes[i], out AlgebraicDouble value)) lowerGradient[i] = value;
29        if (u.Gradient.Elements.TryGetValue(paramNodes[i], out value)) upperGradient[i] = value;
30      }
31      return new Interval(code[0].value.LowerBound.Value.Value, code[0].value.UpperBound.Value.Value);
32    }
33
34    protected override void InitializeInternalInstruction(ref Instruction instruction, ISymbolicExpressionTreeNode node) {
35      instruction.value = new AlgebraicInterval(0, 0);
36    }
37
38
39    protected override void InitializeTerminalInstruction(ref Instruction instruction, ConstantTreeNode constant) {
40      instruction.dblVal = constant.Value;
41      instruction.value = new AlgebraicInterval(
42        new MultivariateDual<AlgebraicDouble>(instruction.dblVal, constant, 1.0),
43        new MultivariateDual<AlgebraicDouble>(instruction.dblVal, constant, 1.0) // use node as key
44        );
45    }
46
47    protected override void InitializeTerminalInstruction(ref Instruction instruction, VariableTreeNode variable) {
48      instruction.dblVal = variable.Weight;
49      var v1 = instruction.dblVal * intervals[variable.VariableName].LowerBound;
50      var v2 = instruction.dblVal * intervals[variable.VariableName].UpperBound;
51      var lower = Math.Min(v1, v2);
52      var upper = Math.Max(v1, v2);
53      // we assume that the for variable nodes ( v(x,w) = w * x ) the gradient is returned for parameter w
54      instruction.value = new AlgebraicInterval(
55        low: new MultivariateDual<AlgebraicDouble>(v: lower, key: variable, dv: intervals[variable.VariableName].LowerBound),
56        high: new MultivariateDual<AlgebraicDouble>(v: upper, key: variable, dv: intervals[variable.VariableName].UpperBound)
57        );
58    }
59
60    protected override void LoadVariable(Instruction a) {
61      // nothing to do
62    }
63  }
64}
Note: See TracBrowser for help on using the repository browser.