1 | #region License Information
|
---|
2 | /* HeuristicLab
|
---|
3 | * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
|
---|
4 | *
|
---|
5 | * This file is part of HeuristicLab.
|
---|
6 | *
|
---|
7 | * HeuristicLab is free software: you can redistribute it and/or modify
|
---|
8 | * it under the terms of the GNU General Public License as published by
|
---|
9 | * the Free Software Foundation, either version 3 of the License, or
|
---|
10 | * (at your option) any later version.
|
---|
11 | *
|
---|
12 | * HeuristicLab is distributed in the hope that it will be useful,
|
---|
13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
---|
14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
---|
15 | * GNU General Public License for more details.
|
---|
16 | *
|
---|
17 | * You should have received a copy of the GNU General Public License
|
---|
18 | * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
|
---|
19 | */
|
---|
20 | #endregion
|
---|
21 |
|
---|
22 | using HEAL.Attic;
|
---|
23 | using HeuristicLab.Analysis;
|
---|
24 | using HeuristicLab.Common;
|
---|
25 | using HeuristicLab.Core;
|
---|
26 | using HeuristicLab.Data;
|
---|
27 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
28 | using HeuristicLab.Optimization;
|
---|
29 | using HeuristicLab.Parameters;
|
---|
30 | using System;
|
---|
31 | using System.Collections.Generic;
|
---|
32 | using System.Linq;
|
---|
33 |
|
---|
34 | namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
|
---|
35 | /// <summary>
|
---|
36 | /// Calculates the accumulated frequencies of variable-symbols over all trees in the population.
|
---|
37 | /// </summary>
|
---|
38 | [Item("TerminalNodesFrequencyAnalyzer", "Calculates the accumulated frequencies of Terminal Nodes over all trees in the population.")]
|
---|
39 | [StorableType("638ECFFA-B441-4099-AB5F-DFCA1FE41154")]
|
---|
40 | public sealed class TerminalNodesFrequencyAnalyzer : SymbolicDataAnalysisAnalyzer {
|
---|
41 | private const string TerminalNodesFrequencyParameterName = "TerminalNodesFrequency";
|
---|
42 | private const string AggregateTerminalNodesParameterName = "AggregateTerminalNodes";
|
---|
43 |
|
---|
44 | #region parameter properties
|
---|
45 | [Storable]
|
---|
46 | public ILookupParameter<DataTable> TerminalNodesFrequencyParameter {
|
---|
47 | get { return (ILookupParameter<DataTable>)Parameters[TerminalNodesFrequencyParameterName]; }
|
---|
48 | }
|
---|
49 | [Storable]
|
---|
50 | public IValueLookupParameter<BoolValue> AggregateTerminalNodesParameter {
|
---|
51 | get { return (IValueLookupParameter<BoolValue>)Parameters[AggregateTerminalNodesParameterName]; }
|
---|
52 | }
|
---|
53 | #endregion
|
---|
54 | #region properties
|
---|
55 | public BoolValue AggregateTerminalNodes {
|
---|
56 | get { return AggregateTerminalNodesParameter.ActualValue; }
|
---|
57 | set { AggregateTerminalNodesParameter.Value = value; }
|
---|
58 | }
|
---|
59 | public DataTable TerminalNodesFrequency {
|
---|
60 | get { return TerminalNodesFrequencyParameter.ActualValue; }
|
---|
61 | set { TerminalNodesFrequencyParameter.ActualValue = value; }
|
---|
62 | }
|
---|
63 |
|
---|
64 | #endregion
|
---|
65 | [StorableConstructor]
|
---|
66 | private TerminalNodesFrequencyAnalyzer(StorableConstructorFlag _) : base(_) { }
|
---|
67 | private TerminalNodesFrequencyAnalyzer(TerminalNodesFrequencyAnalyzer original, Cloner cloner)
|
---|
68 | : base(original, cloner) {
|
---|
69 | }
|
---|
70 | public TerminalNodesFrequencyAnalyzer()
|
---|
71 | : base() {
|
---|
72 | Parameters.Add(new LookupParameter<DataTable>(TerminalNodesFrequencyParameterName, "The relative Terminal Nodes reference frequencies aggregated over all trees in the population."));
|
---|
73 | Parameters.Add(new ValueLookupParameter<BoolValue>(AggregateTerminalNodesParameterName, "Switch that determines whether all references to factor Terminal Nodes should be aggregated regardless of the value. Turn off to analyze all factor variable references with different values separately.", new BoolValue(true)));
|
---|
74 | }
|
---|
75 |
|
---|
76 | [StorableHook(HookType.AfterDeserialization)]
|
---|
77 | private void AfterDeserialization() {
|
---|
78 | // BackwardsCompatibility3.3
|
---|
79 | #region Backwards compatible code, remove with 3.4
|
---|
80 | if (!Parameters.ContainsKey(AggregateTerminalNodesParameterName)) {
|
---|
81 | Parameters.Add(new ValueLookupParameter<BoolValue>(AggregateTerminalNodesParameterName, "Switch that determines whether all references to factor Terminal Nodes should be aggregated regardless of the value. Turn off to analyze all factor Terminal Nodes references with different values separately.", new BoolValue(true)));
|
---|
82 | }
|
---|
83 | #endregion
|
---|
84 | }
|
---|
85 |
|
---|
86 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
87 | return new TerminalNodesFrequencyAnalyzer(this, cloner);
|
---|
88 | }
|
---|
89 |
|
---|
90 | public override IOperation Apply() {
|
---|
91 | ItemArray<ISymbolicExpressionTree> expressions = SymbolicExpressionTreeParameter.ActualValue;
|
---|
92 | ResultCollection results = ResultCollection;
|
---|
93 | DataTable datatable;
|
---|
94 | if (TerminalNodesFrequencyParameter.ActualValue == null) {
|
---|
95 | datatable = new DataTable("Terminal Nodes frequencies", "Relative frequency of Terminal Nodes references aggregated over the whole population.");
|
---|
96 | datatable.VisualProperties.XAxisTitle = "Generation";
|
---|
97 | datatable.VisualProperties.YAxisTitle = "Relative Model Frequency";
|
---|
98 | TerminalNodesFrequencyParameter.ActualValue = datatable;
|
---|
99 | results.Add(new Result("Terminal Nodes frequencies", "Relative frequency of Terminal Nodes references aggregated over the whole population.", datatable));
|
---|
100 | }
|
---|
101 |
|
---|
102 | datatable = TerminalNodesFrequencyParameter.ActualValue;
|
---|
103 | // all rows must have the same number of values so we can just take the first
|
---|
104 | int numberOfValues = datatable.Rows.Select(r => r.Values.Count).DefaultIfEmpty().First();
|
---|
105 | foreach (var pair in CalculateTerminalNodesFrequency(expressions).OrderByDescending(x => x.Value)) {
|
---|
106 | //var pair in CalculateModelFrequency(expressions).OrderByDescending(x => x.Value).Take(10)
|
---|
107 | if (!datatable.Rows.ContainsKey(pair.Key)) {
|
---|
108 | // initialize a new row for the variable and pad with zeros
|
---|
109 | DataRow row = new DataRow(pair.Key, "", Enumerable.Repeat(0.0, numberOfValues));
|
---|
110 | row.VisualProperties.StartIndexZero = true;
|
---|
111 | datatable.Rows.Add(row);
|
---|
112 | }
|
---|
113 | datatable.Rows[pair.Key].Values.Add(Math.Round(pair.Value, 3));
|
---|
114 | }
|
---|
115 |
|
---|
116 | // add a zero for each data row that was not modified in the previous loop
|
---|
117 | foreach (var row in datatable.Rows.Where(r => r.Values.Count != numberOfValues + 1))
|
---|
118 | row.Values.Add(0.0);
|
---|
119 |
|
---|
120 | return base.Apply();
|
---|
121 | }
|
---|
122 |
|
---|
123 | public static IEnumerable<KeyValuePair<string, double>> CalculateTerminalNodesFrequency(IEnumerable<ISymbolicExpressionTree> trees) {
|
---|
124 | var terminalNodesFrequency = trees
|
---|
125 | .SelectMany(t => GetTerminalNodesReferences(t))
|
---|
126 | .GroupBy(pair => pair.Key, pair => pair.Value)
|
---|
127 | .ToDictionary(g => g.Key, g => (double)g.Sum());
|
---|
128 |
|
---|
129 | double totalNumberOfSymbols = terminalNodesFrequency.Values.Sum();
|
---|
130 |
|
---|
131 | foreach (var pair in terminalNodesFrequency.OrderBy(p => p.Key, new NaturalStringComparer()))
|
---|
132 | yield return new KeyValuePair<string, double>(pair.Key, pair.Value / totalNumberOfSymbols);
|
---|
133 | }
|
---|
134 |
|
---|
135 | private static IEnumerable<KeyValuePair<string, int>> GetTerminalNodesReferences(ISymbolicExpressionTree tree) {
|
---|
136 | Dictionary<string, int> references = new Dictionary<string, int>();
|
---|
137 | foreach (var treeNode in tree.IterateNodesPrefix().OfType<VariableTreeNode>()) {
|
---|
138 | string referenceId = "Variable ";
|
---|
139 | if (references.ContainsKey(referenceId)) {
|
---|
140 | references[referenceId]++;
|
---|
141 | } else {
|
---|
142 | references[referenceId] = 1;
|
---|
143 | }
|
---|
144 | }
|
---|
145 | foreach (var treeNode in tree.IterateNodesPrefix().OfType<TreeModelTreeNode>()) {
|
---|
146 | string referenceId = "Model ";
|
---|
147 | if (references.ContainsKey(referenceId)) {
|
---|
148 | references[referenceId]++;
|
---|
149 | } else {
|
---|
150 | references[referenceId] = 1;
|
---|
151 | }
|
---|
152 | }
|
---|
153 | foreach (var treeNode in tree.IterateNodesPrefix().OfType<ConstantTreeNode>()) {
|
---|
154 | string referenceId = "Constant";
|
---|
155 | if (references.ContainsKey(referenceId)) {
|
---|
156 | references[referenceId]++;
|
---|
157 | } else {
|
---|
158 | references[referenceId] = 1;
|
---|
159 | }
|
---|
160 | }
|
---|
161 | return references;
|
---|
162 | }
|
---|
163 | }
|
---|
164 | }
|
---|