source: branches/HeuristicLab.EvolutionTracking/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Tracking/SchemaDiversification/SchemaCreator.cs @ 15906

Last change on this file since 15906 was 15906, checked in by bburlacu, 4 years ago

#1772: Refactoring and speed optimization of diversification operators

File size: 20.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.EvolutionTracking;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.PluginInfrastructure;
33
34namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
35  using Vertices = IEnumerable<IGenealogyGraphNode<ISymbolicExpressionTree>>;
36
37  [Item("SchemaCreator", "An operator that builds schemas based on the heredity relationship in the genealogy graph.")]
38  [StorableClass]
39  public class SchemaCreator : EvolutionTrackingOperator<ISymbolicExpressionTree> {
40    #region parameter names
41    // criteria to trigger schema-based diversification
42    private const string MinimumSchemaLengthParameterName = "MinimumSchemaLength";
43    private const string MinimumSchemaFrequencyParameterName = "MinimumSchemaFrequency";
44    private const string MinimumPhenotypicSimilarityParameterName = "MinimumPhenotypicSimilarity";
45    // parameters controlling the diversification strategy
46    private const string UseAdaptiveMutationRateParameterName = "UseAdaptiveMutationRate"; // dynamically control mutation rate
47    private const string MutationRateUpdateRuleParameterName = "ReplacementRatioUpdateRule";
48    private const string MutationRateParameterName = "MutationRate"; // fixed mutation rate when not using adaptive update
49    private const string ExclusiveMatchingParameterName = "ExclusiveMatching"; // an individual can belong to only 1 schema
50    private const string ParentsRatio = "ParentsRatio"; // use best % parents to generate schemas from
51    private const string StrictSchemaMatchingParameterName = "StrictSchemaMatching"; // use strict node comparison (for constant values and variable weights)
52    private const string SchemaDefinitionParameterName = "SchemaDefinition"; // which schema definition to use: {=}, {#} or {=,#}
53    private const string SchemaManipulatorParameterName = "SchemaManipulator"; // mutation operator to apply within schemas
54
55    // control parallel behaviour
56    private const string ExecuteInParallelParameterName = "ExecuteInParallel";
57    private const string MaxDegreeOfParalellismParameterName = "MaxDegreeOfParallelism";
58    private const string ScaleEstimatedValuesParameterName = "ScaleEstimatedValues";
59
60    private const string UpdateCounterParameterName = "UpdateCounter";
61    private const string UpdateIntervalParameterName = "UpdateInterval";
62
63    #region information parameters
64    private const string NumberOfChangedTreesParameterName = "NumberOfChangedTrees";
65    private const string NumberOfSchemasParameterName = "NumberOfSchemas";
66    private const string AverageSchemaLengthParameterName = "AverageSchemaLength";
67    #endregion
68    #endregion
69
70    #region parameters
71    public IConstrainedValueParameter<ISymbolicExpressionTreeManipulator> SchemaManipulatorParameter {
72      get { return (IConstrainedValueParameter<ISymbolicExpressionTreeManipulator>)Parameters[SchemaManipulatorParameterName]; }
73    }
74    public IConstrainedValueParameter<StringValue> SchemaDefinitionParameter {
75      get { return (IConstrainedValueParameter<StringValue>)Parameters[SchemaDefinitionParameterName]; }
76    }
77    public IConstrainedValueParameter<StringValue> MutationRateUpdateRuleParameter {
78      get { return (IConstrainedValueParameter<StringValue>)Parameters[MutationRateUpdateRuleParameterName]; }
79    }
80    public IFixedValueParameter<BoolValue> UseAdaptiveMutationRateParameter {
81      get { return (IFixedValueParameter<BoolValue>)Parameters[UseAdaptiveMutationRateParameterName]; }
82    }
83    public IFixedValueParameter<BoolValue> StrictSchemaMatchingParameter {
84      get { return (IFixedValueParameter<BoolValue>)Parameters[StrictSchemaMatchingParameterName]; }
85    }
86    public IFixedValueParameter<BoolValue> ExclusiveMatchingParameter {
87      get { return (IFixedValueParameter<BoolValue>)Parameters[ExclusiveMatchingParameterName]; }
88    }
89    public IFixedValueParameter<BoolValue> ScaleEstimatedValuesParameter {
90      get { return (IFixedValueParameter<BoolValue>)Parameters[ScaleEstimatedValuesParameterName]; }
91    }
92    public IFixedValueParameter<PercentValue> ParentsRatioParameter {
93      get { return (IFixedValueParameter<PercentValue>)Parameters[ParentsRatio]; }
94    }
95    public IFixedValueParameter<IntValue> MinimumSchemaLengthParameter {
96      get { return (IFixedValueParameter<IntValue>)Parameters[MinimumSchemaLengthParameterName]; }
97    }
98    public IFixedValueParameter<BoolValue> ExecuteInParallelParameter {
99      get { return (IFixedValueParameter<BoolValue>)Parameters[ExecuteInParallelParameterName]; }
100    }
101    public IFixedValueParameter<IntValue> MaxDegreeOfParallelismParameter {
102      get { return (IFixedValueParameter<IntValue>)Parameters[MaxDegreeOfParalellismParameterName]; }
103    }
104    public IFixedValueParameter<PercentValue> MinimumSchemaFrequencyParameter {
105      get { return (IFixedValueParameter<PercentValue>)Parameters[MinimumSchemaFrequencyParameterName]; }
106    }
107    public IFixedValueParameter<PercentValue> MinimumPhenotypicSimilarityParameter {
108      get { return (IFixedValueParameter<PercentValue>)Parameters[MinimumPhenotypicSimilarityParameterName]; }
109    }
110    public IFixedValueParameter<PercentValue> MutationRateParameter {
111      get { return (IFixedValueParameter<PercentValue>)Parameters[MutationRateParameterName]; }
112    }
113    public IValueParameter<IntValue> NumberOfSchemasParameter {
114      get { return (IValueParameter<IntValue>)Parameters[NumberOfSchemasParameterName]; }
115    }
116    public IValueParameter<DoubleValue> AverageSchemaLengthParameter {
117      get { return (IValueParameter<DoubleValue>)Parameters[AverageSchemaLengthParameterName]; }
118    }
119    public IValueParameter<IntValue> NumberOfChangedTreesParameter {
120      get { return (IValueParameter<IntValue>)Parameters[NumberOfChangedTreesParameterName]; }
121    }
122    public IFixedValueParameter<IntValue> UpdateCounterParameter {
123      get { return (IFixedValueParameter<IntValue>)Parameters[UpdateCounterParameterName]; }
124    }
125    public IFixedValueParameter<IntValue> UpdateIntervalParameter {
126      get { return (IFixedValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }
127    }
128    #endregion
129
130    #region parameter properties
131    public int MinimumSchemaLength { get { return MinimumSchemaLengthParameter.Value.Value; } }
132    public int MaxDegreeOfParallelism { get { return MaxDegreeOfParallelismParameter.Value.Value; } }
133    public bool ExecuteInParallel { get { return ExecuteInParallelParameter.Value.Value; } }
134    public double PercentageOfPopulation { get { return ParentsRatioParameter.Value.Value; } }
135    public bool StrictSchemaMatching { get { return StrictSchemaMatchingParameter.Value.Value; } }
136    public IntValue UpdateCounter { get { return UpdateCounterParameter.Value; } }
137    public IntValue UpdateInterval { get { return UpdateIntervalParameter.Value; } }
138
139    #endregion
140
141    private UpdateQualityOperator updateQualityOperator;
142    private DiversificationStatisticsOperator diversificationStatisticsOperator;
143
144    public override void InitializeState() {
145      base.InitializeState();
146      NumberOfChangedTreesParameter.Value.Value = 0;
147      NumberOfChangedTreesParameter.Value.Value = 0;
148      AverageSchemaLengthParameter.Value.Value = 0;
149      UpdateCounter.Value = 0;
150    }
151
152    public override void ClearState() {
153      NumberOfChangedTreesParameter.Value.Value = 0;
154      NumberOfChangedTreesParameter.Value.Value = 0;
155      AverageSchemaLengthParameter.Value.Value = 0;
156      UpdateCounter.Value = 0;
157      base.ClearState();
158    }
159
160    public SchemaCreator() {
161      #region add parameters
162      Parameters.Add(new FixedValueParameter<IntValue>(MinimumSchemaLengthParameterName, new IntValue(10)));
163      Parameters.Add(new FixedValueParameter<PercentValue>(MinimumSchemaFrequencyParameterName, new PercentValue(0.01)));
164      Parameters.Add(new FixedValueParameter<PercentValue>(MinimumPhenotypicSimilarityParameterName, new PercentValue(0.9)));
165      Parameters.Add(new FixedValueParameter<PercentValue>(MutationRateParameterName, new PercentValue(0.9)));
166      Parameters.Add(new FixedValueParameter<PercentValue>(ParentsRatio, new PercentValue(1)));
167      Parameters.Add(new FixedValueParameter<BoolValue>(ExecuteInParallelParameterName, new BoolValue(false)));
168      Parameters.Add(new FixedValueParameter<IntValue>(MaxDegreeOfParalellismParameterName, new IntValue(-1)));
169      Parameters.Add(new FixedValueParameter<BoolValue>(ScaleEstimatedValuesParameterName, new BoolValue(true)));
170      Parameters.Add(new FixedValueParameter<BoolValue>(ExclusiveMatchingParameterName, new BoolValue(false)));
171      Parameters.Add(new FixedValueParameter<BoolValue>(StrictSchemaMatchingParameterName, new BoolValue(false)));
172      Parameters.Add(new ValueParameter<IntValue>(NumberOfChangedTreesParameterName, new IntValue(0)));
173      Parameters.Add(new ValueParameter<IntValue>(NumberOfSchemasParameterName, new IntValue(0)));
174      Parameters.Add(new ValueParameter<DoubleValue>(AverageSchemaLengthParameterName, new DoubleValue(0)));
175      Parameters.Add(new FixedValueParameter<BoolValue>(UseAdaptiveMutationRateParameterName, new BoolValue(true)));
176      Parameters.Add(new FixedValueParameter<IntValue>(UpdateCounterParameterName, new IntValue(0)));
177      Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, new IntValue(1)));
178
179      // add update rules
180      var mutationRateUpdateRules = new ItemSet<StringValue>(new[] {
181        new StringValue("f(x) = x"),
182        new StringValue("f(x) = tanh(x)"),
183        new StringValue("f(x) = tanh(2x)"),
184        new StringValue("f(x) = tanh(3x)"),
185        new StringValue("f(x) = tanh(4x)"),
186        new StringValue("f(x) = 1-sqrt(1-x)")
187      });
188      var mutationRateUpdateRuleParameter = new ConstrainedValueParameter<StringValue>(MutationRateUpdateRuleParameterName, mutationRateUpdateRules);
189      mutationRateUpdateRuleParameter.Value = mutationRateUpdateRules.First();
190      Parameters.Add(mutationRateUpdateRuleParameter);
191
192      // add schema definitions
193      var schemaDefinitions = new ItemSet<StringValue>(new[] { "=", "#", "=,#" }.Select(x => new StringValue(x)));
194      var schemaDefinitionParameter = new ConstrainedValueParameter<StringValue>(SchemaDefinitionParameterName, schemaDefinitions);
195      schemaDefinitionParameter.Value = schemaDefinitions.First();
196      Parameters.Add(schemaDefinitionParameter);
197
198      // use a separate set of manipulators in order to allow the user to specify different mutations for schemas
199      // and not be limited to the manipulator parameter for the whole algorithm
200      var manipulators = ApplicationManager.Manager.GetTypes(typeof(ISymbolicExpressionTreeManipulator))
201                                           .Where(x => !typeof(IMultiOperator<ISymbolicExpressionTreeManipulator>).IsAssignableFrom(x)
202                                                       && !typeof(ISymbolicExpressionTreeArchitectureAlteringOperator).IsAssignableFrom(x))
203                                           .Select(x => (ISymbolicExpressionTreeManipulator)Activator.CreateInstance(x)).ToList();
204      manipulators.Add(new MultiSymbolicExpressionTreeManipulator());
205
206      // add individual manipulators
207      var manipulatorParameter = new ConstrainedValueParameter<ISymbolicExpressionTreeManipulator>(SchemaManipulatorParameterName, new ItemSet<ISymbolicExpressionTreeManipulator>(manipulators));
208      // add a multi manipulator as well
209      manipulatorParameter.Value = manipulators.First();
210      Parameters.Add(manipulatorParameter);
211      #endregion
212
213      NumberOfChangedTreesParameter.Hidden = true;
214      NumberOfSchemasParameter.Hidden = true;
215      AverageSchemaLengthParameter.Hidden = true;
216
217      ExecuteInParallelParameter.Hidden = true;
218      MaxDegreeOfParallelismParameter.Hidden = true;
219    }
220
221    protected SchemaCreator(SchemaCreator original, Cloner cloner) : base(original, cloner) { }
222
223    public override IDeepCloneable Clone(Cloner cloner) {
224      return new SchemaCreator(this, cloner);
225    }
226
227    [StorableConstructor]
228    protected SchemaCreator(bool deserializing) : base(deserializing) { }
229
230
231    [StorableHook(HookType.AfterDeserialization)]
232    private void AfterDeserialization() {
233      if (!Parameters.ContainsKey(StrictSchemaMatchingParameterName))
234        Parameters.Add(new FixedValueParameter<BoolValue>(StrictSchemaMatchingParameterName, new BoolValue(false)));
235    }
236
237    public override IOperation Apply() {
238      UpdateCounter.Value++;
239      if (UpdateCounter.Value != UpdateInterval.Value)
240        return base.Apply();
241      UpdateCounter.Value = 0;
242
243      // apply only when at least one generation has passed
244      var gen = Generations.Value;
245      if (gen < 1 || GenealogyGraph == null)
246        return base.Apply();
247
248
249      var updateEstimatedValues = new OperationCollection { Parallel = true };
250      if (updateQualityOperator == null)
251        updateQualityOperator = new UpdateQualityOperator();
252
253      var scope = ExecutionContext.Scope;
254
255      foreach (var s in scope.SubScopes.Where(s => !s.Variables.ContainsKey("EstimatedValues"))) {
256        updateEstimatedValues.Add(ExecutionContext.CreateChildOperation(updateQualityOperator, s));
257      }
258
259      var updateRule = MutationRateUpdateRuleParameter.Value.Value;
260      var schemaManipulator = SchemaManipulatorParameter.Value;
261
262      var evaluateSchemas = new OperationCollection();
263
264      Func<IScope, double> getQuality = s => ((DoubleValue)s.Variables["Quality"].Value).Value;
265
266      var bestN = (int)Math.Round(scope.SubScopes.Count * PercentageOfPopulation);
267      var scopes = new ScopeList(scope.SubScopes.OrderByDescending(getQuality).Take(bestN));
268      // for now, only consider crossover offspring
269      var vertices = from s in scopes
270                     let t = (ISymbolicExpressionTree)s.Variables["SymbolicExpressionTree"].Value
271                     let v = GenealogyGraph.GetByContent(t)
272                     where v.InDegree == 2
273                     select v;
274
275      IEnumerable<ISymbolicExpressionTree> schemas;
276      switch (SchemaDefinitionParameter.Value.Value) {
277        case "=":
278          schemas = GenerateAnyNodeSchemas(vertices, MinimumSchemaLength, 0, StrictSchemaMatching);
279          break;
280        case "#":
281          schemas = GenerateAnySubtreeSchemas(vertices, MinimumSchemaLength, 0, StrictSchemaMatching);
282          break;
283        case "=,#":
284          schemas = GenerateCombinedSchemas(vertices, MinimumSchemaLength, 0, StrictSchemaMatching);
285          break;
286        default:
287          return base.Apply();
288      }
289
290      if (!schemas.Any())
291        return base.Apply();
292
293      #region create schemas and add subscopes representing the individuals
294      double avgLength = 0;
295      int count = 0;
296      foreach (var schema in schemas) {
297        avgLength += schema.Length;
298        ++count;
299        evaluateSchemas.Add(ExecutionContext.CreateChildOperation(new SchemaEvaluator { Schema = schema, MutationRateUpdateRule = updateRule, SchemaManipulator = schemaManipulator }, scope));
300      }
301      avgLength /= count;
302      #endregion
303
304      // set parameters for statistics
305      AverageSchemaLengthParameter.Value = new DoubleValue(avgLength);
306      NumberOfSchemasParameter.Value = new IntValue(count);
307      NumberOfChangedTreesParameter.Value = new IntValue(0);
308
309      if (diversificationStatisticsOperator == null)
310        diversificationStatisticsOperator = new DiversificationStatisticsOperator();
311
312      var calculateStatistics = ExecutionContext.CreateChildOperation(diversificationStatisticsOperator);
313
314      // return an operation collection containing all the scope operations + base.Apply()
315      return new OperationCollection { updateEstimatedValues, evaluateSchemas, calculateStatistics, base.Apply() };
316    }
317
318    #region schema generation
319    public static IEnumerable<ISymbolicExpressionTree> GenerateAnyNodeSchemas(Vertices vertices, int minimumSchemaLength, int minOffspringCount = 1, bool strict = true) {
320      return GenerateSchemas(vertices, ReplaceAnyNode, minimumSchemaLength, minOffspringCount, strict);
321    }
322
323    public static IEnumerable<ISymbolicExpressionTree> GenerateAnySubtreeSchemas(Vertices vertices, int minimumSchemaLength, int minOffspringCount = 1, bool strict = true) {
324      return GenerateSchemas(vertices, ReplaceAnySubtree, minimumSchemaLength, minOffspringCount, strict);
325    }
326
327    public static IEnumerable<ISymbolicExpressionTree> GenerateCombinedSchemas(Vertices vertices, int minimumSchemaLength, int minOffspringCount = 1, bool strict = true) {
328      return GenerateSchemas(vertices, ReplaceCombined, minimumSchemaLength, minOffspringCount, strict);
329    }
330
331    public static IEnumerable<ISymbolicExpressionTree> GenerateSchemas(Vertices vertices, Func<ISymbolicExpressionTree, ISymbolicExpressionTreeNode, int, bool> replaceFunc, int minimumSchemaLength, int minOffspringCount, bool strict = true) {
332      var anySubtreeSymbol = new AnySubtreeSymbol();
333      var groups = vertices.GroupBy(x => x.Parents.First()).Where(g => g.Skip(minOffspringCount - 1).Any()).OrderByDescending(g => g.Count());
334      var hash = new HashSet<string>();
335      foreach (var g in groups) {
336        var parent = g.Key;
337        if (parent.Data.Length < minimumSchemaLength)
338          continue;
339        bool replaced = false;
340        var schema = (ISymbolicExpressionTree)parent.Data.Clone();
341        var nodes = schema.IterateNodesPrefix().ToList();
342        var fragments = g.Select(x => x.InArcs.Last().Data).Where(x => x != null).Cast<IFragment<ISymbolicExpressionTreeNode>>();
343        var indices = fragments.Select(x => x.Index1).Distinct().OrderByDescending(x => schema.Root.GetBranchLevel(nodes[x]));
344        foreach (var i in indices) {
345          replaced |= replaceFunc(schema, nodes[i], minimumSchemaLength);
346        }
347        if (replaced) {
348          var str = schema.Root.GetSubtree(0).GetSubtree(0).FormatToString(strict);
349          if (hash.Add(str))
350            yield return schema;
351        }
352      }
353    }
354
355    // node replacement rules
356    private static bool ReplaceAnyNode(ISymbolicExpressionTree schema, ISymbolicExpressionTreeNode node, int minSchemaLength) {
357      var anyNodeSymbol = new AnyNodeSymbol(node.Symbol.MinimumArity, node.Symbol.MaximumArity);
358      var replacement = anyNodeSymbol.CreateTreeNode();
359      SchemaUtil.ReplaceSubtree(node, replacement, true);
360
361      return true;
362    }
363
364    private static bool ReplaceAnySubtree(ISymbolicExpressionTree schema, ISymbolicExpressionTreeNode node, int minSchemaLength) {
365      if (schema.Length - node.GetLength() + 1 < minSchemaLength)
366        return false;
367
368      var anySubtreeSymbol = new AnySubtreeSymbol();
369      var replacement = anySubtreeSymbol.CreateTreeNode();
370
371      SchemaUtil.ReplaceSubtree(node, replacement, false);
372      return true;
373    }
374
375    private static bool ReplaceCombined(ISymbolicExpressionTree schema, ISymbolicExpressionTreeNode node, int minSchemaLength) {
376      ISymbol wildcard;
377      if (node.SubtreeCount > 0)
378        wildcard = new AnyNodeSymbol(node.Symbol.MinimumArity, node.Symbol.MaximumArity);
379      else
380        wildcard = new AnySubtreeSymbol();
381
382      SchemaUtil.ReplaceSubtree(node, wildcard.CreateTreeNode(), node.SubtreeCount > 0);
383      return true;
384    }
385    #endregion
386  }
387}
Note: See TracBrowser for help on using the repository browser.