Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Grammars/TypeCoherentVectorExpressionGrammar.cs

Last change on this file was 18060, checked in by pfleck, 3 years ago

#3040 Added a subvector symbol with ranges as subtrees.

File size: 13.1 KB
RevLine 
[5333]1#region License Information
2/* HeuristicLab
[17180]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5333]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[5393]22using System.Collections.Generic;
[5333]23using System.Linq;
[17456]24using HEAL.Attic;
[5333]25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[17460]28using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
29
[5333]30namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
[17460]31  [StorableType("7EC7B4A7-0E27-4011-B983-B0E15A6944EC")]
32  [Item("TypeCoherentVectorExpressionGrammar", "Represents a grammar for functional expressions in which special syntactic constraints are enforced so that vector and scalar expressions are not mixed.")]
33  public class TypeCoherentVectorExpressionGrammar : DataAnalysisGrammar, ISymbolicDataAnalysisGrammar {
[6803]34    private const string ArithmeticFunctionsName = "Arithmetic Functions";
35    private const string TrigonometricFunctionsName = "Trigonometric Functions";
36    private const string ExponentialFunctionsName = "Exponential and Logarithmic Functions";
[17463]37    private const string PowerFunctionsName = "Power Functions";
[6803]38    private const string TerminalsName = "Terminals";
[17463]39    private const string VectorAggregationName = "Aggregations";
[17554]40    private const string VectorStatisticsName = "Vector Statistics";
41    private const string VectorDistancesName = "Vector Distances";
[17463]42    private const string ScalarSymbolsName = "Scalar Symbols";
[5333]43
[17463]44    private const string VectorArithmeticFunctionsName = "Vector Arithmetic Functions";
45    private const string VectorTrigonometricFunctionsName = "Vector Trigonometric Functions";
46    private const string VectorExponentialFunctionsName = "Vector Exponential and Logarithmic Functions";
47    private const string VectorPowerFunctionsName = "Vector Power Functions";
48    private const string VectorTerminalsName = "Vector Terminals";
49    private const string VectorSymbolsName = "Vector Symbols";
50
[17752]51    private const string VectorManipulationSymbolsName = "Vector Manipulation Symbols";
[18060]52    private const string VectorSubVectorSymbolsName = "Vector SubVector Symbols";
[17752]53
[17463]54    private const string RealValuedSymbolsName = "Real Valued Symbols";
55
[5333]56    [StorableConstructor]
[17460]57    protected TypeCoherentVectorExpressionGrammar(StorableConstructorFlag _) : base(_) { }
58    protected TypeCoherentVectorExpressionGrammar(TypeCoherentVectorExpressionGrammar original, Cloner cloner) : base(original, cloner) { }
59    public TypeCoherentVectorExpressionGrammar()
60      : base(ItemAttribute.GetName(typeof(TypeCoherentVectorExpressionGrammar)), ItemAttribute.GetDescription(typeof(TypeCoherentVectorExpressionGrammar))) {
[5333]61      Initialize();
62    }
63    public override IDeepCloneable Clone(Cloner cloner) {
[17460]64      return new TypeCoherentVectorExpressionGrammar(this, cloner);
[5333]65    }
66
67    private void Initialize() {
[17463]68      #region scalar symbol declaration
[5333]69      var add = new Addition();
70      var sub = new Subtraction();
71      var mul = new Multiplication();
72      var div = new Division();
[17463]73
[5333]74      var sin = new Sine();
75      var cos = new Cosine();
76      var tan = new Tangent();
[17463]77
78      var exp = new Exponential();
[5333]79      var log = new Logarithm();
[17463]80
[7695]81      var square = new Square();
82      var sqrt = new SquareRoot();
[16356]83      var cube = new Cube();
84      var cubeRoot = new CubeRoot();
[17463]85      var power = new Power();
86      var root = new Root();
[7696]87
[17463]88      var constant = new Constant { MinValue = -20, MaxValue = 20 };
89      var variable = new Variable();
90      var binFactorVariable = new BinaryFactorVariable();
91      var factorVariable = new FactorVariable();
[7696]92
[17466]93      var mean = new Mean();
94      var sd = new StandardDeviation();
[17463]95      var sum = new Sum();
[17554]96      var length = new Length() { Enabled = false };
97      var min = new Min() { Enabled = false };
98      var max = new Max() { Enabled = false };
99      var variance = new Variance() { Enabled = false };
100      var skewness = new Skewness() { Enabled = false };
101      var kurtosis = new Kurtosis() { Enabled = false };
102      var euclideanDistance = new EuclideanDistance() { Enabled = false };
103      var covariance = new Covariance() { Enabled = false };
[17463]104      #endregion
[5393]105
[17463]106      #region vector symbol declaration
107      var vectoradd = new Addition() { Name = "Vector Addition" };
108      var vectorsub = new Subtraction() { Name = "Vector Subtraction" };
109      var vectormul = new Multiplication() { Name = "Vector Multiplication" };
110      var vectordiv = new Division() { Name = "Vector Division" };
[5393]111
[17463]112      var vectorsin = new Sine() { Name = "Vector Sine" };
113      var vectorcos = new Cosine() { Name = "Vector Cosine" };
114      var vectortan = new Tangent() { Name = "Vector Tangent" };
115
116      var vectorexp = new Exponential() { Name = "Vector Exponential" };
117      var vectorlog = new Logarithm() { Name = "Vector Logarithm" };
118
119      var vectorsquare = new Square() { Name = "Vector Square" };
120      var vectorsqrt = new SquareRoot() { Name = "Vector SquareRoot" };
121      var vectorcube = new Cube() { Name = "Vector Cube" };
122      var vectorcubeRoot = new CubeRoot() { Name = "Vector CubeRoot" };
123      var vectorpower = new Power() { Name = "Vector Power" };
124      var vectorroot = new Root() { Name = "Vector Root" };
125
126      var vectorvariable = new Variable() { Name = "Vector Variable" };
[6803]127      #endregion
[5333]128
[17752]129      #region vector manipulation symbol declaration
[18060]130      var subvectorLocal = new SubVector();
131      var subvectorSubtree = new SubVectorSubtree();
[17752]132      #endregion
133
[6803]134      #region group symbol declaration
[17463]135      var arithmeticSymbols = new GroupSymbol(ArithmeticFunctionsName, new List<ISymbol>() { add, sub, mul, div });
136      var trigonometricSymbols = new GroupSymbol(TrigonometricFunctionsName, new List<ISymbol>() { sin, cos, tan });
[17369]137      var exponentialAndLogarithmicSymbols = new GroupSymbol(ExponentialFunctionsName, new List<ISymbol> { exp, log });
[17463]138      var powerSymbols = new GroupSymbol(PowerFunctionsName, new List<ISymbol> { square, sqrt, cube, cubeRoot, power, root });
139      var terminalSymbols = new GroupSymbol(TerminalsName, new List<ISymbol> { constant, variable, binFactorVariable, factorVariable });
[17604]140      var statisticsSymbols = new GroupSymbol(VectorStatisticsName, new List<ISymbol> { mean, sd, sum, length, min, max, variance, skewness, kurtosis });
[17554]141      var distancesSymbols = new GroupSymbol(VectorDistancesName, new List<ISymbol> { euclideanDistance, covariance });
142      var aggregationSymbols = new GroupSymbol(VectorAggregationName, new List<ISymbol> { statisticsSymbols, distancesSymbols });
[17463]143      var scalarSymbols = new GroupSymbol(ScalarSymbolsName, new List<ISymbol>() { arithmeticSymbols, trigonometricSymbols, exponentialAndLogarithmicSymbols, powerSymbols, terminalSymbols, aggregationSymbols });
[5333]144
[17463]145      var vectorarithmeticSymbols = new GroupSymbol(VectorArithmeticFunctionsName, new List<ISymbol>() { vectoradd, vectorsub, vectormul, vectordiv });
146      var vectortrigonometricSymbols = new GroupSymbol(VectorTrigonometricFunctionsName, new List<ISymbol>() { vectorsin, vectorcos, vectortan });
147      var vectorexponentialAndLogarithmicSymbols = new GroupSymbol(VectorExponentialFunctionsName, new List<ISymbol> { vectorexp, vectorlog });
148      var vectorpowerSymbols = new GroupSymbol(VectorPowerFunctionsName, new List<ISymbol> { vectorsquare, vectorsqrt, vectorcube, vectorcubeRoot, vectorpower, vectorroot });
149      var vectorterminalSymbols = new GroupSymbol(VectorTerminalsName, new List<ISymbol> { vectorvariable });
150      var vectorSymbols = new GroupSymbol(VectorSymbolsName, new List<ISymbol>() { vectorarithmeticSymbols, vectortrigonometricSymbols, vectorexponentialAndLogarithmicSymbols, vectorpowerSymbols, vectorterminalSymbols });
[5333]151
[18060]152      var vectorSubVectorSymbols = new GroupSymbol(VectorSubVectorSymbolsName, new List<ISymbol>() { subvectorLocal, subvectorSubtree });
153      var vectorManipulationSymbols = new GroupSymbol(VectorManipulationSymbolsName, new List<ISymbol>() { vectorSubVectorSymbols });
154     
[17752]155
[17463]156      //var realValuedSymbols = new GroupSymbol(RealValuedSymbolsName, new List<ISymbol> { scalarSymbols, vectorSymbols });
[6803]157      #endregion
[5333]158
[17463]159      //AddSymbol(realValuedSymbols);
160      AddSymbol(scalarSymbols);
161      AddSymbol(vectorSymbols);
[17752]162      AddSymbol(vectorManipulationSymbols);
[5333]163
[6803]164      #region subtree count configuration
165      SetSubtreeCount(arithmeticSymbols, 2, 2);
166      SetSubtreeCount(trigonometricSymbols, 1, 1);
[17463]167      SetSubtreeCount(exponentialAndLogarithmicSymbols, 1, 1);
[7695]168      SetSubtreeCount(square, 1, 1);
[17463]169      SetSubtreeCount(sqrt, 1, 1);
[16356]170      SetSubtreeCount(cube, 1, 1);
171      SetSubtreeCount(cubeRoot, 1, 1);
[17463]172      SetSubtreeCount(power, 2, 2);
173      SetSubtreeCount(root, 2, 2);
[6803]174      SetSubtreeCount(exponentialAndLogarithmicSymbols, 1, 1);
175      SetSubtreeCount(terminalSymbols, 0, 0);
[17554]176      SetSubtreeCount(statisticsSymbols, 1, 1);
177      SetSubtreeCount(distancesSymbols, 2, 2);
[5333]178
[17463]179      SetSubtreeCount(vectorarithmeticSymbols, 2, 2);
180      SetSubtreeCount(vectortrigonometricSymbols, 1, 1);
181      SetSubtreeCount(vectorexponentialAndLogarithmicSymbols, 1, 1);
182      SetSubtreeCount(vectorsquare, 1, 1);
183      SetSubtreeCount(vectorsqrt, 1, 1);
184      SetSubtreeCount(vectorcube, 1, 1);
185      SetSubtreeCount(vectorcubeRoot, 1, 1);
186      SetSubtreeCount(vectorpower, 2, 2);
187      SetSubtreeCount(vectorroot, 2, 2);
188      SetSubtreeCount(vectorexponentialAndLogarithmicSymbols, 1, 1);
189      SetSubtreeCount(vectorterminalSymbols, 0, 0);
[17752]190
[18060]191      SetSubtreeCount(subvectorLocal, 1, 1);
192      SetSubtreeCount(subvectorSubtree, 3, 3);
[6803]193      #endregion
[5333]194
[6819]195      #region allowed child symbols configuration
[17463]196      AddAllowedChildSymbol(StartSymbol, scalarSymbols);
[9459]197
[17463]198      AddAllowedChildSymbol(arithmeticSymbols, scalarSymbols);
199      AddAllowedChildSymbol(trigonometricSymbols, scalarSymbols);
200      AddAllowedChildSymbol(exponentialAndLogarithmicSymbols, scalarSymbols);
201      AddAllowedChildSymbol(powerSymbols, scalarSymbols, 0);
202      AddAllowedChildSymbol(power, constant, 1);
203      AddAllowedChildSymbol(root, constant, 1);
204      AddAllowedChildSymbol(aggregationSymbols, vectorSymbols);
[18060]205      AddAllowedChildSymbol(statisticsSymbols, vectorSubVectorSymbols);
[5333]206
[17463]207      AddAllowedChildSymbol(vectorarithmeticSymbols, vectorSymbols);
208      AddAllowedChildSymbol(vectorarithmeticSymbols, scalarSymbols);
209      AddAllowedChildSymbol(vectortrigonometricSymbols, vectorSymbols);
210      AddAllowedChildSymbol(vectorexponentialAndLogarithmicSymbols, vectorSymbols);
211      AddAllowedChildSymbol(vectorpowerSymbols, vectorSymbols, 0);
212      AddAllowedChildSymbol(vectorpower, constant, 1);
213      AddAllowedChildSymbol(vectorroot, constant, 1);
[17752]214
[18060]215      AddAllowedChildSymbol(subvectorLocal, vectorSymbols);
216      AddAllowedChildSymbol(subvectorSubtree, vectorSymbols, 0);
217      AddAllowedChildSymbol(subvectorSubtree, scalarSymbols, 1);
218      AddAllowedChildSymbol(subvectorSubtree, scalarSymbols, 2);
[6803]219      #endregion
[17465]220
221      #region default enabled/disabled
222      var disabledByDefault = new[] {
223        TrigonometricFunctionsName, ExponentialFunctionsName, PowerFunctionsName,
[17752]224        VectorTrigonometricFunctionsName, VectorExponentialFunctionsName, VectorPowerFunctionsName,
225        VectorManipulationSymbolsName
[17465]226      };
227      foreach (var grp in Symbols.Where(sym => disabledByDefault.Contains(sym.Name)))
228        grp.Enabled = false;
229      #endregion
[5333]230    }
[6803]231
[17460]232    public override void ConfigureVariableSymbols(IDataAnalysisProblemData problemData) {
233      base.ConfigureVariableSymbols(problemData);
[6803]234
[17460]235      var dataset = problemData.Dataset;
236      foreach (var varSymbol in Symbols.OfType<VariableBase>().Where(sym => sym.Name == "Variable")) {
237        if (!varSymbol.Fixed) {
238          varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => dataset.VariableHasType<double>(x));
239          varSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => dataset.VariableHasType<double>(x));
[17604]240          varSymbol.VariableDataType = typeof(double);
[17460]241        }
242      }
243      foreach (var varSymbol in Symbols.OfType<VariableBase>().Where(sym => sym.Name == "Vector Variable")) {
244        if (!varSymbol.Fixed) {
245          varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => dataset.VariableHasType<DoubleVector>(x));
246          varSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => dataset.VariableHasType<DoubleVector>(x));
[17604]247          varSymbol.VariableDataType = typeof(DoubleVector);
[17460]248        }
249      }
[6803]250    }
[5333]251  }
252}
Note: See TracBrowser for help on using the repository browser.