Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Grammars/TypeCoherentVectorExpressionGrammar.cs

Last change on this file was 18060, checked in by pfleck, 3 years ago

#3040 Added a subvector symbol with ranges as subtrees.

File size: 13.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HEAL.Attic;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
29
30namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
31  [StorableType("7EC7B4A7-0E27-4011-B983-B0E15A6944EC")]
32  [Item("TypeCoherentVectorExpressionGrammar", "Represents a grammar for functional expressions in which special syntactic constraints are enforced so that vector and scalar expressions are not mixed.")]
33  public class TypeCoherentVectorExpressionGrammar : DataAnalysisGrammar, ISymbolicDataAnalysisGrammar {
34    private const string ArithmeticFunctionsName = "Arithmetic Functions";
35    private const string TrigonometricFunctionsName = "Trigonometric Functions";
36    private const string ExponentialFunctionsName = "Exponential and Logarithmic Functions";
37    private const string PowerFunctionsName = "Power Functions";
38    private const string TerminalsName = "Terminals";
39    private const string VectorAggregationName = "Aggregations";
40    private const string VectorStatisticsName = "Vector Statistics";
41    private const string VectorDistancesName = "Vector Distances";
42    private const string ScalarSymbolsName = "Scalar Symbols";
43
44    private const string VectorArithmeticFunctionsName = "Vector Arithmetic Functions";
45    private const string VectorTrigonometricFunctionsName = "Vector Trigonometric Functions";
46    private const string VectorExponentialFunctionsName = "Vector Exponential and Logarithmic Functions";
47    private const string VectorPowerFunctionsName = "Vector Power Functions";
48    private const string VectorTerminalsName = "Vector Terminals";
49    private const string VectorSymbolsName = "Vector Symbols";
50
51    private const string VectorManipulationSymbolsName = "Vector Manipulation Symbols";
52    private const string VectorSubVectorSymbolsName = "Vector SubVector Symbols";
53
54    private const string RealValuedSymbolsName = "Real Valued Symbols";
55
56    [StorableConstructor]
57    protected TypeCoherentVectorExpressionGrammar(StorableConstructorFlag _) : base(_) { }
58    protected TypeCoherentVectorExpressionGrammar(TypeCoherentVectorExpressionGrammar original, Cloner cloner) : base(original, cloner) { }
59    public TypeCoherentVectorExpressionGrammar()
60      : base(ItemAttribute.GetName(typeof(TypeCoherentVectorExpressionGrammar)), ItemAttribute.GetDescription(typeof(TypeCoherentVectorExpressionGrammar))) {
61      Initialize();
62    }
63    public override IDeepCloneable Clone(Cloner cloner) {
64      return new TypeCoherentVectorExpressionGrammar(this, cloner);
65    }
66
67    private void Initialize() {
68      #region scalar symbol declaration
69      var add = new Addition();
70      var sub = new Subtraction();
71      var mul = new Multiplication();
72      var div = new Division();
73
74      var sin = new Sine();
75      var cos = new Cosine();
76      var tan = new Tangent();
77
78      var exp = new Exponential();
79      var log = new Logarithm();
80
81      var square = new Square();
82      var sqrt = new SquareRoot();
83      var cube = new Cube();
84      var cubeRoot = new CubeRoot();
85      var power = new Power();
86      var root = new Root();
87
88      var constant = new Constant { MinValue = -20, MaxValue = 20 };
89      var variable = new Variable();
90      var binFactorVariable = new BinaryFactorVariable();
91      var factorVariable = new FactorVariable();
92
93      var mean = new Mean();
94      var sd = new StandardDeviation();
95      var sum = new Sum();
96      var length = new Length() { Enabled = false };
97      var min = new Min() { Enabled = false };
98      var max = new Max() { Enabled = false };
99      var variance = new Variance() { Enabled = false };
100      var skewness = new Skewness() { Enabled = false };
101      var kurtosis = new Kurtosis() { Enabled = false };
102      var euclideanDistance = new EuclideanDistance() { Enabled = false };
103      var covariance = new Covariance() { Enabled = false };
104      #endregion
105
106      #region vector symbol declaration
107      var vectoradd = new Addition() { Name = "Vector Addition" };
108      var vectorsub = new Subtraction() { Name = "Vector Subtraction" };
109      var vectormul = new Multiplication() { Name = "Vector Multiplication" };
110      var vectordiv = new Division() { Name = "Vector Division" };
111
112      var vectorsin = new Sine() { Name = "Vector Sine" };
113      var vectorcos = new Cosine() { Name = "Vector Cosine" };
114      var vectortan = new Tangent() { Name = "Vector Tangent" };
115
116      var vectorexp = new Exponential() { Name = "Vector Exponential" };
117      var vectorlog = new Logarithm() { Name = "Vector Logarithm" };
118
119      var vectorsquare = new Square() { Name = "Vector Square" };
120      var vectorsqrt = new SquareRoot() { Name = "Vector SquareRoot" };
121      var vectorcube = new Cube() { Name = "Vector Cube" };
122      var vectorcubeRoot = new CubeRoot() { Name = "Vector CubeRoot" };
123      var vectorpower = new Power() { Name = "Vector Power" };
124      var vectorroot = new Root() { Name = "Vector Root" };
125
126      var vectorvariable = new Variable() { Name = "Vector Variable" };
127      #endregion
128
129      #region vector manipulation symbol declaration
130      var subvectorLocal = new SubVector();
131      var subvectorSubtree = new SubVectorSubtree();
132      #endregion
133
134      #region group symbol declaration
135      var arithmeticSymbols = new GroupSymbol(ArithmeticFunctionsName, new List<ISymbol>() { add, sub, mul, div });
136      var trigonometricSymbols = new GroupSymbol(TrigonometricFunctionsName, new List<ISymbol>() { sin, cos, tan });
137      var exponentialAndLogarithmicSymbols = new GroupSymbol(ExponentialFunctionsName, new List<ISymbol> { exp, log });
138      var powerSymbols = new GroupSymbol(PowerFunctionsName, new List<ISymbol> { square, sqrt, cube, cubeRoot, power, root });
139      var terminalSymbols = new GroupSymbol(TerminalsName, new List<ISymbol> { constant, variable, binFactorVariable, factorVariable });
140      var statisticsSymbols = new GroupSymbol(VectorStatisticsName, new List<ISymbol> { mean, sd, sum, length, min, max, variance, skewness, kurtosis });
141      var distancesSymbols = new GroupSymbol(VectorDistancesName, new List<ISymbol> { euclideanDistance, covariance });
142      var aggregationSymbols = new GroupSymbol(VectorAggregationName, new List<ISymbol> { statisticsSymbols, distancesSymbols });
143      var scalarSymbols = new GroupSymbol(ScalarSymbolsName, new List<ISymbol>() { arithmeticSymbols, trigonometricSymbols, exponentialAndLogarithmicSymbols, powerSymbols, terminalSymbols, aggregationSymbols });
144
145      var vectorarithmeticSymbols = new GroupSymbol(VectorArithmeticFunctionsName, new List<ISymbol>() { vectoradd, vectorsub, vectormul, vectordiv });
146      var vectortrigonometricSymbols = new GroupSymbol(VectorTrigonometricFunctionsName, new List<ISymbol>() { vectorsin, vectorcos, vectortan });
147      var vectorexponentialAndLogarithmicSymbols = new GroupSymbol(VectorExponentialFunctionsName, new List<ISymbol> { vectorexp, vectorlog });
148      var vectorpowerSymbols = new GroupSymbol(VectorPowerFunctionsName, new List<ISymbol> { vectorsquare, vectorsqrt, vectorcube, vectorcubeRoot, vectorpower, vectorroot });
149      var vectorterminalSymbols = new GroupSymbol(VectorTerminalsName, new List<ISymbol> { vectorvariable });
150      var vectorSymbols = new GroupSymbol(VectorSymbolsName, new List<ISymbol>() { vectorarithmeticSymbols, vectortrigonometricSymbols, vectorexponentialAndLogarithmicSymbols, vectorpowerSymbols, vectorterminalSymbols });
151
152      var vectorSubVectorSymbols = new GroupSymbol(VectorSubVectorSymbolsName, new List<ISymbol>() { subvectorLocal, subvectorSubtree });
153      var vectorManipulationSymbols = new GroupSymbol(VectorManipulationSymbolsName, new List<ISymbol>() { vectorSubVectorSymbols });
154     
155
156      //var realValuedSymbols = new GroupSymbol(RealValuedSymbolsName, new List<ISymbol> { scalarSymbols, vectorSymbols });
157      #endregion
158
159      //AddSymbol(realValuedSymbols);
160      AddSymbol(scalarSymbols);
161      AddSymbol(vectorSymbols);
162      AddSymbol(vectorManipulationSymbols);
163
164      #region subtree count configuration
165      SetSubtreeCount(arithmeticSymbols, 2, 2);
166      SetSubtreeCount(trigonometricSymbols, 1, 1);
167      SetSubtreeCount(exponentialAndLogarithmicSymbols, 1, 1);
168      SetSubtreeCount(square, 1, 1);
169      SetSubtreeCount(sqrt, 1, 1);
170      SetSubtreeCount(cube, 1, 1);
171      SetSubtreeCount(cubeRoot, 1, 1);
172      SetSubtreeCount(power, 2, 2);
173      SetSubtreeCount(root, 2, 2);
174      SetSubtreeCount(exponentialAndLogarithmicSymbols, 1, 1);
175      SetSubtreeCount(terminalSymbols, 0, 0);
176      SetSubtreeCount(statisticsSymbols, 1, 1);
177      SetSubtreeCount(distancesSymbols, 2, 2);
178
179      SetSubtreeCount(vectorarithmeticSymbols, 2, 2);
180      SetSubtreeCount(vectortrigonometricSymbols, 1, 1);
181      SetSubtreeCount(vectorexponentialAndLogarithmicSymbols, 1, 1);
182      SetSubtreeCount(vectorsquare, 1, 1);
183      SetSubtreeCount(vectorsqrt, 1, 1);
184      SetSubtreeCount(vectorcube, 1, 1);
185      SetSubtreeCount(vectorcubeRoot, 1, 1);
186      SetSubtreeCount(vectorpower, 2, 2);
187      SetSubtreeCount(vectorroot, 2, 2);
188      SetSubtreeCount(vectorexponentialAndLogarithmicSymbols, 1, 1);
189      SetSubtreeCount(vectorterminalSymbols, 0, 0);
190
191      SetSubtreeCount(subvectorLocal, 1, 1);
192      SetSubtreeCount(subvectorSubtree, 3, 3);
193      #endregion
194
195      #region allowed child symbols configuration
196      AddAllowedChildSymbol(StartSymbol, scalarSymbols);
197
198      AddAllowedChildSymbol(arithmeticSymbols, scalarSymbols);
199      AddAllowedChildSymbol(trigonometricSymbols, scalarSymbols);
200      AddAllowedChildSymbol(exponentialAndLogarithmicSymbols, scalarSymbols);
201      AddAllowedChildSymbol(powerSymbols, scalarSymbols, 0);
202      AddAllowedChildSymbol(power, constant, 1);
203      AddAllowedChildSymbol(root, constant, 1);
204      AddAllowedChildSymbol(aggregationSymbols, vectorSymbols);
205      AddAllowedChildSymbol(statisticsSymbols, vectorSubVectorSymbols);
206
207      AddAllowedChildSymbol(vectorarithmeticSymbols, vectorSymbols);
208      AddAllowedChildSymbol(vectorarithmeticSymbols, scalarSymbols);
209      AddAllowedChildSymbol(vectortrigonometricSymbols, vectorSymbols);
210      AddAllowedChildSymbol(vectorexponentialAndLogarithmicSymbols, vectorSymbols);
211      AddAllowedChildSymbol(vectorpowerSymbols, vectorSymbols, 0);
212      AddAllowedChildSymbol(vectorpower, constant, 1);
213      AddAllowedChildSymbol(vectorroot, constant, 1);
214
215      AddAllowedChildSymbol(subvectorLocal, vectorSymbols);
216      AddAllowedChildSymbol(subvectorSubtree, vectorSymbols, 0);
217      AddAllowedChildSymbol(subvectorSubtree, scalarSymbols, 1);
218      AddAllowedChildSymbol(subvectorSubtree, scalarSymbols, 2);
219      #endregion
220
221      #region default enabled/disabled
222      var disabledByDefault = new[] {
223        TrigonometricFunctionsName, ExponentialFunctionsName, PowerFunctionsName,
224        VectorTrigonometricFunctionsName, VectorExponentialFunctionsName, VectorPowerFunctionsName,
225        VectorManipulationSymbolsName
226      };
227      foreach (var grp in Symbols.Where(sym => disabledByDefault.Contains(sym.Name)))
228        grp.Enabled = false;
229      #endregion
230    }
231
232    public override void ConfigureVariableSymbols(IDataAnalysisProblemData problemData) {
233      base.ConfigureVariableSymbols(problemData);
234
235      var dataset = problemData.Dataset;
236      foreach (var varSymbol in Symbols.OfType<VariableBase>().Where(sym => sym.Name == "Variable")) {
237        if (!varSymbol.Fixed) {
238          varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => dataset.VariableHasType<double>(x));
239          varSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => dataset.VariableHasType<double>(x));
240          varSymbol.VariableDataType = typeof(double);
241        }
242      }
243      foreach (var varSymbol in Symbols.OfType<VariableBase>().Where(sym => sym.Name == "Vector Variable")) {
244        if (!varSymbol.Fixed) {
245          varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => dataset.VariableHasType<DoubleVector>(x));
246          varSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => dataset.VariableHasType<DoubleVector>(x));
247          varSymbol.VariableDataType = typeof(DoubleVector);
248        }
249      }
250    }
251  }
252}
Note: See TracBrowser for help on using the repository browser.