Changeset 14818 for branches/EfficientGlobalOptimization/HeuristicLab.Algorithms.EGO/InfillCriteria/RobustImprovement.cs
- Timestamp:
- 04/04/17 12:37:52 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/EfficientGlobalOptimization/HeuristicLab.Algorithms.EGO/InfillCriteria/RobustImprovement.cs
r14768 r14818 21 21 22 22 using System; 23 using System.Collections.Generic; 24 using System.Linq; 23 25 using HeuristicLab.Common; 24 26 using HeuristicLab.Core; … … 33 35 34 36 [StorableClass] 35 [Item(" ConfidenceBoundMeassure", "Adding or Subtracting the variance * factor to the model estimation")]36 public class ConfidenceBound: InfillCriterionBase {37 [Item("RobustImprovementMeassure", "Adding or Subtracting the variance * factor to the model estimation")] 38 public class RobustImprovement : InfillCriterionBase { 37 39 38 40 #region ParameterNames 39 private const string ConfidenceWeightParameterName = "ConfidenceWeight";41 private const string KParameterName = "NearestNeighbours"; 40 42 #endregion 41 43 42 44 #region ParameterProperties 43 public IFixedValueParameter<DoubleValue> ConfidenceWeightParameter 44 { 45 get { return Parameters[ConfidenceWeightParameterName] as IFixedValueParameter<DoubleValue>; } 46 } 45 public IFixedValueParameter<IntValue> KParameter => Parameters[KParameterName] as IFixedValueParameter<IntValue>; 46 47 47 #endregion 48 48 49 49 #region Properties 50 private double ConfidenceWeight 51 { 52 get { return ConfidenceWeightParameter.Value.Value; } 53 } 50 private int K => KParameter.Value.Value; 51 52 [Storable] 53 private double MaxSolutionDist; 54 55 [Storable] 56 //TODO use VP-Tree instead of array 57 private RealVector[] Data; 54 58 #endregion 55 59 56 60 #region HL-Constructors, Serialization and Cloning 57 61 [StorableConstructor] 58 private ConfidenceBound(bool deserializing) : base(deserializing) { } 59 private ConfidenceBound(ConfidenceBound original, Cloner cloner) : base(original, cloner) { } 60 public ConfidenceBound() { 61 Parameters.Add(new FixedValueParameter<DoubleValue>(ConfidenceWeightParameterName, "A value between 0 and 1 indicating the focus on exploration (0) or exploitation (1)", new DoubleValue(0.5))); 62 private RobustImprovement(bool deserializing) : base(deserializing) { } 63 64 private RobustImprovement(RobustImprovement original, Cloner cloner) : base(original, cloner) { 65 MaxSolutionDist = original.MaxSolutionDist; 66 Data = original.Data != null ? original.Data.Select(cloner.Clone).ToArray() : null; 67 } 68 public RobustImprovement() { 69 Parameters.Add(new FixedValueParameter<IntValue>(KParameterName, "A value larger than 0 indicating how many nearestNeighbours shall be used to determine the RI meassure", new IntValue(3))); 62 70 } 63 71 public override IDeepCloneable Clone(Cloner cloner) { 64 return new ConfidenceBound(this, cloner);72 return new RobustImprovement(this, cloner); 65 73 } 66 74 #endregion 67 75 68 public override double Evaluate(IRegressionSolution solution, RealVector vector, bool maximization) { 69 var model = solution.Model as IConfidenceRegressionModel; 76 77 public override double Evaluate(RealVector vector) { 78 List<RealVector> nearestNeighbours; 79 List<double> distances; 80 Search(vector, K, out nearestNeighbours, out distances); 81 var distVectors = nearestNeighbours.Select(x => Minus(x, vector)).ToList(); 82 var sum = 0.0; 83 var wsum = 1.0; //weights for angular distance 84 var used = new HashSet<RealVector>(); 85 foreach (var distVector in distVectors) { 86 var d = Math.Pow(distances[used.Count], 0.5); 87 if (used.Count == 0) { 88 sum += d; 89 } else { 90 var w = used.Select(x => Angular(distVector, x)).Min(); 91 sum += w * d; 92 wsum += w; 93 } 94 used.Add(distVector); 95 } 96 sum /= wsum * MaxSolutionDist; //normalize 97 return sum; 98 } 99 public override bool Maximization() { 100 return ExpensiveMaximization; 101 } 102 protected override void Initialize() { 103 var model = RegressionSolution.Model as IConfidenceRegressionModel; 70 104 if (model == null) throw new ArgumentException("can not calculate EI without confidence measure"); 71 var yhat = model.GetEstimation(vector); 72 var s = Math.Sqrt(model.GetVariance(vector)) * ConfidenceWeight; 73 return maximization ? yhat + s : yhat - s; 105 Data = new RealVector[RegressionSolution.ProblemData.Dataset.Rows]; 106 for (var i = 0; i < Data.Length; i++) { 107 Data[i] = new RealVector(Encoding.Length); 108 for (var j = 0; j < Encoding.Length; j++) 109 Data[i][j] = RegressionSolution.ProblemData.Dataset.GetDoubleValue(i, j); 110 } 111 112 var maxSolution = new double[Encoding.Length]; 113 var minSolution = new double[Encoding.Length]; 114 for (var i = 0; i < Encoding.Length; i++) { 115 var j = i % Encoding.Bounds.Rows; 116 maxSolution[i] = Encoding.Bounds[j, 1]; 117 minSolution[i] = Encoding.Bounds[j, 0]; 118 } 119 MaxSolutionDist = Euclidian(maxSolution, minSolution) / Data.Length; 74 120 } 75 121 122 #region Helpers 123 private static double Euclidian(IEnumerable<double> a, IEnumerable<double> b) { 124 return Math.Sqrt(a.Zip(b, (d, d1) => d - d1).Sum(d => d * d)); 125 } 126 private static double Angular(RealVector a, RealVector b) { 127 var innerProduct = a.Zip(b, (x, y) => x * y).Sum(); 128 var res = Math.Acos(innerProduct / (Norm(a) * Norm(b))) / Math.PI; 129 return double.IsNaN(res) ? 0 : res; 130 } 131 private static double Norm(IEnumerable<double> a) { 132 return Math.Sqrt(a.Sum(d => d * d)); 133 } 134 private static RealVector Minus(RealVector a, RealVector b) { 135 return new RealVector(a.Zip(b, (d, d1) => d - d1).ToArray()); 136 } 137 138 private void Search(RealVector vector, int k, out List<RealVector> nearestNeighbours, out List<double> distances) { 139 var neighbours = new SortedList<double, RealVector>(new DuplicateKeyComparer<double>()); 140 foreach (var n in Data) neighbours.Add(Euclidian(n, vector), n); 141 nearestNeighbours = new List<RealVector>(); 142 143 distances = new List<double>(); 144 foreach (var entry in neighbours) { 145 nearestNeighbours.Add(entry.Value); 146 distances.Add(entry.Key); 147 if (distances.Count == k) break; 148 } 149 } 150 #endregion 151 152 public class DuplicateKeyComparer<TKey> : IComparer<TKey> where TKey : IComparable { 153 public int Compare(TKey x, TKey y) { 154 var result = x.CompareTo(y); 155 return result == 0 ? 1 : result; 156 } 157 } 76 158 } 77 159 }
Note: See TracChangeset
for help on using the changeset viewer.