Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3140_NumberSymbol/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/MultiObjective/NMSEMultiObjectiveConstraintsEvaluator.cs @ 18100

Last change on this file since 18100 was 18100, checked in by chaider, 2 years ago

#3140

  • some more refactoring
  • added possibility to set value of num nodes in infix parser
  • changed displaying style of number
File size: 8.5 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using HEAL.Attic;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using HeuristicLab.Parameters;
33
34namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
35  [Item("NMSE Evaluator with shape constraints (multi-objective)",
36    "Calculates the NMSE and constraint violations for a symbolic regression model.")]
37  [StorableType("8E9D76B7-ED9C-43E7-9898-01FBD3633880")]
38  public class NMSEMultiObjectiveConstraintsEvaluator : SymbolicRegressionMultiObjectiveEvaluator, IMultiObjectiveConstraintsEvaluator {
39    private const string NumConstraintsParameterName = "NumConstraints";
40    private const string BoundsEstimatorParameterName = "BoundsEstimator";
41
42    public IFixedValueParameter<IntValue> NumConstraintsParameter =>
43      (IFixedValueParameter<IntValue>)Parameters[NumConstraintsParameterName];
44
45    public IValueParameter<IBoundsEstimator> BoundsEstimatorParameter =>
46      (IValueParameter<IBoundsEstimator>)Parameters[BoundsEstimatorParameterName];
47
48    public int NumConstraints {
49      get => NumConstraintsParameter.Value.Value;
50      set {
51        NumConstraintsParameter.Value.Value = value;
52      }
53    }
54
55    public IBoundsEstimator BoundsEstimator {
56      get => BoundsEstimatorParameter.Value;
57      set => BoundsEstimatorParameter.Value = value;
58    }
59
60    public override IEnumerable<bool> Maximization => new bool[1 + NumConstraints]; // minimize all objectives
61
62    #region Constructors
63
64    public NMSEMultiObjectiveConstraintsEvaluator() {
65      Parameters.Add(new FixedValueParameter<IntValue>(NumConstraintsParameterName, new IntValue(0)));
66      Parameters.Add(new ValueParameter<IBoundsEstimator>(BoundsEstimatorParameterName, new IntervalArithBoundsEstimator()));
67    }
68
69    [StorableConstructor]
70    protected NMSEMultiObjectiveConstraintsEvaluator(StorableConstructorFlag _) : base(_) { }
71
72    protected NMSEMultiObjectiveConstraintsEvaluator(NMSEMultiObjectiveConstraintsEvaluator original, Cloner cloner) : base(original, cloner) { }
73
74    #endregion
75
76    [StorableHook(HookType.AfterDeserialization)]
77    private void AfterDeserialization() { }
78
79    public override IDeepCloneable Clone(Cloner cloner) {
80      return new NMSEMultiObjectiveConstraintsEvaluator(this, cloner);
81    }
82
83
84    public override IOperation InstrumentedApply() {
85      var rows = GenerateRowsToEvaluate();
86      var tree = SymbolicExpressionTreeParameter.ActualValue;
87      var problemData = ProblemDataParameter.ActualValue;
88      var interpreter = SymbolicDataAnalysisTreeInterpreterParameter.ActualValue;
89      var estimationLimits = EstimationLimitsParameter.ActualValue;
90      var applyLinearScaling = ApplyLinearScalingParameter.ActualValue.Value;
91
92      if (UseConstantOptimization) {
93        SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, rows,
94          false,
95          ConstantOptimizationIterations,
96          ConstantOptimizationUpdateVariableWeights,
97          estimationLimits.Lower,
98          estimationLimits.Upper);
99      } else {
100        if (applyLinearScaling) {
101          var rootNode = new ProgramRootSymbol().CreateTreeNode();
102          var startNode = new StartSymbol().CreateTreeNode();
103          var offset = tree.Root.GetSubtree(0) //Start
104                                .GetSubtree(0); //Offset
105          var scaling = offset.GetSubtree(0);
106
107          //Check if tree contains offset and scaling nodes
108          if (!(offset.Symbol is Addition) || !(scaling.Symbol is Multiplication))
109            throw new ArgumentException($"{ItemName} can only be used with LinearScalingGrammar.");
110
111
112          var t = (ISymbolicExpressionTreeNode)scaling.GetSubtree(0).Clone();
113          rootNode.AddSubtree(startNode);
114          startNode.AddSubtree(t);
115          var newTree = new SymbolicExpressionTree(rootNode);
116
117          //calculate alpha and beta for scaling
118          var estimatedValues = interpreter.GetSymbolicExpressionTreeValues(newTree, problemData.Dataset, rows);
119
120          var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
121          OnlineLinearScalingParameterCalculator.Calculate(estimatedValues, targetValues, out var alpha, out var beta,
122            out var errorState);
123          if (errorState == OnlineCalculatorError.None) {
124            //Set alpha and beta to the scaling nodes from ia grammar
125            var offsetParameter = offset.GetSubtree(1) as NumberTreeNode;
126            offsetParameter.Value = alpha;
127            var scalingParameter = scaling.GetSubtree(1) as NumberTreeNode;
128            scalingParameter.Value = beta;
129          }
130        } // else alpha and beta are evolved
131      }
132
133      var qualities = Calculate(interpreter, tree, estimationLimits.Lower, estimationLimits.Upper, problemData,
134        rows, BoundsEstimator, DecimalPlaces);
135      QualitiesParameter.ActualValue = new DoubleArray(qualities);
136      return base.InstrumentedApply();
137    }
138
139    public override double[] Evaluate(
140      IExecutionContext context, ISymbolicExpressionTree tree,
141      IRegressionProblemData problemData,
142      IEnumerable<int> rows) {
143      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
144      EstimationLimitsParameter.ExecutionContext = context;
145      ApplyLinearScalingParameter.ExecutionContext = context;
146
147      var quality = Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree,
148        EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper,
149        problemData, rows, BoundsEstimator, DecimalPlaces);
150
151      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
152      EstimationLimitsParameter.ExecutionContext = null;
153      ApplyLinearScalingParameter.ExecutionContext = null;
154
155      return quality;
156    }
157
158
159    public static double[] Calculate(
160      ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
161      ISymbolicExpressionTree solution, double lowerEstimationLimit,
162      double upperEstimationLimit,
163      IRegressionProblemData problemData, IEnumerable<int> rows, IBoundsEstimator estimator, int decimalPlaces) {
164      OnlineCalculatorError errorState;
165      var estimatedValues = interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows);
166      var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
167      var constraints = Enumerable.Empty<ShapeConstraint>();
168      if (problemData is ShapeConstrainedRegressionProblemData scProbData) {
169        constraints = scProbData.ShapeConstraints.EnabledConstraints;
170      }
171      var intervalCollection = problemData.VariableRanges;
172
173      double nmse;
174
175      var boundedEstimatedValues = estimatedValues.LimitToRange(lowerEstimationLimit, upperEstimationLimit);
176      nmse = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(targetValues, boundedEstimatedValues, out errorState);
177
178      if (errorState != OnlineCalculatorError.None) nmse = 1.0;
179
180      if (decimalPlaces >= 0)
181        nmse = Math.Round(nmse, decimalPlaces);
182
183      if (nmse > 1)
184        nmse = 1.0;
185
186      var objectives = new List<double> { nmse };
187      var violations = IntervalUtil.GetConstraintViolations(constraints, estimator, intervalCollection, solution);
188      foreach (var violation in violations) {
189        if (double.IsNaN(violation) || double.IsInfinity(violation)) {
190          objectives.Add(double.MaxValue);
191        } else {
192          objectives.Add(Math.Round(violation, decimalPlaces));
193        }
194      }
195
196      return objectives.ToArray();
197    }
198  }
199}
Note: See TracBrowser for help on using the repository browser.