#region License Information /* HeuristicLab * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System.Collections.Generic; using System.Linq; using HEAL.Attic; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector; namespace HeuristicLab.Problems.DataAnalysis.Symbolic { [StorableType("7EC7B4A7-0E27-4011-B983-B0E15A6944EC")] [Item("TypeCoherentVectorExpressionGrammar", "Represents a grammar for functional expressions in which special syntactic constraints are enforced so that vector and scalar expressions are not mixed.")] public class TypeCoherentVectorExpressionGrammar : DataAnalysisGrammar, ISymbolicDataAnalysisGrammar { private const string ArithmeticFunctionsName = "Arithmetic Functions"; private const string TrigonometricFunctionsName = "Trigonometric Functions"; private const string ExponentialFunctionsName = "Exponential and Logarithmic Functions"; private const string PowerFunctionsName = "Power Functions"; private const string TerminalsName = "Terminals"; private const string VectorAggregationName = "Aggregations"; private const string ScalarSymbolsName = "Scalar Symbols"; private const string VectorArithmeticFunctionsName = "Vector Arithmetic Functions"; private const string VectorTrigonometricFunctionsName = "Vector Trigonometric Functions"; private const string VectorExponentialFunctionsName = "Vector Exponential and Logarithmic Functions"; private const string VectorPowerFunctionsName = "Vector Power Functions"; private const string VectorTerminalsName = "Vector Terminals"; private const string VectorSymbolsName = "Vector Symbols"; private const string RealValuedSymbolsName = "Real Valued Symbols"; [StorableConstructor] protected TypeCoherentVectorExpressionGrammar(StorableConstructorFlag _) : base(_) { } protected TypeCoherentVectorExpressionGrammar(TypeCoherentVectorExpressionGrammar original, Cloner cloner) : base(original, cloner) { } public TypeCoherentVectorExpressionGrammar() : base(ItemAttribute.GetName(typeof(TypeCoherentVectorExpressionGrammar)), ItemAttribute.GetDescription(typeof(TypeCoherentVectorExpressionGrammar))) { Initialize(); } public override IDeepCloneable Clone(Cloner cloner) { return new TypeCoherentVectorExpressionGrammar(this, cloner); } private void Initialize() { #region scalar symbol declaration var add = new Addition(); var sub = new Subtraction(); var mul = new Multiplication(); var div = new Division(); var sin = new Sine(); var cos = new Cosine(); var tan = new Tangent(); var exp = new Exponential(); var log = new Logarithm(); var square = new Square(); var sqrt = new SquareRoot(); var cube = new Cube(); var cubeRoot = new CubeRoot(); var power = new Power(); var root = new Root(); var constant = new Constant { MinValue = -20, MaxValue = 20 }; var variable = new Variable(); var binFactorVariable = new BinaryFactorVariable(); var factorVariable = new FactorVariable(); var sum = new Sum(); var mean = new Average { Name = "Mean" }; var sd = new StandardDeviation(); #endregion #region vector symbol declaration var vectoradd = new Addition() { Name = "Vector Addition" }; var vectorsub = new Subtraction() { Name = "Vector Subtraction" }; var vectormul = new Multiplication() { Name = "Vector Multiplication" }; var vectordiv = new Division() { Name = "Vector Division" }; var vectorsin = new Sine() { Name = "Vector Sine" }; var vectorcos = new Cosine() { Name = "Vector Cosine" }; var vectortan = new Tangent() { Name = "Vector Tangent" }; var vectorexp = new Exponential() { Name = "Vector Exponential" }; var vectorlog = new Logarithm() { Name = "Vector Logarithm" }; var vectorsquare = new Square() { Name = "Vector Square" }; var vectorsqrt = new SquareRoot() { Name = "Vector SquareRoot" }; var vectorcube = new Cube() { Name = "Vector Cube" }; var vectorcubeRoot = new CubeRoot() { Name = "Vector CubeRoot" }; var vectorpower = new Power() { Name = "Vector Power" }; var vectorroot = new Root() { Name = "Vector Root" }; var vectorvariable = new Variable() { Name = "Vector Variable" }; #endregion #region group symbol declaration var arithmeticSymbols = new GroupSymbol(ArithmeticFunctionsName, new List() { add, sub, mul, div }); var trigonometricSymbols = new GroupSymbol(TrigonometricFunctionsName, new List() { sin, cos, tan }); var exponentialAndLogarithmicSymbols = new GroupSymbol(ExponentialFunctionsName, new List { exp, log }); var powerSymbols = new GroupSymbol(PowerFunctionsName, new List { square, sqrt, cube, cubeRoot, power, root }); var terminalSymbols = new GroupSymbol(TerminalsName, new List { constant, variable, binFactorVariable, factorVariable }); var aggregationSymbols = new GroupSymbol(VectorAggregationName, new List { sum, mean, sd }); var scalarSymbols = new GroupSymbol(ScalarSymbolsName, new List() { arithmeticSymbols, trigonometricSymbols, exponentialAndLogarithmicSymbols, powerSymbols, terminalSymbols, aggregationSymbols }); var vectorarithmeticSymbols = new GroupSymbol(VectorArithmeticFunctionsName, new List() { vectoradd, vectorsub, vectormul, vectordiv }); var vectortrigonometricSymbols = new GroupSymbol(VectorTrigonometricFunctionsName, new List() { vectorsin, vectorcos, vectortan }); var vectorexponentialAndLogarithmicSymbols = new GroupSymbol(VectorExponentialFunctionsName, new List { vectorexp, vectorlog }); var vectorpowerSymbols = new GroupSymbol(VectorPowerFunctionsName, new List { vectorsquare, vectorsqrt, vectorcube, vectorcubeRoot, vectorpower, vectorroot }); var vectorterminalSymbols = new GroupSymbol(VectorTerminalsName, new List { vectorvariable }); var vectorSymbols = new GroupSymbol(VectorSymbolsName, new List() { vectorarithmeticSymbols, vectortrigonometricSymbols, vectorexponentialAndLogarithmicSymbols, vectorpowerSymbols, vectorterminalSymbols }); //var realValuedSymbols = new GroupSymbol(RealValuedSymbolsName, new List { scalarSymbols, vectorSymbols }); #endregion //AddSymbol(realValuedSymbols); AddSymbol(scalarSymbols); AddSymbol(vectorSymbols); #region subtree count configuration SetSubtreeCount(arithmeticSymbols, 2, 2); SetSubtreeCount(trigonometricSymbols, 1, 1); SetSubtreeCount(exponentialAndLogarithmicSymbols, 1, 1); SetSubtreeCount(square, 1, 1); SetSubtreeCount(sqrt, 1, 1); SetSubtreeCount(cube, 1, 1); SetSubtreeCount(cubeRoot, 1, 1); SetSubtreeCount(power, 2, 2); SetSubtreeCount(root, 2, 2); SetSubtreeCount(exponentialAndLogarithmicSymbols, 1, 1); SetSubtreeCount(terminalSymbols, 0, 0); SetSubtreeCount(aggregationSymbols, 1, 1); SetSubtreeCount(vectorarithmeticSymbols, 2, 2); SetSubtreeCount(vectortrigonometricSymbols, 1, 1); SetSubtreeCount(vectorexponentialAndLogarithmicSymbols, 1, 1); SetSubtreeCount(vectorsquare, 1, 1); SetSubtreeCount(vectorsqrt, 1, 1); SetSubtreeCount(vectorcube, 1, 1); SetSubtreeCount(vectorcubeRoot, 1, 1); SetSubtreeCount(vectorpower, 2, 2); SetSubtreeCount(vectorroot, 2, 2); SetSubtreeCount(vectorexponentialAndLogarithmicSymbols, 1, 1); SetSubtreeCount(vectorterminalSymbols, 0, 0); #endregion #region allowed child symbols configuration AddAllowedChildSymbol(StartSymbol, scalarSymbols); AddAllowedChildSymbol(arithmeticSymbols, scalarSymbols); AddAllowedChildSymbol(trigonometricSymbols, scalarSymbols); AddAllowedChildSymbol(exponentialAndLogarithmicSymbols, scalarSymbols); AddAllowedChildSymbol(powerSymbols, scalarSymbols, 0); AddAllowedChildSymbol(power, constant, 1); AddAllowedChildSymbol(root, constant, 1); AddAllowedChildSymbol(aggregationSymbols, vectorSymbols); AddAllowedChildSymbol(vectorarithmeticSymbols, vectorSymbols); AddAllowedChildSymbol(vectorarithmeticSymbols, scalarSymbols); AddAllowedChildSymbol(vectortrigonometricSymbols, vectorSymbols); AddAllowedChildSymbol(vectorexponentialAndLogarithmicSymbols, vectorSymbols); AddAllowedChildSymbol(vectorpowerSymbols, vectorSymbols, 0); AddAllowedChildSymbol(vectorpower, constant, 1); AddAllowedChildSymbol(vectorroot, constant, 1); #endregion #region default enabled/disabled var disabledByDefault = new[] { TrigonometricFunctionsName, ExponentialFunctionsName, PowerFunctionsName, VectorTrigonometricFunctionsName, VectorExponentialFunctionsName, VectorPowerFunctionsName }; foreach (var grp in Symbols.Where(sym => disabledByDefault.Contains(sym.Name))) grp.Enabled = false; #endregion } public override void ConfigureVariableSymbols(IDataAnalysisProblemData problemData) { base.ConfigureVariableSymbols(problemData); var dataset = problemData.Dataset; foreach (var varSymbol in Symbols.OfType().Where(sym => sym.Name == "Variable")) { if (!varSymbol.Fixed) { varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => dataset.VariableHasType(x)); varSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => dataset.VariableHasType(x)); } } foreach (var varSymbol in Symbols.OfType().Where(sym => sym.Name == "Vector Variable")) { if (!varSymbol.Fixed) { varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => dataset.VariableHasType(x)); varSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => dataset.VariableHasType(x)); } } } } }