Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2825-NSGA3/HeuristicLab.Algorithms.NSGA3/3.3/NSGA3.cs @ 17669

Last change on this file since 17669 was 17669, checked in by dleko, 4 years ago

#2825 Add scatter plot for resulting fitness values.

File size: 16.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Threading;
6using HEAL.Attic;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using HeuristicLab.Encodings.RealVectorEncoding;
11using HeuristicLab.Optimization;
12using HeuristicLab.Parameters;
13using HeuristicLab.Problems.TestFunctions.MultiObjective;
14using HeuristicLab.Random;
15
16namespace HeuristicLab.Algorithms.NSGA3
17{
18    /// <summary>
19    /// The Reference Point Based Non-dominated Sorting Genetic Algorithm III was introduced in Deb
20    /// et al. 2013. An Evolutionary Many-Objective Optimization Algorithm Using Reference Point
21    /// Based Non-dominated Sorting Approach. IEEE Transactions on Evolutionary Computation, 18(4),
22    /// pp. 577-601.
23    /// </summary>
24    [Item("NSGA-III", "The Reference Point Based Non-dominated Sorting Genetic Algorithm III was introduced in Deb et al. 2013. An Evolutionary Many-Objective Optimization Algorithm Using Reference Point Based Non-dominated Sorting Approach. IEEE Transactions on Evolutionary Computation, 18(4), pp. 577-601.")]
25    [Creatable(Category = CreatableAttribute.Categories.PopulationBasedAlgorithms, Priority = 136)]
26    [StorableType("07C745F7-A8A3-4F99-8B2C-F97E639F9AC3")]
27    public class NSGA3 : BasicAlgorithm
28    {
29        public override bool SupportsPause => false;
30
31        #region ProblemProperties
32
33        public override Type ProblemType
34        {
35            get { return typeof(MultiObjectiveBasicProblem<RealVectorEncoding>); }
36        }
37
38        public new MultiObjectiveBasicProblem<RealVectorEncoding> Problem
39        {
40            get { return (MultiObjectiveBasicProblem<RealVectorEncoding>)base.Problem; }
41            set { base.Problem = value; }
42        }
43
44        public int NumberOfObjectives
45        {
46            get
47            {
48                if (!(Problem is MultiObjectiveTestFunctionProblem testFunctionProblem)) throw new NotSupportedException("Only test multi objective test function problems are supported");
49                return testFunctionProblem.Objectives;
50            }
51        }
52
53        #endregion ProblemProperties
54
55        #region Storable fields
56
57        [Storable]
58        private IRandom random;
59
60        [Storable]
61        private List<Solution> solutions; // maybe todo: rename to nextGeneration (see Run method)
62
63        #endregion Storable fields
64
65        #region ParameterAndResultsNames
66
67        // Parameter Names
68
69        private const string SeedName = "Seed";
70        private const string SetSeedRandomlyName = "Set Seed Randomly";
71        private const string PopulationSizeName = "Population Size";
72        private const string CrossoverProbabilityName = "Crossover Probability";
73        private const string CrossoverContiguityName = "Crossover Contiguity";
74        private const string MutationProbabilityName = "Mutation Probability";
75        private const string MaximumGenerationsName = "Maximum Generations";
76        private const string DominateOnEqualQualitiesName = "Dominate On Equal Qualities";
77
78        // Results Names
79
80        private const string GeneratedReferencePointsResultName = "Generated Reference Points";
81        private const string CurrentGenerationResultName = "Generations";
82        private const string ScatterPlotResultName = "Scatter Plot";
83        private const string CurrentFrontResultName = "Pareto Front"; // Do not touch this
84
85        #endregion ParameterAndResultsNames
86
87        #region ParameterProperties
88
89        private IFixedValueParameter<IntValue> SeedParameter
90        {
91            get { return (IFixedValueParameter<IntValue>)Parameters[SeedName]; }
92        }
93
94        private IFixedValueParameter<BoolValue> SetSeedRandomlyParameter
95        {
96            get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyName]; }
97        }
98
99        private IFixedValueParameter<IntValue> PopulationSizeParameter
100        {
101            get { return (IFixedValueParameter<IntValue>)Parameters[PopulationSizeName]; }
102        }
103
104        private IFixedValueParameter<PercentValue> CrossoverProbabilityParameter
105        {
106            get { return (IFixedValueParameter<PercentValue>)Parameters[CrossoverProbabilityName]; }
107        }
108
109        private IFixedValueParameter<DoubleValue> CrossoverContiguityParameter
110        {
111            get { return (IFixedValueParameter<DoubleValue>)Parameters[CrossoverContiguityName]; }
112        }
113
114        private IFixedValueParameter<PercentValue> MutationProbabilityParameter
115        {
116            get { return (IFixedValueParameter<PercentValue>)Parameters[MutationProbabilityName]; }
117        }
118
119        private IFixedValueParameter<IntValue> MaximumGenerationsParameter
120        {
121            get { return (IFixedValueParameter<IntValue>)Parameters[MaximumGenerationsName]; }
122        }
123
124        private IFixedValueParameter<BoolValue> DominateOnEqualQualitiesParameter
125        {
126            get { return (IFixedValueParameter<BoolValue>)Parameters[DominateOnEqualQualitiesName]; }
127        }
128
129        #endregion ParameterProperties
130
131        #region Properties
132
133        public IntValue Seed => SeedParameter.Value;
134
135        public BoolValue SetSeedRandomly => SetSeedRandomlyParameter.Value;
136
137        public IntValue PopulationSize => PopulationSizeParameter.Value;
138
139        public PercentValue CrossoverProbability => CrossoverProbabilityParameter.Value;
140
141        public DoubleValue CrossoverContiguity => CrossoverContiguityParameter.Value;
142
143        public PercentValue MutationProbability => MutationProbabilityParameter.Value;
144
145        public IntValue MaximumGenerations => MaximumGenerationsParameter.Value;
146
147        public BoolValue DominateOnEqualQualities => DominateOnEqualQualitiesParameter.Value;
148
149        public List<List<Solution>> Fronts { get; private set; }
150
151        public List<ReferencePoint> ReferencePoints { get; private set; }
152
153        // todo: create one property for the Generated Reference Points and one for the current
154        // generations reference points
155
156        #endregion Properties
157
158        #region ResultsProperties
159
160        public DoubleMatrix ResultsGeneratedReferencePoints
161        {
162            get { return (DoubleMatrix)Results[GeneratedReferencePointsResultName].Value; }
163            set { Results[GeneratedReferencePointsResultName].Value = value; }
164        }
165
166        public DoubleMatrix ResultsSolutions
167        {
168            get { return (DoubleMatrix)Results[CurrentFrontResultName].Value; }
169            set { Results[CurrentFrontResultName].Value = value; }
170        }
171
172        public IntValue ResultsCurrentGeneration
173        {
174            get { return (IntValue)Results[CurrentGenerationResultName].Value; }
175            set { Results[CurrentGenerationResultName].Value = value; }
176        }
177
178        public ParetoFrontScatterPlot ResultsScatterPlot
179        {
180            get { return (ParetoFrontScatterPlot)Results[ScatterPlotResultName].Value; }
181            set { Results[ScatterPlotResultName].Value = value; }
182        }
183
184        #endregion ResultsProperties
185
186        #region Constructors
187
188        public NSGA3() : base()
189        {
190            Parameters.Add(new FixedValueParameter<IntValue>(SeedName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
191            Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
192            Parameters.Add(new FixedValueParameter<IntValue>(PopulationSizeName, "The size of the population of Individuals.", new IntValue(200)));
193            Parameters.Add(new FixedValueParameter<PercentValue>(CrossoverProbabilityName, "The probability that the crossover operator is applied on two parents.", new PercentValue(0.9)));
194            Parameters.Add(new FixedValueParameter<DoubleValue>(CrossoverContiguityName, "The contiguity value for the Simulated Binary Crossover that specifies how close a child should be to its parents (larger value means closer). The value must be greater than or equal than 0. Typical values are in the range [2;5]."));
195            Parameters.Add(new FixedValueParameter<PercentValue>(MutationProbabilityName, "The probability that the mutation operator is applied on a Individual.", new PercentValue(0.05)));
196            Parameters.Add(new FixedValueParameter<IntValue>(MaximumGenerationsName, "The maximum number of generations which should be processed.", new IntValue(1000)));
197            Parameters.Add(new FixedValueParameter<BoolValue>(DominateOnEqualQualitiesName, "Flag which determines wether Individuals with equal quality values should be treated as dominated.", new BoolValue(false)));
198        }
199
200        // Persistence uses this ctor to improve deserialization efficiency. If we would use the
201        // default ctor instead this would completely initialize the object (e.g. creating
202        // parameters) even though the data is later overwritten by the stored data.
203        [StorableConstructor]
204        public NSGA3(StorableConstructorFlag _) : base(_) { }
205
206        // Each clonable item must have a cloning ctor (deep cloning, the cloner is used to handle
207        // cyclic object references). Don't forget to call the cloning ctor of the base class
208        public NSGA3(NSGA3 original, Cloner cloner) : base(original, cloner)
209        {
210            // todo: don't forget to clone storable fields
211            random = cloner.Clone(original.random);
212            solutions = new List<Solution>(original.solutions?.Select(cloner.Clone));
213        }
214
215        public override IDeepCloneable Clone(Cloner cloner)
216        {
217            return new NSGA3(this, cloner);
218        }
219
220        #endregion Constructors
221
222        #region Initialization
223
224        protected override void Initialize(CancellationToken cancellationToken)
225        {
226            base.Initialize(cancellationToken);
227
228            int numberOfGeneratedReferencePoints = ReferencePoint.GetNumberOfGeneratedReferencePoints(NumberOfObjectives);
229            int pop = ((numberOfGeneratedReferencePoints + 3) / 4) * 4;
230            PopulationSize.Value = pop;
231            InitResults();
232            InitFields();
233            InitReferencePoints();
234            Analyze();
235        }
236
237        private void InitReferencePoints()
238        {
239            // Generate reference points and add them to results
240            ReferencePoints = ReferencePoint.GenerateReferencePoints(random, NumberOfObjectives);
241            ResultsGeneratedReferencePoints = Utility.ConvertToDoubleMatrix(ReferencePoints);
242        }
243
244        private void InitFields()
245        {
246            random = new MersenneTwister();
247            InitSolutions();
248        }
249
250        private void InitSolutions()
251        {
252            int minBound = 0;
253            int maxBound = 1;
254
255            // Initialise solutions
256            solutions = new List<Solution>(PopulationSize.Value);
257            for (int i = 0; i < PopulationSize.Value; i++)
258            {
259                RealVector randomRealVector = new RealVector(Problem.Encoding.Length, random, minBound, maxBound);
260
261                solutions.Add(new Solution(randomRealVector));
262                solutions[i].Fitness = Evaluate(solutions[i].Chromosome);
263            }
264        }
265
266        private void InitResults()
267        {
268            Results.Add(new Result(GeneratedReferencePointsResultName, "The initially generated reference points", new DoubleMatrix()));
269            Results.Add(new Result(CurrentFrontResultName, "The Pareto Front", new DoubleMatrix()));
270            Results.Add(new Result(CurrentGenerationResultName, "The current generation", new IntValue(1)));
271            Results.Add(new Result(ScatterPlotResultName, "A scatterplot displaying the evaluated solutions and (if available) the analytically optimal front", new ParetoFrontScatterPlot()));
272
273            var problem = Problem as MultiObjectiveTestFunctionProblem;
274            if (problem == null) return;
275            // todo: add BestKnownFront parameter
276            ResultsScatterPlot = new ParetoFrontScatterPlot(new double[0][], new double[0][], null, problem.Objectives, problem.ProblemSize);
277        }
278
279        #endregion Initialization
280
281        #region Overriden Methods
282
283        protected override void Run(CancellationToken cancellationToken)
284        {
285            while (ResultsCurrentGeneration.Value < MaximumGenerations.Value)
286            {
287                // create copies of generated reference points (to preserve the original ones for
288                // the next generation) maybe todo: use cloner?
289
290                try
291                {
292                    List<Solution> qt = Mutate(Recombine(solutions));
293                    List<Solution> rt = Utility.Concat(solutions, qt);
294
295                    solutions = NSGA3Selection.SelectSolutionsForNextGeneration(rt, GetCopyOfReferencePoints(), Problem.Maximization, random);
296
297                    ResultsCurrentGeneration.Value++;
298                    Analyze();
299                }
300                catch (Exception ex)
301                {
302                    throw new Exception($"Failed in generation {ResultsCurrentGeneration}", ex);
303                }
304            }
305        }
306
307        #endregion Overriden Methods
308
309        #region Private Methods
310
311        private List<ReferencePoint> GetCopyOfReferencePoints()
312        {
313            if (ReferencePoints == null) return null;
314
315            List<ReferencePoint> referencePoints = new List<ReferencePoint>();
316            foreach (var referencePoint in ReferencePoints)
317                referencePoints.Add(new ReferencePoint(referencePoint));
318
319            return referencePoints;
320        }
321
322        private void Analyze()
323        {
324            ResultsScatterPlot = new ParetoFrontScatterPlot(solutions.Select(x => x.Fitness).ToArray(), solutions.Select(x => x.Chromosome.ToArray()).ToArray(), ResultsScatterPlot.ParetoFront, ResultsScatterPlot.Objectives, ResultsScatterPlot.ProblemSize);
325            ResultsSolutions = solutions.Select(s => s.Chromosome.ToArray()).ToMatrix();
326            Problem.Analyze(
327                solutions.Select(s => (Individual)new SingleEncodingIndividual(Problem.Encoding, new Scope { Variables = { new Variable(Problem.Encoding.Name, s.Chromosome) } })).ToArray(),
328                solutions.Select(s => s.Fitness).ToArray(),
329                Results,
330                random
331                );
332        }
333
334        /// <summary>
335        /// Returns the fitness of the given <paramref name="chromosome" /> by applying the Evaluate
336        /// method of the Problem.
337        /// </summary>
338        /// <param name="chromosome"></param>
339        /// <returns></returns>
340        private double[] Evaluate(RealVector chromosome)
341        {
342            return Problem.Evaluate(new SingleEncodingIndividual(Problem.Encoding, new Scope { Variables = { new Variable(Problem.Encoding.Name, chromosome) } }), random);
343        }
344
345        private List<Solution> Recombine(List<Solution> solutions)
346        {
347            List<Solution> childSolutions = new List<Solution>();
348
349            for (int i = 0; i < solutions.Count; i += 2)
350            {
351                int parentIndex1 = random.Next(solutions.Count);
352                int parentIndex2 = random.Next(solutions.Count);
353                // ensure that the parents are not the same object
354                if (parentIndex1 == parentIndex2) parentIndex2 = (parentIndex2 + 1) % solutions.Count;
355                var parent1 = solutions[parentIndex1];
356                var parent2 = solutions[parentIndex2];
357
358                // Do crossover with crossoverProbabilty == 1 in order to guarantee that a crossover happens
359                var children = SimulatedBinaryCrossover.Apply(random, Problem.Encoding.Bounds, parent1.Chromosome, parent2.Chromosome, 1);
360                Debug.Assert(children != null);
361
362                var child1 = new Solution(children.Item1);
363                var child2 = new Solution(children.Item2);
364                child1.Fitness = Evaluate(child1.Chromosome);
365                child2.Fitness = Evaluate(child1.Chromosome);
366
367                childSolutions.Add(child1);
368                childSolutions.Add(child2);
369            }
370
371            return childSolutions;
372        }
373
374        private List<Solution> Mutate(List<Solution> solutions)
375        {
376            foreach (var solution in solutions)
377            {
378                UniformOnePositionManipulator.Apply(random, solution.Chromosome, Problem.Encoding.Bounds);
379            }
380            return solutions;
381        }
382
383        #endregion Private Methods
384    }
385}
Note: See TracBrowser for help on using the repository browser.