Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3076_IA_evaluators_analyzers_reintegration/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/MultiObjective/NMSEConstraintsEvaluator.cs @ 17899

Last change on this file since 17899 was 17899, checked in by gkronber, 3 years ago

#3076 refactoring for branch reintegration

File size: 8.4 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using HEAL.Attic;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using HeuristicLab.Parameters;
33
34namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.MultiObjective {
35  [Item("NMSE Evaluator (multi-objective, with shape-constraints)",
36    "Calculates the NMSE and the constraints violations of a symbolic regression solution as objectives.")]
37  [StorableType("8E9D76B7-ED9C-43E7-9898-01FBD3633880")]
38  public class NMSEConstraintsEvaluator : SymbolicRegressionMultiObjectiveEvaluator {
39    public const string NumObjectivesParameterName = "NumObjectives";
40    private const string BoundsEstimatorParameterName = "BoundsEstimator";
41
42    public IFixedValueParameter<IntValue> NumObjectivesParameter =>
43      (IFixedValueParameter<IntValue>)Parameters[NumObjectivesParameterName];
44
45    public IValueParameter<IBoundsEstimator> BoundsEstimatorParameter =>
46      (IValueParameter<IBoundsEstimator>)Parameters[BoundsEstimatorParameterName];
47
48    [Storable]
49    private bool[] maximization;
50
51    public int NumObjectives {
52      get => NumObjectivesParameter.Value.Value;
53      set {
54        NumObjectivesParameter.Value.Value = value;
55
56        /*
57         * First objective is to minimize the NMSE
58         * All following objectives have to be minimized ==> Constraints
59         */
60        this.maximization = new bool[value];
61      }
62    }
63
64    public IBoundsEstimator BoundsEstimator {
65      get => BoundsEstimatorParameter.Value;
66      set => BoundsEstimatorParameter.Value = value;
67    }
68
69    #region Constructors
70
71    public NMSEConstraintsEvaluator() {
72      Parameters.Add(new FixedValueParameter<IntValue>(NumObjectivesParameterName, new IntValue(2)));
73      Parameters.Add(new ValueParameter<IBoundsEstimator>(BoundsEstimatorParameterName, new IntervalArithBoundsEstimator()));
74      maximization = new bool[2];
75    }
76
77    [StorableConstructor]
78    protected NMSEConstraintsEvaluator(StorableConstructorFlag _) : base(_) { }
79
80    protected NMSEConstraintsEvaluator(NMSEConstraintsEvaluator original, Cloner cloner) : base(original, cloner) {
81      this.maximization = (bool[])original.maximization.Clone();
82    }
83
84    #endregion
85
86    [StorableHook(HookType.AfterDeserialization)]
87    private void AfterDeserialization() { }
88
89    public override IDeepCloneable Clone(Cloner cloner) {
90      return new NMSEConstraintsEvaluator(this, cloner);
91    }
92
93
94    public override IOperation InstrumentedApply() {
95      var rows = GenerateRowsToEvaluate();
96      var tree = SymbolicExpressionTreeParameter.ActualValue;
97      var problemData = ProblemDataParameter.ActualValue;
98      var interpreter = SymbolicDataAnalysisTreeInterpreterParameter.ActualValue;
99      var estimationLimits = EstimationLimitsParameter.ActualValue;
100      var applyLinearScaling = ApplyLinearScalingParameter.ActualValue.Value;
101
102      if (UseConstantOptimization) {
103        SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, rows,
104          false,
105          ConstantOptimizationIterations,
106          ConstantOptimizationUpdateVariableWeights,
107          estimationLimits.Lower,
108          estimationLimits.Upper);
109      } else {
110        if (applyLinearScaling) {
111          //Check for interval arithmetic grammar
112          if (!(tree.Root.Grammar is IntervalArithmeticGrammar))
113            throw new ArgumentException($"{ItemName} can only be used with IntervalArithmeticGrammar.");
114
115          var rootNode = new ProgramRootSymbol().CreateTreeNode();
116          var startNode = new StartSymbol().CreateTreeNode();
117          var offset = tree.Root.GetSubtree(0) //Start
118                                .GetSubtree(0); //Offset
119          var scaling = offset.GetSubtree(0);
120          var t = (ISymbolicExpressionTreeNode)scaling.GetSubtree(0).Clone();
121          rootNode.AddSubtree(startNode);
122          startNode.AddSubtree(t);
123          var newTree = new SymbolicExpressionTree(rootNode);
124
125          //calculate alpha and beta for scaling
126          var estimatedValues = interpreter.GetSymbolicExpressionTreeValues(newTree, problemData.Dataset, rows);
127
128          var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
129          OnlineLinearScalingParameterCalculator.Calculate(estimatedValues, targetValues, out var alpha, out var beta,
130            out var errorState);
131          if (errorState == OnlineCalculatorError.None) {
132            //Set alpha and beta to the scaling nodes from ia grammar
133            var offsetParameter = offset.GetSubtree(1) as ConstantTreeNode;
134            offsetParameter.Value = alpha;
135            var scalingParameter = scaling.GetSubtree(1) as ConstantTreeNode;
136            scalingParameter.Value = beta;
137          }
138        }
139      }
140
141      var qualities = Calculate(interpreter, tree, estimationLimits.Lower, estimationLimits.Upper, problemData,
142        rows, BoundsEstimator);
143      QualitiesParameter.ActualValue = new DoubleArray(qualities);
144      return base.InstrumentedApply();
145    }
146
147    public override double[] Evaluate(
148      IExecutionContext context, ISymbolicExpressionTree tree,
149      IRegressionProblemData problemData,
150      IEnumerable<int> rows) {
151      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
152      EstimationLimitsParameter.ExecutionContext = context;
153      ApplyLinearScalingParameter.ExecutionContext = context;
154
155      var quality = Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree,
156        EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper,
157        problemData, rows, BoundsEstimator);
158
159      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
160      EstimationLimitsParameter.ExecutionContext = null;
161      ApplyLinearScalingParameter.ExecutionContext = null;
162
163      return quality;
164    }
165
166
167    public static double[] Calculate(
168      ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
169      ISymbolicExpressionTree solution, double lowerEstimationLimit,
170      double upperEstimationLimit,
171      IRegressionProblemData problemData, IEnumerable<int> rows, IBoundsEstimator estimator) {
172      OnlineCalculatorError errorState;
173      var estimatedValues = interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows);
174      var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
175      var constraints = problemData.ShapeConstraints.EnabledConstraints;
176      var intervalCollection = problemData.VariableRanges;
177
178      double nmse;
179
180      var boundedEstimatedValues = estimatedValues.LimitToRange(lowerEstimationLimit, upperEstimationLimit);
181      nmse = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(targetValues, boundedEstimatedValues, out errorState);
182      if (errorState != OnlineCalculatorError.None) nmse = 1.0;
183
184      if (nmse > 1)
185        nmse = 1.0;
186
187      var objectives = new List<double> { nmse };
188      var violations = IntervalUtil.GetConstraintViolations(constraints, estimator, intervalCollection, solution);
189      foreach (var violation in violations) {
190        if (double.IsNaN(violation) || double.IsInfinity(violation)) {
191          objectives.Add(double.MaxValue);
192        } else {
193          objectives.Add(violation);
194        }
195      }
196
197      return objectives.ToArray();
198    }
199
200    public override IEnumerable<bool> Maximization => maximization;
201  }
202}
Note: See TracBrowser for help on using the repository browser.