Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/11/17 14:13:14 (7 years ago)
Author:
gkronber
Message:

#2697 applied changes from r14378 again

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs

    r14826 r14840  
    2323using System.Collections.Generic;
    2424using System.Linq;
    25 using AutoDiff;
    2625using HeuristicLab.Common;
    2726using HeuristicLab.Core;
     
    153152    }
    154153
    155     #region derivations of functions
    156     // create function factory for arctangent
    157     private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
    158       eval: Math.Atan,
    159       diff: x => 1 / (1 + x * x));
    160     private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
    161       eval: Math.Sin,
    162       diff: Math.Cos);
    163     private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
    164        eval: Math.Cos,
    165        diff: x => -Math.Sin(x));
    166     private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
    167       eval: Math.Tan,
    168       diff: x => 1 + Math.Tan(x) * Math.Tan(x));
    169     private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
    170       eval: alglib.errorfunction,
    171       diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
    172     private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
    173       eval: alglib.normaldistribution,
    174       diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
    175     #endregion
    176 
    177 
    178 
    179154    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
    180155      ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling,
     
    188163      // variable name, variable value (for factor vars) and lag as a DataForVariable object.
    189164      // A dictionary is used to find parameters
    190       var variables = new List<AutoDiff.Variable>();
    191       var parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
    192 
    193       AutoDiff.Term func;
    194       if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, updateVariableWeights, out func))
     165      double[] initialConstants;
     166      var parameters = new List<TreeToAutoDiffTermTransformator.DataForVariable>();
     167
     168      TreeToAutoDiffTermTransformator.ParametricFunction func;
     169      TreeToAutoDiffTermTransformator.ParametricFunctionGradient func_grad;
     170      if (!TreeToAutoDiffTermTransformator.TryTransformToAutoDiff(tree, updateVariableWeights, out parameters, out initialConstants, out func, out func_grad))
    195171        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
    196172      if (parameters.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0
    197173
    198174      var parameterEntries = parameters.ToArray(); // order of entries must be the same for x
    199       AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameterEntries.Select(kvp => kvp.Value).ToArray());
    200 
    201       List<SymbolicExpressionTreeTerminalNode> terminalNodes = null; // gkronber only used for extraction of initial constants
    202       if (updateVariableWeights)
    203         terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
    204       else
    205         terminalNodes = new List<SymbolicExpressionTreeTerminalNode>
    206           (tree.Root.IterateNodesPrefix()
    207           .OfType<SymbolicExpressionTreeTerminalNode>()
    208           .Where(node => node is ConstantTreeNode || node is FactorVariableTreeNode));
    209175
    210176      //extract inital constants
    211       double[] c = new double[variables.Count];
     177      double[] c = new double[initialConstants.Length];
    212178      {
    213179        c[0] = 0.0;
    214180        c[1] = 1.0;
    215         int i = 2;
    216         foreach (var node in terminalNodes) {
    217           ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
    218           VariableTreeNode variableTreeNode = node as VariableTreeNode;
    219           BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
    220           FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
    221           if (constantTreeNode != null)
    222             c[i++] = constantTreeNode.Value;
    223           else if (updateVariableWeights && variableTreeNode != null)
    224             c[i++] = variableTreeNode.Weight;
    225           else if (updateVariableWeights && binFactorVarTreeNode != null)
    226             c[i++] = binFactorVarTreeNode.Weight;
    227           else if (factorVarTreeNode != null) {
    228             // gkronber: a factorVariableTreeNode holds a category-specific constant therefore we can consider factors to be the same as constants
    229             foreach (var w in factorVarTreeNode.Weights) c[i++] = w;
    230           }
    231         }
     181        Array.Copy(initialConstants, 0, c, 2, initialConstants.Length);
    232182      }
    233183      double[] originalConstants = (double[])c.Clone();
     
    243193      foreach (var r in rows) {
    244194        int col = 0;
    245         foreach (var kvp in parameterEntries) {
    246           var info = kvp.Key;
     195        foreach (var info in parameterEntries) {
    247196          int lag = info.lag;
    248197          if (ds.VariableHasType<double>(info.variableName)) {
     
    260209      int k = c.Length;
    261210
    262       alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
    263       alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
     211      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(func);
     212      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(func_grad);
    264213
    265214      try {
     
    307256    }
    308257
    309     private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
    310       return (double[] c, double[] x, ref double func, object o) => {
    311         func = compiledFunc.Evaluate(c, x);
     258    private static alglib.ndimensional_pfunc CreatePFunc(TreeToAutoDiffTermTransformator.ParametricFunction func) {
     259      return (double[] c, double[] x, ref double fx, object o) => {
     260        fx = func(c, x);
    312261      };
    313262    }
    314263
    315     private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
    316       return (double[] c, double[] x, ref double func, double[] grad, object o) => {
    317         var tupel = compiledFunc.Differentiate(c, x);
    318         func = tupel.Item2;
     264    private static alglib.ndimensional_pgrad CreatePGrad(TreeToAutoDiffTermTransformator.ParametricFunctionGradient func_grad) {
     265      return (double[] c, double[] x, ref double fx, double[] grad, object o) => {
     266        var tupel = func_grad(c, x);
     267        fx = tupel.Item2;
    319268        Array.Copy(tupel.Item1, grad, grad.Length);
    320269      };
    321270    }
    322 
    323     private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node,
    324       List<AutoDiff.Variable> variables, Dictionary<DataForVariable, AutoDiff.Variable> parameters,
    325       bool updateVariableWeights, out AutoDiff.Term term) {
    326       if (node.Symbol is Constant) {
    327         var var = new AutoDiff.Variable();
    328         variables.Add(var);
    329         term = var;
    330         return true;
    331       }
    332       if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
    333         var varNode = node as VariableTreeNodeBase;
    334         var factorVarNode = node as BinaryFactorVariableTreeNode;
    335         // factor variable values are only 0 or 1 and set in x accordingly
    336         var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
    337         var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
    338 
    339         if (updateVariableWeights) {
    340           var w = new AutoDiff.Variable();
    341           variables.Add(w);
    342           term = AutoDiff.TermBuilder.Product(w, par);
    343         } else {
    344           term = varNode.Weight * par;
    345         }
    346         return true;
    347       }
    348       if (node.Symbol is FactorVariable) {
    349         var factorVarNode = node as FactorVariableTreeNode;
    350         var products = new List<Term>();
    351         foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
    352           var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
    353 
    354           var wVar = new AutoDiff.Variable();
    355           variables.Add(wVar);
    356 
    357           products.Add(AutoDiff.TermBuilder.Product(wVar, par));
    358         }
    359         term = AutoDiff.TermBuilder.Sum(products);
    360         return true;
    361       }
    362       if (node.Symbol is LaggedVariable) {
    363         var varNode = node as LaggedVariableTreeNode;
    364         var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
    365 
    366         if (updateVariableWeights) {
    367           var w = new AutoDiff.Variable();
    368           variables.Add(w);
    369           term = AutoDiff.TermBuilder.Product(w, par);
    370         } else {
    371           term = varNode.Weight * par;
    372         }
    373         return true;
    374       }
    375       if (node.Symbol is Addition) {
    376         List<AutoDiff.Term> terms = new List<Term>();
    377         foreach (var subTree in node.Subtrees) {
    378           AutoDiff.Term t;
    379           if (!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
    380             term = null;
    381             return false;
    382           }
    383           terms.Add(t);
    384         }
    385         term = AutoDiff.TermBuilder.Sum(terms);
    386         return true;
    387       }
    388       if (node.Symbol is Subtraction) {
    389         List<AutoDiff.Term> terms = new List<Term>();
    390         for (int i = 0; i < node.SubtreeCount; i++) {
    391           AutoDiff.Term t;
    392           if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, updateVariableWeights, out t)) {
    393             term = null;
    394             return false;
    395           }
    396           if (i > 0) t = -t;
    397           terms.Add(t);
    398         }
    399         if (terms.Count == 1) term = -terms[0];
    400         else term = AutoDiff.TermBuilder.Sum(terms);
    401         return true;
    402       }
    403       if (node.Symbol is Multiplication) {
    404         List<AutoDiff.Term> terms = new List<Term>();
    405         foreach (var subTree in node.Subtrees) {
    406           AutoDiff.Term t;
    407           if (!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
    408             term = null;
    409             return false;
    410           }
    411           terms.Add(t);
    412         }
    413         if (terms.Count == 1) term = terms[0];
    414         else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
    415         return true;
    416 
    417       }
    418       if (node.Symbol is Division) {
    419         List<AutoDiff.Term> terms = new List<Term>();
    420         foreach (var subTree in node.Subtrees) {
    421           AutoDiff.Term t;
    422           if (!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
    423             term = null;
    424             return false;
    425           }
    426           terms.Add(t);
    427         }
    428         if (terms.Count == 1) term = 1.0 / terms[0];
    429         else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
    430         return true;
    431       }
    432       if (node.Symbol is Logarithm) {
    433         AutoDiff.Term t;
    434         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    435           term = null;
    436           return false;
    437         } else {
    438           term = AutoDiff.TermBuilder.Log(t);
    439           return true;
    440         }
    441       }
    442       if (node.Symbol is Exponential) {
    443         AutoDiff.Term t;
    444         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    445           term = null;
    446           return false;
    447         } else {
    448           term = AutoDiff.TermBuilder.Exp(t);
    449           return true;
    450         }
    451       }
    452       if (node.Symbol is Square) {
    453         AutoDiff.Term t;
    454         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    455           term = null;
    456           return false;
    457         } else {
    458           term = AutoDiff.TermBuilder.Power(t, 2.0);
    459           return true;
    460         }
    461       }
    462       if (node.Symbol is SquareRoot) {
    463         AutoDiff.Term t;
    464         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    465           term = null;
    466           return false;
    467         } else {
    468           term = AutoDiff.TermBuilder.Power(t, 0.5);
    469           return true;
    470         }
    471       }
    472       if (node.Symbol is Sine) {
    473         AutoDiff.Term t;
    474         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    475           term = null;
    476           return false;
    477         } else {
    478           term = sin(t);
    479           return true;
    480         }
    481       }
    482       if (node.Symbol is Cosine) {
    483         AutoDiff.Term t;
    484         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    485           term = null;
    486           return false;
    487         } else {
    488           term = cos(t);
    489           return true;
    490         }
    491       }
    492       if (node.Symbol is Tangent) {
    493         AutoDiff.Term t;
    494         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    495           term = null;
    496           return false;
    497         } else {
    498           term = tan(t);
    499           return true;
    500         }
    501       }
    502       if (node.Symbol is Erf) {
    503         AutoDiff.Term t;
    504         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    505           term = null;
    506           return false;
    507         } else {
    508           term = erf(t);
    509           return true;
    510         }
    511       }
    512       if (node.Symbol is Norm) {
    513         AutoDiff.Term t;
    514         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    515           term = null;
    516           return false;
    517         } else {
    518           term = norm(t);
    519           return true;
    520         }
    521       }
    522       if (node.Symbol is StartSymbol) {
    523         var alpha = new AutoDiff.Variable();
    524         var beta = new AutoDiff.Variable();
    525         variables.Add(beta);
    526         variables.Add(alpha);
    527         AutoDiff.Term branchTerm;
    528         if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out branchTerm)) {
    529           term = branchTerm * alpha + beta;
    530           return true;
    531         } else {
    532           term = null;
    533           return false;
    534         }
    535       }
    536       term = null;
    537       return false;
    538     }
    539 
    540     // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
    541     // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
    542     private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
    543       string varName, string varValue = "", int lag = 0) {
    544       var data = new DataForVariable(varName, varValue, lag);
    545 
    546       AutoDiff.Variable par = null;
    547       if (!parameters.TryGetValue(data, out par)) {
    548         // not found -> create new parameter and entries in names and values lists
    549         par = new AutoDiff.Variable();
    550         parameters.Add(data, par);
    551       }
    552       return par;
    553     }
    554 
    555271    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
    556       var containsUnknownSymbol = (
    557         from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
    558         where
    559          !(n.Symbol is Variable) &&
    560          !(n.Symbol is BinaryFactorVariable) &&
    561          !(n.Symbol is FactorVariable) &&
    562          !(n.Symbol is LaggedVariable) &&
    563          !(n.Symbol is Constant) &&
    564          !(n.Symbol is Addition) &&
    565          !(n.Symbol is Subtraction) &&
    566          !(n.Symbol is Multiplication) &&
    567          !(n.Symbol is Division) &&
    568          !(n.Symbol is Logarithm) &&
    569          !(n.Symbol is Exponential) &&
    570          !(n.Symbol is SquareRoot) &&
    571          !(n.Symbol is Square) &&
    572          !(n.Symbol is Sine) &&
    573          !(n.Symbol is Cosine) &&
    574          !(n.Symbol is Tangent) &&
    575          !(n.Symbol is Erf) &&
    576          !(n.Symbol is Norm) &&
    577          !(n.Symbol is StartSymbol)
    578         select n).
    579       Any();
    580       return !containsUnknownSymbol;
    581     }
    582 
    583 
    584     #region helper class
    585     private class DataForVariable {
    586       public readonly string variableName;
    587       public readonly string variableValue; // for factor vars
    588       public readonly int lag;
    589 
    590       public DataForVariable(string varName, string varValue, int lag) {
    591         this.variableName = varName;
    592         this.variableValue = varValue;
    593         this.lag = lag;
    594       }
    595 
    596       public override bool Equals(object obj) {
    597         var other = obj as DataForVariable;
    598         if (other == null) return false;
    599         return other.variableName.Equals(this.variableName) &&
    600                other.variableValue.Equals(this.variableValue) &&
    601                other.lag == this.lag;
    602       }
    603 
    604       public override int GetHashCode() {
    605         return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
    606       }
    607     }
    608     #endregion
     272      return TreeToAutoDiffTermTransformator.IsCompatible(tree);
     273    }
    609274  }
    610275}
Note: See TracChangeset for help on using the changeset viewer.