Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Grammars/TypeCoherentVectorExpressionGrammar.cs @ 17465

Last change on this file since 17465 was 17465, checked in by pfleck, 4 years ago

#3040 Simplified default vector grammar.

File size: 11.0 KB
RevLine 
[5333]1#region License Information
2/* HeuristicLab
[17180]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5333]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[5393]22using System.Collections.Generic;
[5333]23using System.Linq;
[17456]24using HEAL.Attic;
[5333]25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[17460]28using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
29
[5333]30namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
[17460]31  [StorableType("7EC7B4A7-0E27-4011-B983-B0E15A6944EC")]
32  [Item("TypeCoherentVectorExpressionGrammar", "Represents a grammar for functional expressions in which special syntactic constraints are enforced so that vector and scalar expressions are not mixed.")]
33  public class TypeCoherentVectorExpressionGrammar : DataAnalysisGrammar, ISymbolicDataAnalysisGrammar {
[6803]34    private const string ArithmeticFunctionsName = "Arithmetic Functions";
35    private const string TrigonometricFunctionsName = "Trigonometric Functions";
36    private const string ExponentialFunctionsName = "Exponential and Logarithmic Functions";
[17463]37    private const string PowerFunctionsName = "Power Functions";
[6803]38    private const string TerminalsName = "Terminals";
[17463]39    private const string VectorAggregationName = "Aggregations";
40    private const string ScalarSymbolsName = "Scalar Symbols";
[5333]41
[17463]42    private const string VectorArithmeticFunctionsName = "Vector Arithmetic Functions";
43    private const string VectorTrigonometricFunctionsName = "Vector Trigonometric Functions";
44    private const string VectorExponentialFunctionsName = "Vector Exponential and Logarithmic Functions";
45    private const string VectorPowerFunctionsName = "Vector Power Functions";
46    private const string VectorTerminalsName = "Vector Terminals";
47    private const string VectorSymbolsName = "Vector Symbols";
48
49    private const string RealValuedSymbolsName = "Real Valued Symbols";
50
[5333]51    [StorableConstructor]
[17460]52    protected TypeCoherentVectorExpressionGrammar(StorableConstructorFlag _) : base(_) { }
53    protected TypeCoherentVectorExpressionGrammar(TypeCoherentVectorExpressionGrammar original, Cloner cloner) : base(original, cloner) { }
54    public TypeCoherentVectorExpressionGrammar()
55      : base(ItemAttribute.GetName(typeof(TypeCoherentVectorExpressionGrammar)), ItemAttribute.GetDescription(typeof(TypeCoherentVectorExpressionGrammar))) {
[5333]56      Initialize();
57    }
58    public override IDeepCloneable Clone(Cloner cloner) {
[17460]59      return new TypeCoherentVectorExpressionGrammar(this, cloner);
[5333]60    }
61
62    private void Initialize() {
[17463]63      #region scalar symbol declaration
[5333]64      var add = new Addition();
65      var sub = new Subtraction();
66      var mul = new Multiplication();
67      var div = new Division();
[17463]68
[5333]69      var sin = new Sine();
70      var cos = new Cosine();
71      var tan = new Tangent();
[17463]72
73      var exp = new Exponential();
[5333]74      var log = new Logarithm();
[17463]75
[7695]76      var square = new Square();
77      var sqrt = new SquareRoot();
[16356]78      var cube = new Cube();
79      var cubeRoot = new CubeRoot();
[17463]80      var power = new Power();
81      var root = new Root();
[7696]82
[17463]83      var constant = new Constant { MinValue = -20, MaxValue = 20 };
84      var variable = new Variable();
85      var binFactorVariable = new BinaryFactorVariable();
86      var factorVariable = new FactorVariable();
[7696]87
[17463]88      var sum = new Sum();
89      var mean = new Average { Name = "Mean" };
90      var sd = new StandardDeviation();
91      #endregion
[5393]92
[17463]93      #region vector symbol declaration
94      var vectoradd = new Addition() { Name = "Vector Addition" };
95      var vectorsub = new Subtraction() { Name = "Vector Subtraction" };
96      var vectormul = new Multiplication() { Name = "Vector Multiplication" };
97      var vectordiv = new Division() { Name = "Vector Division" };
[5393]98
[17463]99      var vectorsin = new Sine() { Name = "Vector Sine" };
100      var vectorcos = new Cosine() { Name = "Vector Cosine" };
101      var vectortan = new Tangent() { Name = "Vector Tangent" };
102
103      var vectorexp = new Exponential() { Name = "Vector Exponential" };
104      var vectorlog = new Logarithm() { Name = "Vector Logarithm" };
105
106      var vectorsquare = new Square() { Name = "Vector Square" };
107      var vectorsqrt = new SquareRoot() { Name = "Vector SquareRoot" };
108      var vectorcube = new Cube() { Name = "Vector Cube" };
109      var vectorcubeRoot = new CubeRoot() { Name = "Vector CubeRoot" };
110      var vectorpower = new Power() { Name = "Vector Power" };
111      var vectorroot = new Root() { Name = "Vector Root" };
112
113      var vectorvariable = new Variable() { Name = "Vector Variable" };
[6803]114      #endregion
[5333]115
[6803]116      #region group symbol declaration
[17463]117      var arithmeticSymbols = new GroupSymbol(ArithmeticFunctionsName, new List<ISymbol>() { add, sub, mul, div });
118      var trigonometricSymbols = new GroupSymbol(TrigonometricFunctionsName, new List<ISymbol>() { sin, cos, tan });
[17369]119      var exponentialAndLogarithmicSymbols = new GroupSymbol(ExponentialFunctionsName, new List<ISymbol> { exp, log });
[17463]120      var powerSymbols = new GroupSymbol(PowerFunctionsName, new List<ISymbol> { square, sqrt, cube, cubeRoot, power, root });
121      var terminalSymbols = new GroupSymbol(TerminalsName, new List<ISymbol> { constant, variable, binFactorVariable, factorVariable });
122      var aggregationSymbols = new GroupSymbol(VectorAggregationName, new List<ISymbol> { sum, mean, sd });
123      var scalarSymbols = new GroupSymbol(ScalarSymbolsName, new List<ISymbol>() { arithmeticSymbols, trigonometricSymbols, exponentialAndLogarithmicSymbols, powerSymbols, terminalSymbols, aggregationSymbols });
[5333]124
[17463]125      var vectorarithmeticSymbols = new GroupSymbol(VectorArithmeticFunctionsName, new List<ISymbol>() { vectoradd, vectorsub, vectormul, vectordiv });
126      var vectortrigonometricSymbols = new GroupSymbol(VectorTrigonometricFunctionsName, new List<ISymbol>() { vectorsin, vectorcos, vectortan });
127      var vectorexponentialAndLogarithmicSymbols = new GroupSymbol(VectorExponentialFunctionsName, new List<ISymbol> { vectorexp, vectorlog });
128      var vectorpowerSymbols = new GroupSymbol(VectorPowerFunctionsName, new List<ISymbol> { vectorsquare, vectorsqrt, vectorcube, vectorcubeRoot, vectorpower, vectorroot });
129      var vectorterminalSymbols = new GroupSymbol(VectorTerminalsName, new List<ISymbol> { vectorvariable });
130      var vectorSymbols = new GroupSymbol(VectorSymbolsName, new List<ISymbol>() { vectorarithmeticSymbols, vectortrigonometricSymbols, vectorexponentialAndLogarithmicSymbols, vectorpowerSymbols, vectorterminalSymbols });
[5333]131
[17463]132      //var realValuedSymbols = new GroupSymbol(RealValuedSymbolsName, new List<ISymbol> { scalarSymbols, vectorSymbols });
[5333]133
[17463]134
[6803]135      #endregion
[5333]136
[17463]137      //AddSymbol(realValuedSymbols);
138      AddSymbol(scalarSymbols);
139      AddSymbol(vectorSymbols);
[5333]140
[6803]141      #region subtree count configuration
142      SetSubtreeCount(arithmeticSymbols, 2, 2);
143      SetSubtreeCount(trigonometricSymbols, 1, 1);
[17463]144      SetSubtreeCount(exponentialAndLogarithmicSymbols, 1, 1);
[7695]145      SetSubtreeCount(square, 1, 1);
[17463]146      SetSubtreeCount(sqrt, 1, 1);
[16356]147      SetSubtreeCount(cube, 1, 1);
148      SetSubtreeCount(cubeRoot, 1, 1);
[17463]149      SetSubtreeCount(power, 2, 2);
150      SetSubtreeCount(root, 2, 2);
[6803]151      SetSubtreeCount(exponentialAndLogarithmicSymbols, 1, 1);
152      SetSubtreeCount(terminalSymbols, 0, 0);
[17463]153      SetSubtreeCount(aggregationSymbols, 1, 1);
[5333]154
[17463]155      SetSubtreeCount(vectorarithmeticSymbols, 2, 2);
156      SetSubtreeCount(vectortrigonometricSymbols, 1, 1);
157      SetSubtreeCount(vectorexponentialAndLogarithmicSymbols, 1, 1);
158      SetSubtreeCount(vectorsquare, 1, 1);
159      SetSubtreeCount(vectorsqrt, 1, 1);
160      SetSubtreeCount(vectorcube, 1, 1);
161      SetSubtreeCount(vectorcubeRoot, 1, 1);
162      SetSubtreeCount(vectorpower, 2, 2);
163      SetSubtreeCount(vectorroot, 2, 2);
164      SetSubtreeCount(vectorexponentialAndLogarithmicSymbols, 1, 1);
165      SetSubtreeCount(vectorterminalSymbols, 0, 0);
[6803]166      #endregion
[5333]167
[6819]168      #region allowed child symbols configuration
[17463]169      AddAllowedChildSymbol(StartSymbol, scalarSymbols);
[9459]170
[17463]171      AddAllowedChildSymbol(arithmeticSymbols, scalarSymbols);
172      AddAllowedChildSymbol(trigonometricSymbols, scalarSymbols);
173      AddAllowedChildSymbol(exponentialAndLogarithmicSymbols, scalarSymbols);
174      AddAllowedChildSymbol(powerSymbols, scalarSymbols, 0);
175      AddAllowedChildSymbol(power, constant, 1);
176      AddAllowedChildSymbol(root, constant, 1);
177      AddAllowedChildSymbol(aggregationSymbols, vectorSymbols);
[5333]178
[17463]179      AddAllowedChildSymbol(vectorarithmeticSymbols, vectorSymbols);
180      AddAllowedChildSymbol(vectorarithmeticSymbols, scalarSymbols);
181      AddAllowedChildSymbol(vectortrigonometricSymbols, vectorSymbols);
182      AddAllowedChildSymbol(vectorexponentialAndLogarithmicSymbols, vectorSymbols);
183      AddAllowedChildSymbol(vectorpowerSymbols, vectorSymbols, 0);
184      AddAllowedChildSymbol(vectorpower, constant, 1);
185      AddAllowedChildSymbol(vectorroot, constant, 1);
[6803]186      #endregion
[17465]187
188      #region default enabled/disabled
189      var disabledByDefault = new[] {
190        TrigonometricFunctionsName, ExponentialFunctionsName, PowerFunctionsName,
191        VectorTrigonometricFunctionsName, VectorExponentialFunctionsName, VectorPowerFunctionsName
192      };
193      foreach (var grp in Symbols.Where(sym => disabledByDefault.Contains(sym.Name)))
194        grp.Enabled = false;
195      #endregion
[5333]196    }
[6803]197
[17460]198    public override void ConfigureVariableSymbols(IDataAnalysisProblemData problemData) {
199      base.ConfigureVariableSymbols(problemData);
[6803]200
[17460]201      var dataset = problemData.Dataset;
202      foreach (var varSymbol in Symbols.OfType<VariableBase>().Where(sym => sym.Name == "Variable")) {
203        if (!varSymbol.Fixed) {
204          varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => dataset.VariableHasType<double>(x));
205          varSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => dataset.VariableHasType<double>(x));
206        }
207      }
208      foreach (var varSymbol in Symbols.OfType<VariableBase>().Where(sym => sym.Name == "Vector Variable")) {
209        if (!varSymbol.Fixed) {
210          varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => dataset.VariableHasType<DoubleVector>(x));
211          varSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => dataset.VariableHasType<DoubleVector>(x));
212        }
213      }
[6803]214    }
[5333]215  }
216}
Note: See TracBrowser for help on using the repository browser.