Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Grammars/TypeCoherentVectorExpressionGrammar.cs @ 17554

Last change on this file since 17554 was 17554, checked in by pfleck, 4 years ago

#3040 Added some symbols for statistical aggregation.

File size: 11.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HEAL.Attic;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
29
30namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
31  [StorableType("7EC7B4A7-0E27-4011-B983-B0E15A6944EC")]
32  [Item("TypeCoherentVectorExpressionGrammar", "Represents a grammar for functional expressions in which special syntactic constraints are enforced so that vector and scalar expressions are not mixed.")]
33  public class TypeCoherentVectorExpressionGrammar : DataAnalysisGrammar, ISymbolicDataAnalysisGrammar {
34    private const string ArithmeticFunctionsName = "Arithmetic Functions";
35    private const string TrigonometricFunctionsName = "Trigonometric Functions";
36    private const string ExponentialFunctionsName = "Exponential and Logarithmic Functions";
37    private const string PowerFunctionsName = "Power Functions";
38    private const string TerminalsName = "Terminals";
39    private const string VectorAggregationName = "Aggregations";
40    private const string VectorStatisticsName = "Vector Statistics";
41    private const string VectorDistancesName = "Vector Distances";
42    private const string ScalarSymbolsName = "Scalar Symbols";
43
44    private const string VectorArithmeticFunctionsName = "Vector Arithmetic Functions";
45    private const string VectorTrigonometricFunctionsName = "Vector Trigonometric Functions";
46    private const string VectorExponentialFunctionsName = "Vector Exponential and Logarithmic Functions";
47    private const string VectorPowerFunctionsName = "Vector Power Functions";
48    private const string VectorTerminalsName = "Vector Terminals";
49    private const string VectorSymbolsName = "Vector Symbols";
50
51    private const string RealValuedSymbolsName = "Real Valued Symbols";
52
53    [StorableConstructor]
54    protected TypeCoherentVectorExpressionGrammar(StorableConstructorFlag _) : base(_) { }
55    protected TypeCoherentVectorExpressionGrammar(TypeCoherentVectorExpressionGrammar original, Cloner cloner) : base(original, cloner) { }
56    public TypeCoherentVectorExpressionGrammar()
57      : base(ItemAttribute.GetName(typeof(TypeCoherentVectorExpressionGrammar)), ItemAttribute.GetDescription(typeof(TypeCoherentVectorExpressionGrammar))) {
58      Initialize();
59    }
60    public override IDeepCloneable Clone(Cloner cloner) {
61      return new TypeCoherentVectorExpressionGrammar(this, cloner);
62    }
63
64    private void Initialize() {
65      #region scalar symbol declaration
66      var add = new Addition();
67      var sub = new Subtraction();
68      var mul = new Multiplication();
69      var div = new Division();
70
71      var sin = new Sine();
72      var cos = new Cosine();
73      var tan = new Tangent();
74
75      var exp = new Exponential();
76      var log = new Logarithm();
77
78      var square = new Square();
79      var sqrt = new SquareRoot();
80      var cube = new Cube();
81      var cubeRoot = new CubeRoot();
82      var power = new Power();
83      var root = new Root();
84
85      var constant = new Constant { MinValue = -20, MaxValue = 20 };
86      var variable = new Variable();
87      var binFactorVariable = new BinaryFactorVariable();
88      var factorVariable = new FactorVariable();
89
90      var mean = new Mean();
91      var sd = new StandardDeviation();
92      var sum = new Sum();
93      var length = new Length() { Enabled = false };
94      var min = new Min() { Enabled = false };
95      var max = new Max() { Enabled = false };
96      var variance = new Variance() { Enabled = false };
97      var skewness = new Skewness() { Enabled = false };
98      var kurtosis = new Kurtosis() { Enabled = false };
99      var euclideanDistance = new EuclideanDistance() { Enabled = false };
100      var covariance = new Covariance() { Enabled = false };
101      #endregion
102
103      #region vector symbol declaration
104      var vectoradd = new Addition() { Name = "Vector Addition" };
105      var vectorsub = new Subtraction() { Name = "Vector Subtraction" };
106      var vectormul = new Multiplication() { Name = "Vector Multiplication" };
107      var vectordiv = new Division() { Name = "Vector Division" };
108
109      var vectorsin = new Sine() { Name = "Vector Sine" };
110      var vectorcos = new Cosine() { Name = "Vector Cosine" };
111      var vectortan = new Tangent() { Name = "Vector Tangent" };
112
113      var vectorexp = new Exponential() { Name = "Vector Exponential" };
114      var vectorlog = new Logarithm() { Name = "Vector Logarithm" };
115
116      var vectorsquare = new Square() { Name = "Vector Square" };
117      var vectorsqrt = new SquareRoot() { Name = "Vector SquareRoot" };
118      var vectorcube = new Cube() { Name = "Vector Cube" };
119      var vectorcubeRoot = new CubeRoot() { Name = "Vector CubeRoot" };
120      var vectorpower = new Power() { Name = "Vector Power" };
121      var vectorroot = new Root() { Name = "Vector Root" };
122
123      var vectorvariable = new Variable() { Name = "Vector Variable" };
124      #endregion
125
126      #region group symbol declaration
127      var arithmeticSymbols = new GroupSymbol(ArithmeticFunctionsName, new List<ISymbol>() { add, sub, mul, div });
128      var trigonometricSymbols = new GroupSymbol(TrigonometricFunctionsName, new List<ISymbol>() { sin, cos, tan });
129      var exponentialAndLogarithmicSymbols = new GroupSymbol(ExponentialFunctionsName, new List<ISymbol> { exp, log });
130      var powerSymbols = new GroupSymbol(PowerFunctionsName, new List<ISymbol> { square, sqrt, cube, cubeRoot, power, root });
131      var terminalSymbols = new GroupSymbol(TerminalsName, new List<ISymbol> { constant, variable, binFactorVariable, factorVariable });
132      var statisticsSymbols = new GroupSymbol(VectorStatisticsName, new List<ISymbol> {mean, sd, sum, length, min, max, variance, skewness, kurtosis});
133      var distancesSymbols = new GroupSymbol(VectorDistancesName, new List<ISymbol> { euclideanDistance, covariance });
134      var aggregationSymbols = new GroupSymbol(VectorAggregationName, new List<ISymbol> { statisticsSymbols, distancesSymbols });
135      var scalarSymbols = new GroupSymbol(ScalarSymbolsName, new List<ISymbol>() { arithmeticSymbols, trigonometricSymbols, exponentialAndLogarithmicSymbols, powerSymbols, terminalSymbols, aggregationSymbols });
136
137      var vectorarithmeticSymbols = new GroupSymbol(VectorArithmeticFunctionsName, new List<ISymbol>() { vectoradd, vectorsub, vectormul, vectordiv });
138      var vectortrigonometricSymbols = new GroupSymbol(VectorTrigonometricFunctionsName, new List<ISymbol>() { vectorsin, vectorcos, vectortan });
139      var vectorexponentialAndLogarithmicSymbols = new GroupSymbol(VectorExponentialFunctionsName, new List<ISymbol> { vectorexp, vectorlog });
140      var vectorpowerSymbols = new GroupSymbol(VectorPowerFunctionsName, new List<ISymbol> { vectorsquare, vectorsqrt, vectorcube, vectorcubeRoot, vectorpower, vectorroot });
141      var vectorterminalSymbols = new GroupSymbol(VectorTerminalsName, new List<ISymbol> { vectorvariable });
142      var vectorSymbols = new GroupSymbol(VectorSymbolsName, new List<ISymbol>() { vectorarithmeticSymbols, vectortrigonometricSymbols, vectorexponentialAndLogarithmicSymbols, vectorpowerSymbols, vectorterminalSymbols });
143
144      //var realValuedSymbols = new GroupSymbol(RealValuedSymbolsName, new List<ISymbol> { scalarSymbols, vectorSymbols });
145
146
147      #endregion
148
149      //AddSymbol(realValuedSymbols);
150      AddSymbol(scalarSymbols);
151      AddSymbol(vectorSymbols);
152
153      #region subtree count configuration
154      SetSubtreeCount(arithmeticSymbols, 2, 2);
155      SetSubtreeCount(trigonometricSymbols, 1, 1);
156      SetSubtreeCount(exponentialAndLogarithmicSymbols, 1, 1);
157      SetSubtreeCount(square, 1, 1);
158      SetSubtreeCount(sqrt, 1, 1);
159      SetSubtreeCount(cube, 1, 1);
160      SetSubtreeCount(cubeRoot, 1, 1);
161      SetSubtreeCount(power, 2, 2);
162      SetSubtreeCount(root, 2, 2);
163      SetSubtreeCount(exponentialAndLogarithmicSymbols, 1, 1);
164      SetSubtreeCount(terminalSymbols, 0, 0);
165      SetSubtreeCount(statisticsSymbols, 1, 1);
166      SetSubtreeCount(distancesSymbols, 2, 2);
167
168      SetSubtreeCount(vectorarithmeticSymbols, 2, 2);
169      SetSubtreeCount(vectortrigonometricSymbols, 1, 1);
170      SetSubtreeCount(vectorexponentialAndLogarithmicSymbols, 1, 1);
171      SetSubtreeCount(vectorsquare, 1, 1);
172      SetSubtreeCount(vectorsqrt, 1, 1);
173      SetSubtreeCount(vectorcube, 1, 1);
174      SetSubtreeCount(vectorcubeRoot, 1, 1);
175      SetSubtreeCount(vectorpower, 2, 2);
176      SetSubtreeCount(vectorroot, 2, 2);
177      SetSubtreeCount(vectorexponentialAndLogarithmicSymbols, 1, 1);
178      SetSubtreeCount(vectorterminalSymbols, 0, 0);
179      #endregion
180
181      #region allowed child symbols configuration
182      AddAllowedChildSymbol(StartSymbol, scalarSymbols);
183
184      AddAllowedChildSymbol(arithmeticSymbols, scalarSymbols);
185      AddAllowedChildSymbol(trigonometricSymbols, scalarSymbols);
186      AddAllowedChildSymbol(exponentialAndLogarithmicSymbols, scalarSymbols);
187      AddAllowedChildSymbol(powerSymbols, scalarSymbols, 0);
188      AddAllowedChildSymbol(power, constant, 1);
189      AddAllowedChildSymbol(root, constant, 1);
190      AddAllowedChildSymbol(aggregationSymbols, vectorSymbols);
191
192      AddAllowedChildSymbol(vectorarithmeticSymbols, vectorSymbols);
193      AddAllowedChildSymbol(vectorarithmeticSymbols, scalarSymbols);
194      AddAllowedChildSymbol(vectortrigonometricSymbols, vectorSymbols);
195      AddAllowedChildSymbol(vectorexponentialAndLogarithmicSymbols, vectorSymbols);
196      AddAllowedChildSymbol(vectorpowerSymbols, vectorSymbols, 0);
197      AddAllowedChildSymbol(vectorpower, constant, 1);
198      AddAllowedChildSymbol(vectorroot, constant, 1);
199      #endregion
200
201      #region default enabled/disabled
202      var disabledByDefault = new[] {
203        TrigonometricFunctionsName, ExponentialFunctionsName, PowerFunctionsName,
204        VectorTrigonometricFunctionsName, VectorExponentialFunctionsName, VectorPowerFunctionsName
205      };
206      foreach (var grp in Symbols.Where(sym => disabledByDefault.Contains(sym.Name)))
207        grp.Enabled = false;
208      #endregion
209    }
210
211    public override void ConfigureVariableSymbols(IDataAnalysisProblemData problemData) {
212      base.ConfigureVariableSymbols(problemData);
213
214      var dataset = problemData.Dataset;
215      foreach (var varSymbol in Symbols.OfType<VariableBase>().Where(sym => sym.Name == "Variable")) {
216        if (!varSymbol.Fixed) {
217          varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => dataset.VariableHasType<double>(x));
218          varSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => dataset.VariableHasType<double>(x));
219        }
220      }
221      foreach (var varSymbol in Symbols.OfType<VariableBase>().Where(sym => sym.Name == "Vector Variable")) {
222        if (!varSymbol.Fixed) {
223          varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => dataset.VariableHasType<DoubleVector>(x));
224          varSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => dataset.VariableHasType<DoubleVector>(x));
225        }
226      }
227    }
228  }
229}
Note: See TracBrowser for help on using the repository browser.