source: branches/2205_OptimizationNetworks/HeuristicLab.Networks.IntegratedOptimization.SurrogateModeling/3.3/FullEvaluationAlgorithm.cs @ 15897

Last change on this file since 15897 was 15896, checked in by jkarder, 4 years ago

#2205: added surrogate modeling network

File size: 12.5 KB
Line 
1using System;
2using System.Linq;
3using System.Threading;
4using System.Threading.Tasks;
5using HeuristicLab.Analysis;
6using HeuristicLab.Common;
7using HeuristicLab.Core;
8using HeuristicLab.Data;
9using HeuristicLab.Encodings.RealVectorEncoding;
10using HeuristicLab.Optimization;
11using HeuristicLab.Parameters;
12using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
13using HeuristicLab.Problems.DataAnalysis;
14using HeuristicLab.Problems.TestFunctions;
15
16namespace HeuristicLab.Networks.IntegratedOptimization.SurrogateModeling {
17  [Item("Full Evaluation Algorithm", "")]
18  [Creatable]
19  [StorableClass]
20  public sealed class FullEvaluationAlgorithm : BasicAlgorithm {
21    private readonly object locker = new object();
22
23    private SemaphoreSlim tickets;
24    private LimitedPriorityQueue exploitationQueue;
25    private LimitedPriorityQueue explorationQueue;
26
27    public override bool SupportsPause { get { return true; } }
28
29    public override Type ProblemType { get { return typeof(SingleObjectiveTestFunctionProblem); } }
30    public new SingleObjectiveTestFunctionProblem Problem {
31      get { return (SingleObjectiveTestFunctionProblem)base.Problem; }
32      set { base.Problem = value; }
33    }
34
35    public Action<RealVector, double> NotifyEvaluation { get; set; }
36
37    #region Parameter Names
38    private const string MaxEvaluationsParameterName = "MaxEvaluations";
39    private const string MaxConcurrentEvaluationsParameterName = "MaxConcurrentEvaluations";
40    private const string MaxBufferLengthParameterName = "MaxBufferLength";
41    private const string AdditionalEvaluationDurationParameterName = "AdditionalEvaluationDuration";
42    private const string RegressionSolutionParameterName = "RegressionSolution";
43    private const string SortAccordingToExpectedImprovementParameterName = "SortAccordingToExpectedImprovement";
44    #endregion Parameter Names
45
46    #region Parameters
47    public IValueParameter<IntValue> MaxEvaluationsParameter {
48      get { return (IValueParameter<IntValue>)Parameters[MaxEvaluationsParameterName]; }
49    }
50
51    public IValueParameter<IntValue> MaxConcurrentEvaluationsParameter {
52      get { return (IValueParameter<IntValue>)Parameters[MaxConcurrentEvaluationsParameterName]; }
53    }
54
55    public IValueParameter<IntValue> MaxBufferLengthParameter {
56      get { return (IValueParameter<IntValue>)Parameters[MaxBufferLengthParameterName]; }
57    }
58
59    public IValueParameter<TimeSpanValue> AdditionalEvaluationDurationParameter {
60      get { return (IValueParameter<TimeSpanValue>)Parameters[AdditionalEvaluationDurationParameterName]; }
61    }
62
63    public IValueParameter<IRegressionSolution> RegressionSolutionParameter {
64      get { return (IValueParameter<IRegressionSolution>)Parameters[RegressionSolutionParameterName]; }
65    }
66
67    public IValueParameter<BoolValue> SortAccordingToExpectedImprovementParameter {
68      get { return (IValueParameter<BoolValue>)Parameters[SortAccordingToExpectedImprovementParameterName]; }
69    }
70    #endregion Parameters
71
72    #region Parameter Properties
73    private IntValue MaxEvaluations {
74      get { return MaxEvaluationsParameter.Value; }
75      set { MaxEvaluationsParameter.Value = value; }
76    }
77
78    public int MaxConcurrentEvaluations {
79      get { return MaxConcurrentEvaluationsParameter.Value.Value; }
80      set { MaxConcurrentEvaluationsParameter.Value.Value = value; }
81    }
82
83    public int MaxBufferLength {
84      get { return MaxBufferLengthParameter.Value.Value; }
85      set { MaxBufferLengthParameter.Value.Value = value; }
86    }
87
88    public TimeSpan AdditionalEvaluationDuration {
89      get { return AdditionalEvaluationDurationParameter.Value.Value; }
90      set { AdditionalEvaluationDurationParameter.Value.Value = value; }
91    }
92
93    public IRegressionSolution RegressionSolution {
94      get { return RegressionSolutionParameter.Value; }
95      set { RegressionSolutionParameter.Value = value; }
96    }
97
98    public BoolValue SortAccordingToExpectedImprovement {
99      get { return SortAccordingToExpectedImprovementParameter.Value; }
100      set { SortAccordingToExpectedImprovementParameter.Value = value; }
101    }
102    #endregion Parameter Properties
103
104    #region Result Names
105    private const string EvaluatedSolutionsResultName = "EvaluatedSolutions";
106    private const string EvaluatedExplorationSolutionsResultName = "EvaluatedExplorationSolutions";
107    private const string EvaluatedExploitationSolutionsResultName = "EvaluatedExploitationSolutions";
108    private const string FullEvaluationQualitiesResultName = "FullEvaluationQualities";
109    #endregion Result Names
110
111    #region Results
112    private IntValue EvaluatedSolutionsResult {
113      get { return (IntValue)Results[EvaluatedSolutionsResultName].Value; }
114    }
115
116    private IItemList<RealVector> EvaluatedExplorationSolutionsResult {
117      get { return (IItemList<RealVector>)Results[EvaluatedExplorationSolutionsResultName].Value; }
118    }
119
120    private IItemList<RealVector> EvaluatedExploitationSolutionsResult {
121      get { return (IItemList<RealVector>)Results[EvaluatedExploitationSolutionsResultName].Value; }
122    }
123
124    private DataTable FullEvaluationQualitiesResult {
125      get { return (DataTable)Results[FullEvaluationQualitiesResultName].Value; }
126    }
127    #endregion Results
128
129    #region Constructors & Cloning
130    [StorableConstructor]
131    private FullEvaluationAlgorithm(bool deserializing) : base(deserializing) { }
132    private FullEvaluationAlgorithm(FullEvaluationAlgorithm original, Cloner cloner) : base(original, cloner) {
133      RegressionSolutionParameter.ValueChanged += RegressionSolutionParameter_ValueChanged;
134    }
135    public FullEvaluationAlgorithm() {
136      Parameters.Add(new ValueParameter<IntValue>(MaxEvaluationsParameterName, new IntValue(200)));
137      Parameters.Add(new ValueParameter<IntValue>(MaxConcurrentEvaluationsParameterName, new IntValue(1)));
138      Parameters.Add(new ValueParameter<IntValue>(MaxBufferLengthParameterName, new IntValue(100)));
139      Parameters.Add(new ValueParameter<TimeSpanValue>(AdditionalEvaluationDurationParameterName, new TimeSpanValue(TimeSpan.FromSeconds(5.0))));
140      Parameters.Add(new ValueParameter<IRegressionSolution>(RegressionSolutionParameterName, ""));
141      Parameters.Add(new ValueParameter<BoolValue>(SortAccordingToExpectedImprovementParameterName, "", new BoolValue(true)));
142
143      RegressionSolutionParameter.ValueChanged += RegressionSolutionParameter_ValueChanged;
144
145      Problem = new SingleObjectiveTestFunctionProblem();
146    }
147
148    public override IDeepCloneable Clone(Cloner cloner) {
149      return new FullEvaluationAlgorithm(this, cloner);
150    }
151    #endregion Constructors & Cloning
152
153    [StorableHook(HookType.AfterDeserialization)]
154    private void AfterDeserialization() {
155      RegressionSolutionParameter.ValueChanged += RegressionSolutionParameter_ValueChanged;
156    }
157
158    private void RegressionSolutionParameter_ValueChanged(object sender, EventArgs e) {
159      lock (exploitationQueue) {
160        var points = exploitationQueue.Select(x => x.Item2).ToArray();
161        var solution = RegressionSolutionParameter.Value;
162        var useExpectedImprovement = SortAccordingToExpectedImprovement.Value;
163
164        var results = ExpectedImprovementHelpers.Evaluate(points, solution, useExpectedImprovement);
165        var targets = (useExpectedImprovement ? results.Select(x => -x) : results).ToArray();
166
167        exploitationQueue.Clear();
168        for (int i = 0; i < points.Length; i++)
169          exploitationQueue.Enqueue(targets[i], points[i]);
170      }
171    }
172
173    protected override void Initialize(CancellationToken cancellationToken) {
174      var maxConcurrentEvals = MaxConcurrentEvaluationsParameter.Value.Value;
175      tickets = new SemaphoreSlim(maxConcurrentEvals, maxConcurrentEvals);
176
177      var maxBufferLength = MaxBufferLengthParameter.Value.Value;
178      exploitationQueue = new LimitedPriorityQueue(maxBufferLength, new LimitedPriorityQueue.AscendingPriorityComparer());
179      explorationQueue = new LimitedPriorityQueue(maxBufferLength, new LimitedPriorityQueue.DescendingPriorityComparer());
180
181      Results.Add(new Result(EvaluatedSolutionsResultName, new IntValue(0)));
182      Results.Add(new Result(EvaluatedExplorationSolutionsResultName, new ItemList<RealVector>()));
183      Results.Add(new Result(EvaluatedExploitationSolutionsResultName, new ItemList<RealVector>()));
184      Results.Add(new Result(FullEvaluationQualitiesResultName, new DataTable(FullEvaluationQualitiesResultName) {
185        Rows = { new DataRow("Exploration"), new DataRow("Exploitation"), new DataRow("Best") }
186      }));
187    }
188
189    protected override void Run(CancellationToken cancellationToken) {
190      bool doExploration = false;
191
192      while (!cancellationToken.IsCancellationRequested && EvaluatedSolutionsResult.Value < MaxEvaluations.Value) {
193        try { tickets.Wait(cancellationToken); } catch (OperationCanceledException) { break; }
194
195        RealVector realVector;
196        if (doExploration = !doExploration) {
197          lock (explorationQueue) {
198            if (explorationQueue.TryDequeue(out realVector)) {
199              DoFullEvaluationAsync(realVector, EvaluationSolutionType.Exploration, cancellationToken);
200            } else {
201              tickets.Release();
202            }
203          }
204        } else {
205          lock (exploitationQueue) {
206            if (exploitationQueue.TryDequeue(out realVector)) {
207              DoFullEvaluationAsync(realVector, EvaluationSolutionType.Exploitation, cancellationToken);
208            } else {
209              tickets.Release();
210            }
211          }
212        }
213      }
214
215      for (int i = 0; i < tickets.CurrentCount; i++) tickets.Release();
216
217      // signal that we are done ...
218      NotifyEvaluation(new RealVector(), double.NaN); // TODO: improve this
219    }
220
221    #region Helpers
222    private enum EvaluationSolutionType { Exploration, Exploitation }
223
224    public void EnqueueForExploitation(RealVector realVector, double priority) {
225      lock (exploitationQueue) exploitationQueue.Enqueue(priority, realVector);
226    }
227
228    public void EnqueueForExploration(RealVector realVector, double priority) {
229      lock (explorationQueue) explorationQueue.Enqueue(priority, realVector);
230    }
231
232    private async Task DoFullEvaluationAsync(RealVector realVector, EvaluationSolutionType evaluationSolutionType, CancellationToken cancellationToken) {
233      var quality = double.NaN;
234
235      try {
236        await Task.Delay(AdditionalEvaluationDurationParameter.Value.Value, cancellationToken);
237        quality = Problem.Evaluator.Evaluate(realVector);
238      } catch (OperationCanceledException) {
239      } finally {
240        tickets.Release();
241      }
242
243      if (EvaluatedSolutionsResult.Value >= MaxEvaluations.Value) return;
244
245      lock (locker) {
246        EvaluatedSolutionsResult.Value++;
247
248        var explorationValues = FullEvaluationQualitiesResult.Rows["Exploration"].Values;
249        var actualExplorationValues = explorationValues.Where(x => !double.IsNaN(x));
250        var minExplorationValue = actualExplorationValues.Where(x => !double.IsNaN(x)).Any() ? actualExplorationValues.Min() : double.MaxValue;
251        var lastExplorationValue = actualExplorationValues.Any() ? actualExplorationValues.Last() : double.NaN;
252
253        var exploitationValues = FullEvaluationQualitiesResult.Rows["Exploitation"].Values;
254        var actualExploitationValues = exploitationValues.Where(x => !double.IsNaN(x));
255        var minExploitationValue = actualExploitationValues.Any() ? actualExploitationValues.Min() : double.MaxValue;
256        var lastExploitationValue = actualExploitationValues.Any() ? actualExploitationValues.Last() : double.NaN;
257
258        switch (evaluationSolutionType) {
259          case EvaluationSolutionType.Exploration:
260            EvaluatedExplorationSolutionsResult.Add(realVector);
261            explorationValues.Add(quality);
262            exploitationValues.Add(lastExploitationValue);
263            break;
264          case EvaluationSolutionType.Exploitation:
265            EvaluatedExploitationSolutionsResult.Add(realVector);
266            exploitationValues.Add(quality);
267            explorationValues.Add(lastExplorationValue);
268            break;
269          default: throw new ArgumentOutOfRangeException();
270        }
271
272        var minEvaluatedValue = Math.Min(minExplorationValue, minExploitationValue);
273        var minValue = Math.Min(minEvaluatedValue, quality);
274
275        FullEvaluationQualitiesResult.Rows["Best"].Values.Add(minValue);
276      }
277
278      NotifyEvaluation((RealVector)realVector.Clone(), quality);
279    }
280    #endregion Helpers
281  }
282}
Note: See TracBrowser for help on using the repository browser.