#region License Information /* HeuristicLab * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.RealVectorEncoding; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; // ReSharper disable once CheckNamespace namespace HeuristicLab.Algorithms.EGO { [StorableClass] [Item("ExpectedImprovementMeassure", "Extension of the Expected Improvement to a weighted version by ANDRAS SÓBESTER , STEPHEN J. LEARY and ANDY J. KEANE in \n On the Design of Optimization Strategies Based on Global Response Surface Approximation Models")] public class ExpectedImprovement : InfillCriterionBase { #region ParameterNames private const string ExploitationWeightParameterName = "ExploitationWeight"; #endregion #region ParameterProperties public IFixedValueParameter ExploitationWeightParameter { get { return Parameters[ExploitationWeightParameterName] as IFixedValueParameter; } } #endregion #region Properties private double ExploitationWeight { get { return ExploitationWeightParameter.Value.Value; } } #endregion #region HL-Constructors, Serialization and Cloning [StorableConstructor] private ExpectedImprovement(bool deserializing) : base(deserializing) { } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { RegisterEventhandlers(); } private ExpectedImprovement(ExpectedImprovement original, Cloner cloner) : base(original, cloner) { RegisterEventhandlers(); } public ExpectedImprovement() { Parameters.Add(new FixedValueParameter(ExploitationWeightParameterName, "A value between 0 and 1 indicating the focus on exploration (0) or exploitation (1)", new DoubleValue(0.5))); RegisterEventhandlers(); } public override IDeepCloneable Clone(Cloner cloner) { return new ExpectedImprovement(this, cloner); } #endregion public override double Evaluate(IRegressionSolution solution, RealVector vector, bool maximization) { if (maximization) throw new NotImplementedException("Expected Improvement for maximization not yet implemented"); var model = solution.Model as IConfidenceRegressionModel; if (model == null) throw new ArgumentException("can not calculate EI without confidence measure"); var yhat = model.GetEstimation(vector); var min = solution.ProblemData.TargetVariableTrainingValues.Min(); var s = Math.Sqrt(model.GetVariance(vector)); return GetEstimatedImprovement(min, yhat, s, ExploitationWeight); } public override bool Maximization(bool expensiveProblemMaximization) { return true; } #region Eventhandling private void RegisterEventhandlers() { DeregisterEventhandlers(); ExploitationWeightParameter.Value.ValueChanged += OnExploitationWeightChanged; } private void DeregisterEventhandlers() { ExploitationWeightParameter.Value.ValueChanged -= OnExploitationWeightChanged; } private void OnExploitationWeightChanged(object sender, EventArgs e) { ExploitationWeightParameter.Value.ValueChanged -= OnExploitationWeightChanged; ExploitationWeightParameter.Value.Value = Math.Max(0, Math.Min(ExploitationWeight, 1)); ExploitationWeightParameter.Value.ValueChanged += OnExploitationWeightChanged; } #endregion #region Helpers private static double GetEstimatedImprovement(double ymin, double yhat, double s, double w) { if (Math.Abs(s) < double.Epsilon) return 0; var val = (ymin - yhat) / s; var res = w * (ymin - yhat) * StandardNormalDistribution(val) + (1 - w) * s * StandardNormalDensity(val); return double.IsInfinity(res) || double.IsNaN(res) ? 0 : res; } private static double StandardNormalDensity(double x) { if (Math.Abs(x) > 10) return 0; return Math.Exp(-0.5 * x * x) / Math.Sqrt(2 * Math.PI); } //taken from https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/ private static double StandardNormalDistribution(double x) { if (x > 10) return 1; if (x < -10) return 0; const double a1 = 0.254829592; const double a2 = -0.284496736; const double a3 = 1.421413741; const double a4 = -1.453152027; const double a5 = 1.061405429; const double p = 0.3275911; var sign = x < 0 ? -1 : 1; x = Math.Abs(x) / Math.Sqrt(2.0); var t = 1.0 / (1.0 + p * x); var y = 1.0 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.Exp(-x * x); return 0.5 * (1.0 + sign * y); } #endregion } }