Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/09/16 10:49:59 (7 years ago)
Author:
gkronber
Message:

#2697:

  • created a folder for all classes related to transformation from and to trees
  • created a transformator which takes a tree and uses AutoDiff to produce a function and gradient function for the tree.
  • moved code from SymbolicRegressionConstantOptimizationEvaluator to TreeToAutoDiffTermTransformator to make AutoDiff for trees more accessible
File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs

    r14358 r14378  
    2323using System.Collections.Generic;
    2424using System.Linq;
    25 using AutoDiff;
    2625using HeuristicLab.Common;
    2726using HeuristicLab.Core;
     
    153152    }
    154153
    155     #region derivations of functions
    156     // create function factory for arctangent
    157     private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
    158       eval: Math.Atan,
    159       diff: x => 1 / (1 + x * x));
    160     private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
    161       eval: Math.Sin,
    162       diff: Math.Cos);
    163     private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
    164        eval: Math.Cos,
    165        diff: x => -Math.Sin(x));
    166     private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
    167       eval: Math.Tan,
    168       diff: x => 1 + Math.Tan(x) * Math.Tan(x));
    169     private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
    170       eval: alglib.errorfunction,
    171       diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
    172     private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
    173       eval: alglib.normaldistribution,
    174       diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
    175     #endregion
    176 
    177 
    178154    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, bool updateVariableWeights = true, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, bool updateConstantsInTree = true) {
    179155
    180       List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
    181       List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
    182       List<string> variableNames = new List<string>();
    183       List<int> lags = new List<int>();
    184 
    185       AutoDiff.Term func;
    186       if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out func))
     156      string[] variableNames;
     157      int[] lags;
     158      double[] constants;
     159
     160      TreeToAutoDiffTermTransformator.ParametricFunction func;
     161      TreeToAutoDiffTermTransformator.ParametricFunctionGradient func_grad;
     162      if (!TreeToAutoDiffTermTransformator.TryTransformToAutoDiff(tree, updateVariableWeights, out variableNames, out lags, out constants, out func, out func_grad))
    187163        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
    188       if (variableNames.Count == 0) return 0.0;
    189 
    190       AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameters.ToArray());
    191 
    192       List<SymbolicExpressionTreeTerminalNode> terminalNodes = null;
    193       if (updateVariableWeights)
    194         terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
    195       else
    196         terminalNodes = new List<SymbolicExpressionTreeTerminalNode>(tree.Root.IterateNodesPrefix().OfType<ConstantTreeNode>());
     164      if (variableNames.Length == 0) return 0.0;
    197165
    198166      //extract inital constants
    199       double[] c = new double[variables.Count];
    200       {
    201         c[0] = 0.0;
    202         c[1] = 1.0;
    203         int i = 2;
    204         foreach (var node in terminalNodes) {
    205           ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
    206           VariableTreeNode variableTreeNode = node as VariableTreeNode;
    207           if (constantTreeNode != null)
    208             c[i++] = constantTreeNode.Value;
    209           else if (updateVariableWeights && variableTreeNode != null)
    210             c[i++] = variableTreeNode.Weight;
    211         }
    212       }
     167      double[] c = new double[constants.Length + 2];
     168      c[0] = 0.0;
     169      c[1] = 1.0;
     170      Array.Copy(constants, 0, c, 2, constants.Length);
    213171      double[] originalConstants = (double[])c.Clone();
    214172      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
     
    218176      int info;
    219177
     178      // TODO: refactor
    220179      IDataset ds = problemData.Dataset;
    221       double[,] x = new double[rows.Count(), variableNames.Count];
     180      double[,] x = new double[rows.Count(), variableNames.Length];
    222181      int row = 0;
    223182      foreach (var r in rows) {
    224         for (int col = 0; col < variableNames.Count; col++) {
     183        for (int col = 0; col < variableNames.Length; col++) {
    225184          int lag = lags[col];
    226185          x[row, col] = ds.GetDoubleValue(variableNames[col], r + lag);
     
    233192      int k = c.Length;
    234193
    235       alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
    236       alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
     194      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(func);
     195      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(func_grad);
    237196
    238197      try {
     
    272231    }
    273232
    274     private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
    275       return (double[] c, double[] x, ref double func, object o) => {
    276         func = compiledFunc.Evaluate(c, x);
     233    private static alglib.ndimensional_pfunc CreatePFunc(TreeToAutoDiffTermTransformator.ParametricFunction func) {
     234      return (double[] c, double[] x, ref double fx, object o) => {
     235        fx = func(c, x);
    277236      };
    278237    }
    279238
    280     private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
    281       return (double[] c, double[] x, ref double func, double[] grad, object o) => {
    282         var tupel = compiledFunc.Differentiate(c, x);
    283         func = tupel.Item2;
     239    private static alglib.ndimensional_pgrad CreatePGrad(TreeToAutoDiffTermTransformator.ParametricFunctionGradient func_grad) {
     240      return (double[] c, double[] x, ref double fx, double[] grad, object o) => {
     241        var tupel = func_grad(c, x);
     242        fx = tupel.Item2;
    284243        Array.Copy(tupel.Item1, grad, grad.Length);
    285244      };
    286245    }
    287246
    288     private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, List<int> lags, bool updateVariableWeights, out AutoDiff.Term term) {
    289       if (node.Symbol is Constant) {
    290         var var = new AutoDiff.Variable();
    291         variables.Add(var);
    292         term = var;
    293         return true;
    294       }
    295       if (node.Symbol is Variable) {
    296         var varNode = node as VariableTreeNode;
    297         var par = new AutoDiff.Variable();
    298         parameters.Add(par);
    299         variableNames.Add(varNode.VariableName);
    300         lags.Add(0);
    301 
    302         if (updateVariableWeights) {
    303           var w = new AutoDiff.Variable();
    304           variables.Add(w);
    305           term = AutoDiff.TermBuilder.Product(w, par);
    306         } else {
    307           term = varNode.Weight * par;
    308         }
    309         return true;
    310       }
    311       if (node.Symbol is LaggedVariable) {
    312         var varNode = node as LaggedVariableTreeNode;
    313         var par = new AutoDiff.Variable();
    314         parameters.Add(par);
    315         variableNames.Add(varNode.VariableName);
    316         lags.Add(varNode.Lag);
    317 
    318         if (updateVariableWeights) {
    319           var w = new AutoDiff.Variable();
    320           variables.Add(w);
    321           term = AutoDiff.TermBuilder.Product(w, par);
    322         } else {
    323           term = varNode.Weight * par;
    324         }
    325         return true;
    326       }
    327       if (node.Symbol is Addition) {
    328         List<AutoDiff.Term> terms = new List<Term>();
    329         foreach (var subTree in node.Subtrees) {
    330           AutoDiff.Term t;
    331           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    332             term = null;
    333             return false;
    334           }
    335           terms.Add(t);
    336         }
    337         term = AutoDiff.TermBuilder.Sum(terms);
    338         return true;
    339       }
    340       if (node.Symbol is Subtraction) {
    341         List<AutoDiff.Term> terms = new List<Term>();
    342         for (int i = 0; i < node.SubtreeCount; i++) {
    343           AutoDiff.Term t;
    344           if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    345             term = null;
    346             return false;
    347           }
    348           if (i > 0) t = -t;
    349           terms.Add(t);
    350         }
    351         if (terms.Count == 1) term = -terms[0];
    352         else term = AutoDiff.TermBuilder.Sum(terms);
    353         return true;
    354       }
    355       if (node.Symbol is Multiplication) {
    356         List<AutoDiff.Term> terms = new List<Term>();
    357         foreach (var subTree in node.Subtrees) {
    358           AutoDiff.Term t;
    359           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    360             term = null;
    361             return false;
    362           }
    363           terms.Add(t);
    364         }
    365         if (terms.Count == 1) term = terms[0];
    366         else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
    367         return true;
    368 
    369       }
    370       if (node.Symbol is Division) {
    371         List<AutoDiff.Term> terms = new List<Term>();
    372         foreach (var subTree in node.Subtrees) {
    373           AutoDiff.Term t;
    374           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    375             term = null;
    376             return false;
    377           }
    378           terms.Add(t);
    379         }
    380         if (terms.Count == 1) term = 1.0 / terms[0];
    381         else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
    382         return true;
    383       }
    384       if (node.Symbol is Logarithm) {
    385         AutoDiff.Term t;
    386         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    387           term = null;
    388           return false;
    389         } else {
    390           term = AutoDiff.TermBuilder.Log(t);
    391           return true;
    392         }
    393       }
    394       if (node.Symbol is Exponential) {
    395         AutoDiff.Term t;
    396         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    397           term = null;
    398           return false;
    399         } else {
    400           term = AutoDiff.TermBuilder.Exp(t);
    401           return true;
    402         }
    403       }
    404       if (node.Symbol is Square) {
    405         AutoDiff.Term t;
    406         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    407           term = null;
    408           return false;
    409         } else {
    410           term = AutoDiff.TermBuilder.Power(t, 2.0);
    411           return true;
    412         }
    413       }
    414       if (node.Symbol is SquareRoot) {
    415         AutoDiff.Term t;
    416         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    417           term = null;
    418           return false;
    419         } else {
    420           term = AutoDiff.TermBuilder.Power(t, 0.5);
    421           return true;
    422         }
    423       }
    424       if (node.Symbol is Sine) {
    425         AutoDiff.Term t;
    426         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    427           term = null;
    428           return false;
    429         } else {
    430           term = sin(t);
    431           return true;
    432         }
    433       }
    434       if (node.Symbol is Cosine) {
    435         AutoDiff.Term t;
    436         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    437           term = null;
    438           return false;
    439         } else {
    440           term = cos(t);
    441           return true;
    442         }
    443       }
    444       if (node.Symbol is Tangent) {
    445         AutoDiff.Term t;
    446         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    447           term = null;
    448           return false;
    449         } else {
    450           term = tan(t);
    451           return true;
    452         }
    453       }
    454       if (node.Symbol is Erf) {
    455         AutoDiff.Term t;
    456         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    457           term = null;
    458           return false;
    459         } else {
    460           term = erf(t);
    461           return true;
    462         }
    463       }
    464       if (node.Symbol is Norm) {
    465         AutoDiff.Term t;
    466         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
    467           term = null;
    468           return false;
    469         } else {
    470           term = norm(t);
    471           return true;
    472         }
    473       }
    474       if (node.Symbol is StartSymbol) {
    475         var alpha = new AutoDiff.Variable();
    476         var beta = new AutoDiff.Variable();
    477         variables.Add(beta);
    478         variables.Add(alpha);
    479         AutoDiff.Term branchTerm;
    480         if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out branchTerm)) {
    481           term = branchTerm * alpha + beta;
    482           return true;
    483         } else {
    484           term = null;
    485           return false;
    486         }
    487       }
    488       term = null;
    489       return false;
    490     }
    491 
    492247    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
    493       var containsUnknownSymbol = (
    494         from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
    495         where
    496          !(n.Symbol is Variable) &&
    497          !(n.Symbol is LaggedVariable) &&
    498          !(n.Symbol is Constant) &&
    499          !(n.Symbol is Addition) &&
    500          !(n.Symbol is Subtraction) &&
    501          !(n.Symbol is Multiplication) &&
    502          !(n.Symbol is Division) &&
    503          !(n.Symbol is Logarithm) &&
    504          !(n.Symbol is Exponential) &&
    505          !(n.Symbol is SquareRoot) &&
    506          !(n.Symbol is Square) &&
    507          !(n.Symbol is Sine) &&
    508          !(n.Symbol is Cosine) &&
    509          !(n.Symbol is Tangent) &&
    510          !(n.Symbol is Erf) &&
    511          !(n.Symbol is Norm) &&
    512          !(n.Symbol is StartSymbol)
    513         select n).
    514       Any();
    515       return !containsUnknownSymbol;
     248      return TreeToAutoDiffTermTransformator.IsCompatible(tree);
    516249    }
    517250  }
Note: See TracChangeset for help on using the changeset viewer.