Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Grammars/TypeCoherentVectorExpressionGrammar.cs @ 17604

Last change on this file since 17604 was 17604, checked in by pfleck, 4 years ago

#3040 Stores the datatype of a tree node (e.g. variable nodes) in the tree itself for the interpreter to derive the datatypes for subtrees. This way, the interpreter (and simplifier) do not need an actual dataset to figure out datatypes for subtrees.

File size: 12.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HEAL.Attic;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
29
30namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
31  [StorableType("7EC7B4A7-0E27-4011-B983-B0E15A6944EC")]
32  [Item("TypeCoherentVectorExpressionGrammar", "Represents a grammar for functional expressions in which special syntactic constraints are enforced so that vector and scalar expressions are not mixed.")]
33  public class TypeCoherentVectorExpressionGrammar : DataAnalysisGrammar, ISymbolicDataAnalysisGrammar {
34    private const string ArithmeticFunctionsName = "Arithmetic Functions";
35    private const string TrigonometricFunctionsName = "Trigonometric Functions";
36    private const string ExponentialFunctionsName = "Exponential and Logarithmic Functions";
37    private const string PowerFunctionsName = "Power Functions";
38    private const string TerminalsName = "Terminals";
39    private const string VectorAggregationName = "Aggregations";
40    private const string VectorStatisticsName = "Vector Statistics";
41    private const string VectorDistancesName = "Vector Distances";
42    private const string ScalarSymbolsName = "Scalar Symbols";
43
44    private const string VectorArithmeticFunctionsName = "Vector Arithmetic Functions";
45    private const string VectorTrigonometricFunctionsName = "Vector Trigonometric Functions";
46    private const string VectorExponentialFunctionsName = "Vector Exponential and Logarithmic Functions";
47    private const string VectorPowerFunctionsName = "Vector Power Functions";
48    private const string VectorTerminalsName = "Vector Terminals";
49    private const string VectorSymbolsName = "Vector Symbols";
50
51    private const string RealValuedSymbolsName = "Real Valued Symbols";
52
53    [StorableConstructor]
54    protected TypeCoherentVectorExpressionGrammar(StorableConstructorFlag _) : base(_) { }
55    protected TypeCoherentVectorExpressionGrammar(TypeCoherentVectorExpressionGrammar original, Cloner cloner) : base(original, cloner) { }
56    public TypeCoherentVectorExpressionGrammar()
57      : base(ItemAttribute.GetName(typeof(TypeCoherentVectorExpressionGrammar)), ItemAttribute.GetDescription(typeof(TypeCoherentVectorExpressionGrammar))) {
58      Initialize();
59    }
60    public override IDeepCloneable Clone(Cloner cloner) {
61      return new TypeCoherentVectorExpressionGrammar(this, cloner);
62    }
63
64    private void Initialize() {
65      #region scalar symbol declaration
66      var add = new Addition();
67      var sub = new Subtraction();
68      var mul = new Multiplication();
69      var div = new Division();
70
71      var sin = new Sine();
72      var cos = new Cosine();
73      var tan = new Tangent();
74
75      var exp = new Exponential();
76      var log = new Logarithm();
77
78      var square = new Square();
79      var sqrt = new SquareRoot();
80      var cube = new Cube();
81      var cubeRoot = new CubeRoot();
82      var power = new Power();
83      var root = new Root();
84
85      var constant = new Constant { MinValue = -20, MaxValue = 20 };
86      var variable = new Variable();
87      var binFactorVariable = new BinaryFactorVariable();
88      var factorVariable = new FactorVariable();
89
90      var mean = new Mean();
91      var sd = new StandardDeviation();
92      var sum = new Sum();
93      var length = new Length() { Enabled = false };
94      var min = new Min() { Enabled = false };
95      var max = new Max() { Enabled = false };
96      var variance = new Variance() { Enabled = false };
97      var skewness = new Skewness() { Enabled = false };
98      var kurtosis = new Kurtosis() { Enabled = false };
99      var euclideanDistance = new EuclideanDistance() { Enabled = false };
100      var covariance = new Covariance() { Enabled = false };
101      #endregion
102
103      #region vector symbol declaration
104      var vectoradd = new Addition() { Name = "Vector Addition" };
105      var vectorsub = new Subtraction() { Name = "Vector Subtraction" };
106      var vectormul = new Multiplication() { Name = "Vector Multiplication" };
107      var vectordiv = new Division() { Name = "Vector Division" };
108
109      var vectorsin = new Sine() { Name = "Vector Sine" };
110      var vectorcos = new Cosine() { Name = "Vector Cosine" };
111      var vectortan = new Tangent() { Name = "Vector Tangent" };
112
113      var vectorexp = new Exponential() { Name = "Vector Exponential" };
114      var vectorlog = new Logarithm() { Name = "Vector Logarithm" };
115
116      var vectorsquare = new Square() { Name = "Vector Square" };
117      var vectorsqrt = new SquareRoot() { Name = "Vector SquareRoot" };
118      var vectorcube = new Cube() { Name = "Vector Cube" };
119      var vectorcubeRoot = new CubeRoot() { Name = "Vector CubeRoot" };
120      var vectorpower = new Power() { Name = "Vector Power" };
121      var vectorroot = new Root() { Name = "Vector Root" };
122
123      var vectorvariable = new Variable() { Name = "Vector Variable" };
124      #endregion
125
126      #region group symbol declaration
127      var arithmeticSymbols = new GroupSymbol(ArithmeticFunctionsName, new List<ISymbol>() { add, sub, mul, div });
128      var trigonometricSymbols = new GroupSymbol(TrigonometricFunctionsName, new List<ISymbol>() { sin, cos, tan });
129      var exponentialAndLogarithmicSymbols = new GroupSymbol(ExponentialFunctionsName, new List<ISymbol> { exp, log });
130      var powerSymbols = new GroupSymbol(PowerFunctionsName, new List<ISymbol> { square, sqrt, cube, cubeRoot, power, root });
131      var terminalSymbols = new GroupSymbol(TerminalsName, new List<ISymbol> { constant, variable, binFactorVariable, factorVariable });
132      var statisticsSymbols = new GroupSymbol(VectorStatisticsName, new List<ISymbol> { mean, sd, sum, length, min, max, variance, skewness, kurtosis });
133      var distancesSymbols = new GroupSymbol(VectorDistancesName, new List<ISymbol> { euclideanDistance, covariance });
134      var aggregationSymbols = new GroupSymbol(VectorAggregationName, new List<ISymbol> { statisticsSymbols, distancesSymbols });
135      var scalarSymbols = new GroupSymbol(ScalarSymbolsName, new List<ISymbol>() { arithmeticSymbols, trigonometricSymbols, exponentialAndLogarithmicSymbols, powerSymbols, terminalSymbols, aggregationSymbols });
136
137      var vectorarithmeticSymbols = new GroupSymbol(VectorArithmeticFunctionsName, new List<ISymbol>() { vectoradd, vectorsub, vectormul, vectordiv });
138      var vectortrigonometricSymbols = new GroupSymbol(VectorTrigonometricFunctionsName, new List<ISymbol>() { vectorsin, vectorcos, vectortan });
139      var vectorexponentialAndLogarithmicSymbols = new GroupSymbol(VectorExponentialFunctionsName, new List<ISymbol> { vectorexp, vectorlog });
140      var vectorpowerSymbols = new GroupSymbol(VectorPowerFunctionsName, new List<ISymbol> { vectorsquare, vectorsqrt, vectorcube, vectorcubeRoot, vectorpower, vectorroot });
141      var vectorterminalSymbols = new GroupSymbol(VectorTerminalsName, new List<ISymbol> { vectorvariable });
142      var vectorSymbols = new GroupSymbol(VectorSymbolsName, new List<ISymbol>() { vectorarithmeticSymbols, vectortrigonometricSymbols, vectorexponentialAndLogarithmicSymbols, vectorpowerSymbols, vectorterminalSymbols });
143
144      //var realValuedSymbols = new GroupSymbol(RealValuedSymbolsName, new List<ISymbol> { scalarSymbols, vectorSymbols });
145
146
147      #endregion
148
149      //AddSymbol(realValuedSymbols);
150      AddSymbol(scalarSymbols);
151      AddSymbol(vectorSymbols);
152
153      #region subtree count configuration
154      SetSubtreeCount(arithmeticSymbols, 2, 2);
155      SetSubtreeCount(trigonometricSymbols, 1, 1);
156      SetSubtreeCount(exponentialAndLogarithmicSymbols, 1, 1);
157      SetSubtreeCount(square, 1, 1);
158      SetSubtreeCount(sqrt, 1, 1);
159      SetSubtreeCount(cube, 1, 1);
160      SetSubtreeCount(cubeRoot, 1, 1);
161      SetSubtreeCount(power, 2, 2);
162      SetSubtreeCount(root, 2, 2);
163      SetSubtreeCount(exponentialAndLogarithmicSymbols, 1, 1);
164      SetSubtreeCount(terminalSymbols, 0, 0);
165      SetSubtreeCount(statisticsSymbols, 1, 1);
166      SetSubtreeCount(distancesSymbols, 2, 2);
167
168      SetSubtreeCount(vectorarithmeticSymbols, 2, 2);
169      SetSubtreeCount(vectortrigonometricSymbols, 1, 1);
170      SetSubtreeCount(vectorexponentialAndLogarithmicSymbols, 1, 1);
171      SetSubtreeCount(vectorsquare, 1, 1);
172      SetSubtreeCount(vectorsqrt, 1, 1);
173      SetSubtreeCount(vectorcube, 1, 1);
174      SetSubtreeCount(vectorcubeRoot, 1, 1);
175      SetSubtreeCount(vectorpower, 2, 2);
176      SetSubtreeCount(vectorroot, 2, 2);
177      SetSubtreeCount(vectorexponentialAndLogarithmicSymbols, 1, 1);
178      SetSubtreeCount(vectorterminalSymbols, 0, 0);
179      #endregion
180
181      #region allowed child symbols configuration
182      AddAllowedChildSymbol(StartSymbol, scalarSymbols);
183
184      AddAllowedChildSymbol(arithmeticSymbols, scalarSymbols);
185      AddAllowedChildSymbol(trigonometricSymbols, scalarSymbols);
186      AddAllowedChildSymbol(exponentialAndLogarithmicSymbols, scalarSymbols);
187      AddAllowedChildSymbol(powerSymbols, scalarSymbols, 0);
188      AddAllowedChildSymbol(power, constant, 1);
189      AddAllowedChildSymbol(root, constant, 1);
190      AddAllowedChildSymbol(aggregationSymbols, vectorSymbols);
191
192      AddAllowedChildSymbol(vectorarithmeticSymbols, vectorSymbols);
193      AddAllowedChildSymbol(vectorarithmeticSymbols, scalarSymbols);
194      AddAllowedChildSymbol(vectortrigonometricSymbols, vectorSymbols);
195      AddAllowedChildSymbol(vectorexponentialAndLogarithmicSymbols, vectorSymbols);
196      AddAllowedChildSymbol(vectorpowerSymbols, vectorSymbols, 0);
197      AddAllowedChildSymbol(vectorpower, constant, 1);
198      AddAllowedChildSymbol(vectorroot, constant, 1);
199      #endregion
200
201      #region default enabled/disabled
202      var disabledByDefault = new[] {
203        TrigonometricFunctionsName, ExponentialFunctionsName, PowerFunctionsName,
204        VectorTrigonometricFunctionsName, VectorExponentialFunctionsName, VectorPowerFunctionsName
205      };
206      foreach (var grp in Symbols.Where(sym => disabledByDefault.Contains(sym.Name)))
207        grp.Enabled = false;
208      #endregion
209    }
210
211    public override void ConfigureVariableSymbols(IDataAnalysisProblemData problemData) {
212      base.ConfigureVariableSymbols(problemData);
213
214      var dataset = problemData.Dataset;
215      foreach (var varSymbol in Symbols.OfType<VariableBase>().Where(sym => sym.Name == "Variable")) {
216        if (!varSymbol.Fixed) {
217          varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => dataset.VariableHasType<double>(x));
218          varSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => dataset.VariableHasType<double>(x));
219          varSymbol.VariableDataType = typeof(double);
220        }
221      }
222      foreach (var varSymbol in Symbols.OfType<VariableBase>().Where(sym => sym.Name == "Vector Variable")) {
223        if (!varSymbol.Fixed) {
224          varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => dataset.VariableHasType<DoubleVector>(x));
225          varSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => dataset.VariableHasType<DoubleVector>(x));
226          varSymbol.VariableDataType = typeof(DoubleVector);
227        }
228      }
229    }
230  }
231}
Note: See TracBrowser for help on using the repository browser.