#region License Information /* HeuristicLab * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using HeuristicLab.Analysis; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Parameters; using HEAL.Attic; using MathNet.Numerics.LinearAlgebra; using MathNet.Numerics.Statistics; using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector; namespace HeuristicLab.Problems.DataAnalysis.Symbolic { [StorableType("DE68A1D9-5AFC-4DDD-AB62-29F3B8FC28E0")] [Item("SymbolicDataAnalysisExpressionTreeVectorInterpreter", "Interpreter for symbolic expression trees including vector arithmetic.")] public class SymbolicDataAnalysisExpressionTreeVectorInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter { private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions"; public override bool CanChangeName { get { return false; } } public override bool CanChangeDescription { get { return false; } } #region parameter properties public IFixedValueParameter EvaluatedSolutionsParameter { get { return (IFixedValueParameter)Parameters[EvaluatedSolutionsParameterName]; } } #endregion #region properties public int EvaluatedSolutions { get { return EvaluatedSolutionsParameter.Value.Value; } set { EvaluatedSolutionsParameter.Value.Value = value; } } #endregion [StorableConstructor] protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(StorableConstructorFlag _) : base(_) { } protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(SymbolicDataAnalysisExpressionTreeVectorInterpreter original, Cloner cloner) : base(original, cloner) { } public override IDeepCloneable Clone(Cloner cloner) { return new SymbolicDataAnalysisExpressionTreeVectorInterpreter(this, cloner); } public SymbolicDataAnalysisExpressionTreeVectorInterpreter() : base("SymbolicDataAnalysisExpressionTreeVectorInterpreter", "Interpreter for symbolic expression trees including vector arithmetic.") { Parameters.Add(new FixedValueParameter(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0))); } protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(string name, string description) : base(name, description) { Parameters.Add(new FixedValueParameter(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0))); } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { } #region IStatefulItem public void InitializeState() { EvaluatedSolutions = 0; } public void ClearState() { } #endregion private readonly object syncRoot = new object(); public IEnumerable GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable rows) { lock (syncRoot) { EvaluatedSolutions++; // increment the evaluated solutions counter } var state = PrepareInterpreterState(tree, dataset); foreach (var rowEnum in rows) { int row = rowEnum; var result = Evaluate(dataset, ref row, state); if (result.IsScalar) yield return result.Scalar; else yield return double.NaN; //if (!result.IsScalar) // throw new InvalidOperationException("Result of the tree is not a scalar."); //yield return result.Scalar; state.Reset(); } } private static InterpreterState PrepareInterpreterState(ISymbolicExpressionTree tree, IDataset dataset) { Instruction[] code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode); int necessaryArgStackSize = 0; foreach (Instruction instr in code) { if (instr.opCode == OpCodes.Variable) { var variableTreeNode = (VariableTreeNode)instr.dynamicNode; if (dataset.VariableHasType(variableTreeNode.VariableName)) instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName); else if (dataset.VariableHasType(variableTreeNode.VariableName)) instr.data = dataset.GetReadOnlyDoubleVectorValues(variableTreeNode.VariableName); else throw new NotSupportedException($"Type of variable {variableTreeNode.VariableName} is not supported."); } else if (instr.opCode == OpCodes.FactorVariable) { var factorTreeNode = instr.dynamicNode as FactorVariableTreeNode; instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName); } else if (instr.opCode == OpCodes.BinaryFactorVariable) { var factorTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode; instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName); } else if (instr.opCode == OpCodes.LagVariable) { var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode; instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName); } else if (instr.opCode == OpCodes.VariableCondition) { var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode; instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName); } else if (instr.opCode == OpCodes.Call) { necessaryArgStackSize += instr.nArguments + 1; } } return new InterpreterState(code, necessaryArgStackSize); } public struct EvaluationResult { public double Scalar { get; } public bool IsScalar => !double.IsNaN(Scalar); public DoubleVector Vector { get; } public bool IsVector => !(Vector.Count == 1 && double.IsNaN(Vector[0])); public bool IsNaN => !IsScalar && !IsVector; public EvaluationResult(double scalar) { Scalar = scalar; Vector = NaNVector; } public EvaluationResult(DoubleVector vector) { if (vector == null) throw new ArgumentNullException(nameof(vector)); Vector = vector; Scalar = double.NaN; } public override string ToString() { if (IsScalar) return Scalar.ToString(); if (IsVector) return Vector.ToVectorString(); return "NaN"; } private static readonly DoubleVector NaNVector = DoubleVector.Build.Dense(1, double.NaN); public static readonly EvaluationResult NaN = new EvaluationResult(double.NaN); } private static EvaluationResult ArithmeticApply(EvaluationResult lhs, EvaluationResult rhs, Func ssFunc = null, Func svFunc = null, Func vsFunc = null, Func vvFunc = null) { if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar)); if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector)); if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar)); if (lhs.IsVector && rhs.IsVector && vvFunc != null) return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector)); return EvaluationResult.NaN; } private static EvaluationResult FunctionApply(EvaluationResult val, Func sFunc = null, Func vFunc = null) { if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar)); if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector)); return EvaluationResult.NaN; } private static EvaluationResult AggregateApply(EvaluationResult val, Func sFunc = null, Func vFunc = null) { if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar)); if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector)); return EvaluationResult.NaN; } public virtual EvaluationResult Evaluate(IDataset dataset, ref int row, InterpreterState state) { Instruction currentInstr = state.NextInstruction(); switch (currentInstr.opCode) { case OpCodes.Add: { var cur = Evaluate(dataset, ref row, state); for (int i = 1; i < currentInstr.nArguments; i++) { var op = Evaluate(dataset, ref row, state); cur = ArithmeticApply(cur, op, (s1, s2) => s1 + s2, (s1, v2) => s1 + v2, (v1, s2) => v1 + s2, (v1, v2) => v1 + v2); } return cur; } case OpCodes.Sub: { var cur = Evaluate(dataset, ref row, state); for (int i = 1; i < currentInstr.nArguments; i++) { var op = Evaluate(dataset, ref row, state); cur = ArithmeticApply(cur, op, (s1, s2) => s1 - s2, (s1, v2) => s1 - v2, (v1, s2) => v1 - s2, (v1, v2) => v1 - v2); } return cur; } case OpCodes.Mul: { var cur = Evaluate(dataset, ref row, state); for (int i = 1; i < currentInstr.nArguments; i++) { var op = Evaluate(dataset, ref row, state); cur = ArithmeticApply(cur, op, (s1, s2) => s1 * s2, (s1, v2) => s1 * v2, (v1, s2) => v1 * s2, (v1, v2) => v1.PointwiseMultiply(v2)); } return cur; } case OpCodes.Div: { var cur = Evaluate(dataset, ref row, state); for (int i = 1; i < currentInstr.nArguments; i++) { var op = Evaluate(dataset, ref row, state); cur = ArithmeticApply(cur, op, (s1, s2) => s1 / s2, (s1, v2) => s1 / v2, (v1, s2) => v1 / s2, (v1, v2) => v1 / v2); } return cur; } case OpCodes.Absolute: { var cur = Evaluate(dataset, ref row, state); return FunctionApply(cur, Math.Abs, DoubleVector.Abs); } case OpCodes.Tanh: { var cur = Evaluate(dataset, ref row, state); return FunctionApply(cur, Math.Tanh, DoubleVector.Tanh); } case OpCodes.Cos: { var cur = Evaluate(dataset, ref row, state); return FunctionApply(cur, Math.Cos, DoubleVector.Cos); } case OpCodes.Sin: { var cur = Evaluate(dataset, ref row, state); return FunctionApply(cur, Math.Sin, DoubleVector.Sin); } case OpCodes.Tan: { var cur = Evaluate(dataset, ref row, state); return FunctionApply(cur, Math.Tan, DoubleVector.Tan); } case OpCodes.Square: { var cur = Evaluate(dataset, ref row, state); return FunctionApply(cur, s => Math.Pow(s, 2), v => v.PointwisePower(2)); } case OpCodes.Cube: { var cur = Evaluate(dataset, ref row, state); return FunctionApply(cur, s => Math.Pow(s, 3), v => v.PointwisePower(3)); } case OpCodes.Power: { var x = Evaluate(dataset, ref row, state); var y = Evaluate(dataset, ref row, state); return ArithmeticApply(x, y, (s1, s2) => Math.Pow(s1, Math.Round(s2)), (s1, v2) => DoubleVector.Build.Dense(v2.Count, s1).PointwisePower(DoubleVector.Round(v2)), (v1, s2) => v1.PointwisePower(Math.Round(s2)), (v1, v2) => v1.PointwisePower(DoubleVector.Round(v2))); } case OpCodes.SquareRoot: { var cur = Evaluate(dataset, ref row, state); return FunctionApply(cur, s => Math.Sqrt(s), v => DoubleVector.Sqrt(v)); } case OpCodes.CubeRoot: { var cur = Evaluate(dataset, ref row, state); return FunctionApply(cur, s => s < 0 ? -Math.Pow(-s, 1.0 / 3.0) : Math.Pow(s, 1.0 / 3.0), v => v.Map(s => s < 0 ? -Math.Pow(-s, 1.0 / 3.0) : Math.Pow(s, 1.0 / 3.0))); } case OpCodes.Root: { var x = Evaluate(dataset, ref row, state); var y = Evaluate(dataset, ref row, state); return ArithmeticApply(x, y, (s1, s2) => Math.Pow(s1, 1.0 / Math.Round(s2)), (s1, v2) => DoubleVector.Build.Dense(v2.Count, s1).PointwisePower(1.0 / DoubleVector.Round(v2)), (v1, s2) => v1.PointwisePower(1.0 / Math.Round(s2)), (v1, v2) => v1.PointwisePower(1.0 / DoubleVector.Round(v2))); } case OpCodes.Exp: { var cur = Evaluate(dataset, ref row, state); return FunctionApply(cur, s => Math.Exp(s), v => DoubleVector.Exp(v)); } case OpCodes.Log: { var cur = Evaluate(dataset, ref row, state); return FunctionApply(cur, s => Math.Log(s), v => DoubleVector.Log(v)); } case OpCodes.Sum: { var cur = Evaluate(dataset, ref row, state); return AggregateApply(cur, s => s, v => v.Sum()); } case OpCodes.Mean: { var cur = Evaluate(dataset, ref row, state); return AggregateApply(cur, s => s, v => v.Mean()); } case OpCodes.StandardDeviation: { var cur = Evaluate(dataset, ref row, state); return AggregateApply(cur, s => 0, v => v.Count > 1 ? Statistics.StandardDeviation(v) : 0); } case OpCodes.Variable: { if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN; var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode; if (currentInstr.data is IList doubleList) return new EvaluationResult(doubleList[row] * variableTreeNode.Weight); if (currentInstr.data is IList doubleVectorList) return new EvaluationResult(doubleVectorList[row] * variableTreeNode.Weight); throw new NotSupportedException($"Unsupported type of variable: {currentInstr.data.GetType().GetPrettyName()}"); } case OpCodes.BinaryFactorVariable: { if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN; var factorVarTreeNode = currentInstr.dynamicNode as BinaryFactorVariableTreeNode; return new EvaluationResult(((IList)currentInstr.data)[row] == factorVarTreeNode.VariableValue ? factorVarTreeNode.Weight : 0); } case OpCodes.FactorVariable: { if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN; var factorVarTreeNode = currentInstr.dynamicNode as FactorVariableTreeNode; return new EvaluationResult(factorVarTreeNode.GetValue(((IList)currentInstr.data)[row])); } case OpCodes.Constant: { var constTreeNode = (ConstantTreeNode)currentInstr.dynamicNode; return new EvaluationResult(constTreeNode.Value); } default: throw new NotSupportedException($"Unsupported OpCode: {currentInstr.opCode}"); } } } }