Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/17/16 15:41:33 (7 years ago)
Author:
gkronber
Message:

#2697: reverse merge of r14378, r14390, r14391, r14393, r14394, r14396

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs

    r14390 r14400  
    2323using System.Collections.Generic;
    2424using System.Linq;
     25using AutoDiff;
    2526using HeuristicLab.Common;
    2627using HeuristicLab.Core;
     
    152153    }
    153154
     155    #region derivations of functions
     156    // create function factory for arctangent
     157    private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
     158      eval: Math.Atan,
     159      diff: x => 1 / (1 + x * x));
     160    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
     161      eval: Math.Sin,
     162      diff: Math.Cos);
     163    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
     164       eval: Math.Cos,
     165       diff: x => -Math.Sin(x));
     166    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
     167      eval: Math.Tan,
     168      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
     169    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
     170      eval: alglib.errorfunction,
     171      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
     172    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
     173      eval: alglib.normaldistribution,
     174      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
     175    #endregion
     176
     177
    154178    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, bool updateVariableWeights = true, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, bool updateConstantsInTree = true) {
    155179
    156       string[] variableNames;
    157       int[] lags;
    158       double[] constants;
    159 
    160       TreeToAutoDiffTermConverter.ParametricFunction func;
    161       TreeToAutoDiffTermConverter.ParametricFunctionGradient func_grad;
    162       if (!TreeToAutoDiffTermConverter.TryTransformToAutoDiff(tree, updateVariableWeights, out variableNames, out lags, out constants, out func, out func_grad))
     180      List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
     181      List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
     182      List<string> variableNames = new List<string>();
     183      List<int> lags = new List<int>();
     184
     185      AutoDiff.Term func;
     186      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out func))
    163187        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
    164       if (variableNames.Length == 0) return 0.0;
     188      if (variableNames.Count == 0) return 0.0;
     189
     190      AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameters.ToArray());
     191
     192      List<SymbolicExpressionTreeTerminalNode> terminalNodes = null;
     193      if (updateVariableWeights)
     194        terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
     195      else
     196        terminalNodes = new List<SymbolicExpressionTreeTerminalNode>(tree.Root.IterateNodesPrefix().OfType<ConstantTreeNode>());
    165197
    166198      //extract inital constants
    167       double[] c = new double[constants.Length + 2];
    168       c[0] = 0.0;
    169       c[1] = 1.0;
    170       Array.Copy(constants, 0, c, 2, constants.Length);
     199      double[] c = new double[variables.Count];
     200      {
     201        c[0] = 0.0;
     202        c[1] = 1.0;
     203        int i = 2;
     204        foreach (var node in terminalNodes) {
     205          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
     206          VariableTreeNode variableTreeNode = node as VariableTreeNode;
     207          if (constantTreeNode != null)
     208            c[i++] = constantTreeNode.Value;
     209          else if (updateVariableWeights && variableTreeNode != null)
     210            c[i++] = variableTreeNode.Weight;
     211        }
     212      }
    171213      double[] originalConstants = (double[])c.Clone();
    172214      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
     
    176218      int info;
    177219
    178       // TODO: refactor
    179220      IDataset ds = problemData.Dataset;
    180       double[,] x = new double[rows.Count(), variableNames.Length];
     221      double[,] x = new double[rows.Count(), variableNames.Count];
    181222      int row = 0;
    182223      foreach (var r in rows) {
    183         for (int col = 0; col < variableNames.Length; col++) {
     224        for (int col = 0; col < variableNames.Count; col++) {
    184225          int lag = lags[col];
    185226          x[row, col] = ds.GetDoubleValue(variableNames[col], r + lag);
     
    192233      int k = c.Length;
    193234
    194       alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(func);
    195       alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(func_grad);
     235      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
     236      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
    196237
    197238      try {
     
    231272    }
    232273
    233     private static alglib.ndimensional_pfunc CreatePFunc(TreeToAutoDiffTermConverter.ParametricFunction func) {
    234       return (double[] c, double[] x, ref double fx, object o) => {
    235         fx = func(c, x);
     274    private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
     275      return (double[] c, double[] x, ref double func, object o) => {
     276        func = compiledFunc.Evaluate(c, x);
    236277      };
    237278    }
    238279
    239     private static alglib.ndimensional_pgrad CreatePGrad(TreeToAutoDiffTermConverter.ParametricFunctionGradient func_grad) {
    240       return (double[] c, double[] x, ref double fx, double[] grad, object o) => {
    241         var tupel = func_grad(c, x);
    242         fx = tupel.Item2;
     280    private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
     281      return (double[] c, double[] x, ref double func, double[] grad, object o) => {
     282        var tupel = compiledFunc.Differentiate(c, x);
     283        func = tupel.Item2;
    243284        Array.Copy(tupel.Item1, grad, grad.Length);
    244285      };
    245286    }
    246287
     288    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, List<int> lags, bool updateVariableWeights, out AutoDiff.Term term) {
     289      if (node.Symbol is Constant) {
     290        var var = new AutoDiff.Variable();
     291        variables.Add(var);
     292        term = var;
     293        return true;
     294      }
     295      if (node.Symbol is Variable) {
     296        var varNode = node as VariableTreeNode;
     297        var par = new AutoDiff.Variable();
     298        parameters.Add(par);
     299        variableNames.Add(varNode.VariableName);
     300        lags.Add(0);
     301
     302        if (updateVariableWeights) {
     303          var w = new AutoDiff.Variable();
     304          variables.Add(w);
     305          term = AutoDiff.TermBuilder.Product(w, par);
     306        } else {
     307          term = varNode.Weight * par;
     308        }
     309        return true;
     310      }
     311      if (node.Symbol is LaggedVariable) {
     312        var varNode = node as LaggedVariableTreeNode;
     313        var par = new AutoDiff.Variable();
     314        parameters.Add(par);
     315        variableNames.Add(varNode.VariableName);
     316        lags.Add(varNode.Lag);
     317
     318        if (updateVariableWeights) {
     319          var w = new AutoDiff.Variable();
     320          variables.Add(w);
     321          term = AutoDiff.TermBuilder.Product(w, par);
     322        } else {
     323          term = varNode.Weight * par;
     324        }
     325        return true;
     326      }
     327      if (node.Symbol is Addition) {
     328        List<AutoDiff.Term> terms = new List<Term>();
     329        foreach (var subTree in node.Subtrees) {
     330          AutoDiff.Term t;
     331          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     332            term = null;
     333            return false;
     334          }
     335          terms.Add(t);
     336        }
     337        term = AutoDiff.TermBuilder.Sum(terms);
     338        return true;
     339      }
     340      if (node.Symbol is Subtraction) {
     341        List<AutoDiff.Term> terms = new List<Term>();
     342        for (int i = 0; i < node.SubtreeCount; i++) {
     343          AutoDiff.Term t;
     344          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     345            term = null;
     346            return false;
     347          }
     348          if (i > 0) t = -t;
     349          terms.Add(t);
     350        }
     351        if (terms.Count == 1) term = -terms[0];
     352        else term = AutoDiff.TermBuilder.Sum(terms);
     353        return true;
     354      }
     355      if (node.Symbol is Multiplication) {
     356        List<AutoDiff.Term> terms = new List<Term>();
     357        foreach (var subTree in node.Subtrees) {
     358          AutoDiff.Term t;
     359          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     360            term = null;
     361            return false;
     362          }
     363          terms.Add(t);
     364        }
     365        if (terms.Count == 1) term = terms[0];
     366        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
     367        return true;
     368
     369      }
     370      if (node.Symbol is Division) {
     371        List<AutoDiff.Term> terms = new List<Term>();
     372        foreach (var subTree in node.Subtrees) {
     373          AutoDiff.Term t;
     374          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     375            term = null;
     376            return false;
     377          }
     378          terms.Add(t);
     379        }
     380        if (terms.Count == 1) term = 1.0 / terms[0];
     381        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
     382        return true;
     383      }
     384      if (node.Symbol is Logarithm) {
     385        AutoDiff.Term t;
     386        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     387          term = null;
     388          return false;
     389        } else {
     390          term = AutoDiff.TermBuilder.Log(t);
     391          return true;
     392        }
     393      }
     394      if (node.Symbol is Exponential) {
     395        AutoDiff.Term t;
     396        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     397          term = null;
     398          return false;
     399        } else {
     400          term = AutoDiff.TermBuilder.Exp(t);
     401          return true;
     402        }
     403      }
     404      if (node.Symbol is Square) {
     405        AutoDiff.Term t;
     406        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     407          term = null;
     408          return false;
     409        } else {
     410          term = AutoDiff.TermBuilder.Power(t, 2.0);
     411          return true;
     412        }
     413      }
     414      if (node.Symbol is SquareRoot) {
     415        AutoDiff.Term t;
     416        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     417          term = null;
     418          return false;
     419        } else {
     420          term = AutoDiff.TermBuilder.Power(t, 0.5);
     421          return true;
     422        }
     423      }
     424      if (node.Symbol is Sine) {
     425        AutoDiff.Term t;
     426        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     427          term = null;
     428          return false;
     429        } else {
     430          term = sin(t);
     431          return true;
     432        }
     433      }
     434      if (node.Symbol is Cosine) {
     435        AutoDiff.Term t;
     436        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     437          term = null;
     438          return false;
     439        } else {
     440          term = cos(t);
     441          return true;
     442        }
     443      }
     444      if (node.Symbol is Tangent) {
     445        AutoDiff.Term t;
     446        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     447          term = null;
     448          return false;
     449        } else {
     450          term = tan(t);
     451          return true;
     452        }
     453      }
     454      if (node.Symbol is Erf) {
     455        AutoDiff.Term t;
     456        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     457          term = null;
     458          return false;
     459        } else {
     460          term = erf(t);
     461          return true;
     462        }
     463      }
     464      if (node.Symbol is Norm) {
     465        AutoDiff.Term t;
     466        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     467          term = null;
     468          return false;
     469        } else {
     470          term = norm(t);
     471          return true;
     472        }
     473      }
     474      if (node.Symbol is StartSymbol) {
     475        var alpha = new AutoDiff.Variable();
     476        var beta = new AutoDiff.Variable();
     477        variables.Add(beta);
     478        variables.Add(alpha);
     479        AutoDiff.Term branchTerm;
     480        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out branchTerm)) {
     481          term = branchTerm * alpha + beta;
     482          return true;
     483        } else {
     484          term = null;
     485          return false;
     486        }
     487      }
     488      term = null;
     489      return false;
     490    }
     491
    247492    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
    248       return TreeToAutoDiffTermConverter.IsCompatible(tree);
     493      var containsUnknownSymbol = (
     494        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
     495        where
     496         !(n.Symbol is Variable) &&
     497         !(n.Symbol is LaggedVariable) &&
     498         !(n.Symbol is Constant) &&
     499         !(n.Symbol is Addition) &&
     500         !(n.Symbol is Subtraction) &&
     501         !(n.Symbol is Multiplication) &&
     502         !(n.Symbol is Division) &&
     503         !(n.Symbol is Logarithm) &&
     504         !(n.Symbol is Exponential) &&
     505         !(n.Symbol is SquareRoot) &&
     506         !(n.Symbol is Square) &&
     507         !(n.Symbol is Sine) &&
     508         !(n.Symbol is Cosine) &&
     509         !(n.Symbol is Tangent) &&
     510         !(n.Symbol is Erf) &&
     511         !(n.Symbol is Norm) &&
     512         !(n.Symbol is StartSymbol)
     513        select n).
     514      Any();
     515      return !containsUnknownSymbol;
    249516    }
    250517  }
Note: See TracChangeset for help on using the changeset viewer.