#region License Information
/* HeuristicLab
* Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Encodings.RealVectorEncoding;
using HeuristicLab.Parameters;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using HeuristicLab.Problems.DataAnalysis;
// ReSharper disable once CheckNamespace
namespace HeuristicLab.Algorithms.EGO {
[StorableClass]
[Item("ExpectedImprovementMeassure", "Extension of the Expected Improvement to a weighted version by ANDRAS SÓBESTER , STEPHEN J. LEARY and ANDY J. KEANE in \n On the Design of Optimization Strategies Based on Global Response Surface Approximation Models")]
public class ExpectedImprovement : InfillCriterionBase {
#region ParameterNames
private const string ExploitationWeightParameterName = "ExploitationWeight";
#endregion
#region ParameterProperties
public IFixedValueParameter ExploitationWeightParameter
{
get { return Parameters[ExploitationWeightParameterName] as IFixedValueParameter; }
}
#endregion
#region Properties
private double ExploitationWeight
{
get { return ExploitationWeightParameter.Value.Value; }
}
#endregion
#region HL-Constructors, Serialization and Cloning
[StorableConstructor]
private ExpectedImprovement(bool deserializing) : base(deserializing) { }
[StorableHook(HookType.AfterDeserialization)]
private void AfterDeserialization() {
RegisterEventhandlers();
}
private ExpectedImprovement(ExpectedImprovement original, Cloner cloner) : base(original, cloner) {
RegisterEventhandlers();
}
public ExpectedImprovement() {
Parameters.Add(new FixedValueParameter(ExploitationWeightParameterName, "A value between 0 and 1 indicating the focus on exploration (0) or exploitation (1)", new DoubleValue(0.5)));
RegisterEventhandlers();
}
public override IDeepCloneable Clone(Cloner cloner) {
return new ExpectedImprovement(this, cloner);
}
#endregion
public override double Evaluate(IRegressionSolution solution, RealVector vector, bool maximization) {
if (maximization) throw new NotImplementedException("Expected Improvement for maximization not yet implemented");
var model = solution.Model as IConfidenceRegressionModel;
if (model == null) throw new ArgumentException("can not calculate EI without confidence measure");
var yhat = model.GetEstimation(vector);
var min = solution.ProblemData.TargetVariableTrainingValues.Min();
var s = Math.Sqrt(model.GetVariance(vector));
return GetEstimatedImprovement(min, yhat, s, ExploitationWeight);
}
public override bool Maximization(bool expensiveProblemMaximization) {
return true;
}
#region Eventhandling
private void RegisterEventhandlers() {
DeregisterEventhandlers();
ExploitationWeightParameter.Value.ValueChanged += OnExploitationWeightChanged;
}
private void DeregisterEventhandlers() {
ExploitationWeightParameter.Value.ValueChanged -= OnExploitationWeightChanged;
}
private void OnExploitationWeightChanged(object sender, EventArgs e) {
ExploitationWeightParameter.Value.ValueChanged -= OnExploitationWeightChanged;
ExploitationWeightParameter.Value.Value = Math.Max(0, Math.Min(ExploitationWeight, 1));
ExploitationWeightParameter.Value.ValueChanged += OnExploitationWeightChanged;
}
#endregion
#region Helpers
private static double GetEstimatedImprovement(double ymin, double yhat, double s, double w) {
if (Math.Abs(s) < double.Epsilon) return 0;
var val = (ymin - yhat) / s;
var res = w * (ymin - yhat) * StandardNormalDistribution(val) + (1 - w) * s * StandardNormalDensity(val);
return double.IsInfinity(res) || double.IsNaN(res) ? 0 : res;
}
private static double StandardNormalDensity(double x) {
if (Math.Abs(x) > 10) return 0;
return Math.Exp(-0.5 * x * x) / Math.Sqrt(2 * Math.PI);
}
//taken from https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/
private static double StandardNormalDistribution(double x) {
if (x > 10) return 1;
if (x < -10) return 0;
const double a1 = 0.254829592;
const double a2 = -0.284496736;
const double a3 = 1.421413741;
const double a4 = -1.453152027;
const double a5 = 1.061405429;
const double p = 0.3275911;
var sign = x < 0 ? -1 : 1;
x = Math.Abs(x) / Math.Sqrt(2.0);
var t = 1.0 / (1.0 + p * x);
var y = 1.0 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.Exp(-x * x);
return 0.5 * (1.0 + sign * y);
}
#endregion
}
}