Free cookie consent management tool by TermsFeed Policy Generator

source: branches/EfficientGlobalOptimization/HeuristicLab.Algorithms.EGO/InfillCriteria/ExpectedImprovementBase.cs @ 15332

Last change on this file since 15332 was 15064, checked in by bwerth, 7 years ago

#2745 implemented EGO as EngineAlgorithm + some simplifications in the IInfillCriterion interface

File size: 5.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using HeuristicLab.Common;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Encodings.RealVectorEncoding;
27using HeuristicLab.Parameters;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30
31// ReSharper disable once CheckNamespace
32namespace HeuristicLab.Algorithms.EGO {
33
34  [StorableClass]
35  public abstract class ExpectedImprovementBase : InfillCriterionBase {
36
37    #region ParameterNames
38    private const string ExploitationWeightParameterName = "ExploitationWeight";
39    #endregion
40
41    #region ParameterProperties
42    public IFixedValueParameter<DoubleValue> ExploitationWeightParameter => Parameters[ExploitationWeightParameterName] as IFixedValueParameter<DoubleValue>;
43    #endregion
44
45    #region Properties
46    protected double ExploitationWeight => ExploitationWeightParameter.Value.Value;
47    [Storable]
48    protected double BestFitness;
49    #endregion
50
51    #region Constructors, Serialization and Cloning
52    [StorableConstructor]
53    protected ExpectedImprovementBase(bool deserializing) : base(deserializing) { }
54    [StorableHook(HookType.AfterDeserialization)]
55    private void AfterDeserialization() {
56      RegisterEventhandlers();
57    }
58    protected ExpectedImprovementBase(ExpectedImprovementBase original, Cloner cloner) : base(original, cloner) {
59      BestFitness = original.BestFitness;
60      RegisterEventhandlers();
61    }
62    protected ExpectedImprovementBase() {
63      Parameters.Add(new FixedValueParameter<DoubleValue>(ExploitationWeightParameterName, "A value between 0 and 1 indicating the focus on exploration (0) or exploitation (1). 0.5 equates to the original EI by Jones et al.", new DoubleValue(0.5)));
64      RegisterEventhandlers();
65    }
66    #endregion
67
68    public override void Initialize() {
69      var solution = RegressionSolution as IConfidenceRegressionSolution;
70      if (solution == null) throw new ArgumentException("can not calculate EI without a regression solution providing confidence values");
71      BestFitness = FindBestFitness(solution);
72    }
73
74    protected abstract double FindBestFitness(IConfidenceRegressionSolution solution);
75
76    public override double Evaluate(RealVector vector) {
77      var model = RegressionSolution.Model as IConfidenceRegressionModel;
78      var yhat = model.GetEstimation(vector);
79      var s = Math.Sqrt(model.GetVariance(vector));
80      return Evaluate(vector, yhat, s);
81    }
82
83    protected abstract double Evaluate(RealVector vector, double estimatedFitness, double estimatedStandardDeviation);
84
85    #region Eventhandling
86    private void RegisterEventhandlers() {
87      DeregisterEventhandlers();
88      ExploitationWeightParameter.Value.ValueChanged += OnExploitationWeightChanged;
89    }
90    private void DeregisterEventhandlers() {
91      ExploitationWeightParameter.Value.ValueChanged -= OnExploitationWeightChanged;
92    }
93    private void OnExploitationWeightChanged(object sender, EventArgs e) {
94      ExploitationWeightParameter.Value.ValueChanged -= OnExploitationWeightChanged;
95      ExploitationWeightParameter.Value.Value = Math.Max(0, Math.Min(ExploitationWeight, 1));
96      ExploitationWeightParameter.Value.ValueChanged += OnExploitationWeightChanged;
97    }
98    #endregion
99
100    #region Helpers
101    public static double GetEstimatedImprovement(double bestFitness, double estimatedFitness, double modelUncertainty, double weight, bool maximization) {
102      if (Math.Abs(modelUncertainty) < double.Epsilon) return 0;
103      var diff = maximization ? (estimatedFitness - bestFitness) : (bestFitness - estimatedFitness);
104      var val = diff / modelUncertainty;
105      var res = weight * diff * StandardNormalDistribution(val) + (1 - weight) * modelUncertainty * StandardNormalDensity(val);
106      return double.IsInfinity(res) || double.IsNaN(res) ? 0 : res;
107    }
108    private static double StandardNormalDensity(double x) {
109      if (Math.Abs(x) > 10) return 0;
110      return Math.Exp(-0.5 * x * x) / Math.Sqrt(2 * Math.PI);
111    }
112    //taken from https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/
113    private static double StandardNormalDistribution(double x) {
114      if (x > 10) return 1;
115      if (x < -10) return 0;
116      const double a1 = 0.254829592;
117      const double a2 = -0.284496736;
118      const double a3 = 1.421413741;
119      const double a4 = -1.453152027;
120      const double a5 = 1.061405429;
121      const double p = 0.3275911;
122      var sign = x < 0 ? -1 : 1;
123      x = Math.Abs(x) / Math.Sqrt(2.0);
124      var t = 1.0 / (1.0 + p * x);
125      var y = 1.0 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.Exp(-x * x);
126      return 0.5 * (1.0 + sign * y);
127    }
128    #endregion
129  }
130}
Note: See TracBrowser for help on using the repository browser.