Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/06/17 11:08:16 (8 years ago)
Author:
gkronber
Message:

#2697: merged r14840 from trunk to stable

File:
1 edited

Legend:

Unmodified
Added
Removed
  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs

    r15136 r15141  
    2323using System.Collections.Generic;
    2424using System.Linq;
    25 using AutoDiff;
    2625using HeuristicLab.Common;
    2726using HeuristicLab.Core;
     
    153152    }
    154153
    155     #region derivations of functions
    156     // create function factory for arctangent
    157     private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
    158       eval: Math.Atan,
    159       diff: x => 1 / (1 + x * x));
    160     private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
    161       eval: Math.Sin,
    162       diff: Math.Cos);
    163     private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
    164        eval: Math.Cos,
    165        diff: x => -Math.Sin(x));
    166     private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
    167       eval: Math.Tan,
    168       diff: x => 1 + Math.Tan(x) * Math.Tan(x));
    169     private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
    170       eval: alglib.errorfunction,
    171       diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
    172     private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
    173       eval: alglib.normaldistribution,
    174       diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
    175     #endregion
    176 
    177 
    178 
    179154    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
    180155      ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling,
     
    188163      // variable name, variable value (for factor vars) and lag as a DataForVariable object.
    189164      // A dictionary is used to find parameters
    190       var variables = new List<AutoDiff.Variable>();
    191       var parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
    192 
    193       AutoDiff.Term func;
    194       if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, updateVariableWeights, out func))
     165      double[] initialConstants;
     166      var parameters = new List<TreeToAutoDiffTermTransformator.DataForVariable>();
     167
     168      TreeToAutoDiffTermTransformator.ParametricFunction func;
     169      TreeToAutoDiffTermTransformator.ParametricFunctionGradient func_grad;
     170      if (!TreeToAutoDiffTermTransformator.TryTransformToAutoDiff(tree, updateVariableWeights, out parameters, out initialConstants, out func, out func_grad))
    195171        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
    196172      if (parameters.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0
    197173
    198174      var parameterEntries = parameters.ToArray(); // order of entries must be the same for x
    199       AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameterEntries.Select(kvp => kvp.Value).ToArray());
    200 
    201       List<SymbolicExpressionTreeTerminalNode> terminalNodes = null; // gkronber only used for extraction of initial constants
    202       if (updateVariableWeights)
    203         terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
    204       else
    205         terminalNodes = new List<SymbolicExpressionTreeTerminalNode>
    206           (tree.Root.IterateNodesPrefix()
    207           .OfType<SymbolicExpressionTreeTerminalNode>()
    208           .Where(node => node is ConstantTreeNode || node is FactorVariableTreeNode));
    209175
    210176      //extract inital constants
    211       double[] c = new double[variables.Count];
     177      double[] c = new double[initialConstants.Length];
    212178      {
    213179        c[0] = 0.0;
    214180        c[1] = 1.0;
    215         int i = 2;
    216         foreach (var node in terminalNodes) {
    217           ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
    218           VariableTreeNode variableTreeNode = node as VariableTreeNode;
    219           BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
    220           FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
    221           if (constantTreeNode != null)
    222             c[i++] = constantTreeNode.Value;
    223           else if (updateVariableWeights && variableTreeNode != null)
    224             c[i++] = variableTreeNode.Weight;
    225           else if (updateVariableWeights && binFactorVarTreeNode != null)
    226             c[i++] = binFactorVarTreeNode.Weight;
    227           else if (factorVarTreeNode != null) {
    228             // gkronber: a factorVariableTreeNode holds a category-specific constant therefore we can consider factors to be the same as constants
    229             foreach (var w in factorVarTreeNode.Weights) c[i++] = w;
    230           }
    231         }
     181        Array.Copy(initialConstants, 0, c, 2, initialConstants.Length);
    232182      }
    233183      double[] originalConstants = (double[])c.Clone();
     
    243193      foreach (var r in rows) {
    244194        int col = 0;
    245         foreach (var kvp in parameterEntries) {
    246           var info = kvp.Key;
     195        foreach (var info in parameterEntries) {
    247196          if (ds.VariableHasType<double>(info.variableName)) {
    248197            x[row, col] = ds.GetDoubleValue(info.variableName, r + info.lag);
     
    259208      int k = c.Length;
    260209
    261       alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
    262       alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
     210      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(func);
     211      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(func_grad);
    263212
    264213      try {
     
    305254    }
    306255
    307     private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
    308       return (double[] c, double[] x, ref double func, object o) => {
    309         func = compiledFunc.Evaluate(c, x);
     256    private static alglib.ndimensional_pfunc CreatePFunc(TreeToAutoDiffTermTransformator.ParametricFunction func) {
     257      return (double[] c, double[] x, ref double fx, object o) => {
     258        fx = func(c, x);
    310259      };
    311260    }
    312261
    313     private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
    314       return (double[] c, double[] x, ref double func, double[] grad, object o) => {
    315         var tupel = compiledFunc.Differentiate(c, x);
    316         func = tupel.Item2;
     262    private static alglib.ndimensional_pgrad CreatePGrad(TreeToAutoDiffTermTransformator.ParametricFunctionGradient func_grad) {
     263      return (double[] c, double[] x, ref double fx, double[] grad, object o) => {
     264        var tupel = func_grad(c, x);
     265        fx = tupel.Item2;
    317266        Array.Copy(tupel.Item1, grad, grad.Length);
    318267      };
    319268    }
    320 
    321     private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node,
    322       List<AutoDiff.Variable> variables, Dictionary<DataForVariable, AutoDiff.Variable> parameters,
    323       bool updateVariableWeights, out AutoDiff.Term term) {
    324       if (node.Symbol is Constant) {
    325         var var = new AutoDiff.Variable();
    326         variables.Add(var);
    327         term = var;
    328         return true;
    329       }
    330       if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
    331         var varNode = node as VariableTreeNodeBase;
    332         var factorVarNode = node as BinaryFactorVariableTreeNode;
    333         // factor variable values are only 0 or 1 and set in x accordingly
    334         var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
    335         var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
    336 
    337         if (updateVariableWeights) {
    338           var w = new AutoDiff.Variable();
    339           variables.Add(w);
    340           term = AutoDiff.TermBuilder.Product(w, par);
    341         } else {
    342           term = varNode.Weight * par;
    343         }
    344         return true;
    345       }
    346       if (node.Symbol is FactorVariable) {
    347         var factorVarNode = node as FactorVariableTreeNode;
    348         var products = new List<Term>();
    349         foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
    350           var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
    351 
    352           var wVar = new AutoDiff.Variable();
    353           variables.Add(wVar);
    354 
    355           products.Add(AutoDiff.TermBuilder.Product(wVar, par));
    356         }
    357         term = AutoDiff.TermBuilder.Sum(products);
    358         return true;
    359       }
    360       if (node.Symbol is LaggedVariable) {
    361         var varNode = node as LaggedVariableTreeNode;
    362         var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
    363 
    364         if (updateVariableWeights) {
    365           var w = new AutoDiff.Variable();
    366           variables.Add(w);
    367           term = AutoDiff.TermBuilder.Product(w, par);
    368         } else {
    369           term = varNode.Weight * par;
    370         }
    371         return true;
    372       }
    373       if (node.Symbol is Addition) {
    374         List<AutoDiff.Term> terms = new List<Term>();
    375         foreach (var subTree in node.Subtrees) {
    376           AutoDiff.Term t;
    377           if (!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
    378             term = null;
    379             return false;
    380           }
    381           terms.Add(t);
    382         }
    383         term = AutoDiff.TermBuilder.Sum(terms);
    384         return true;
    385       }
    386       if (node.Symbol is Subtraction) {
    387         List<AutoDiff.Term> terms = new List<Term>();
    388         for (int i = 0; i < node.SubtreeCount; i++) {
    389           AutoDiff.Term t;
    390           if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, updateVariableWeights, out t)) {
    391             term = null;
    392             return false;
    393           }
    394           if (i > 0) t = -t;
    395           terms.Add(t);
    396         }
    397         if (terms.Count == 1) term = -terms[0];
    398         else term = AutoDiff.TermBuilder.Sum(terms);
    399         return true;
    400       }
    401       if (node.Symbol is Multiplication) {
    402         List<AutoDiff.Term> terms = new List<Term>();
    403         foreach (var subTree in node.Subtrees) {
    404           AutoDiff.Term t;
    405           if (!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
    406             term = null;
    407             return false;
    408           }
    409           terms.Add(t);
    410         }
    411         if (terms.Count == 1) term = terms[0];
    412         else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
    413         return true;
    414 
    415       }
    416       if (node.Symbol is Division) {
    417         List<AutoDiff.Term> terms = new List<Term>();
    418         foreach (var subTree in node.Subtrees) {
    419           AutoDiff.Term t;
    420           if (!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
    421             term = null;
    422             return false;
    423           }
    424           terms.Add(t);
    425         }
    426         if (terms.Count == 1) term = 1.0 / terms[0];
    427         else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
    428         return true;
    429       }
    430       if (node.Symbol is Logarithm) {
    431         AutoDiff.Term t;
    432         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    433           term = null;
    434           return false;
    435         } else {
    436           term = AutoDiff.TermBuilder.Log(t);
    437           return true;
    438         }
    439       }
    440       if (node.Symbol is Exponential) {
    441         AutoDiff.Term t;
    442         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    443           term = null;
    444           return false;
    445         } else {
    446           term = AutoDiff.TermBuilder.Exp(t);
    447           return true;
    448         }
    449       }
    450       if (node.Symbol is Square) {
    451         AutoDiff.Term t;
    452         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    453           term = null;
    454           return false;
    455         } else {
    456           term = AutoDiff.TermBuilder.Power(t, 2.0);
    457           return true;
    458         }
    459       }
    460       if (node.Symbol is SquareRoot) {
    461         AutoDiff.Term t;
    462         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    463           term = null;
    464           return false;
    465         } else {
    466           term = AutoDiff.TermBuilder.Power(t, 0.5);
    467           return true;
    468         }
    469       }
    470       if (node.Symbol is Sine) {
    471         AutoDiff.Term t;
    472         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    473           term = null;
    474           return false;
    475         } else {
    476           term = sin(t);
    477           return true;
    478         }
    479       }
    480       if (node.Symbol is Cosine) {
    481         AutoDiff.Term t;
    482         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    483           term = null;
    484           return false;
    485         } else {
    486           term = cos(t);
    487           return true;
    488         }
    489       }
    490       if (node.Symbol is Tangent) {
    491         AutoDiff.Term t;
    492         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    493           term = null;
    494           return false;
    495         } else {
    496           term = tan(t);
    497           return true;
    498         }
    499       }
    500       if (node.Symbol is Erf) {
    501         AutoDiff.Term t;
    502         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    503           term = null;
    504           return false;
    505         } else {
    506           term = erf(t);
    507           return true;
    508         }
    509       }
    510       if (node.Symbol is Norm) {
    511         AutoDiff.Term t;
    512         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    513           term = null;
    514           return false;
    515         } else {
    516           term = norm(t);
    517           return true;
    518         }
    519       }
    520       if (node.Symbol is StartSymbol) {
    521         var alpha = new AutoDiff.Variable();
    522         var beta = new AutoDiff.Variable();
    523         variables.Add(beta);
    524         variables.Add(alpha);
    525         AutoDiff.Term branchTerm;
    526         if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out branchTerm)) {
    527           term = branchTerm * alpha + beta;
    528           return true;
    529         } else {
    530           term = null;
    531           return false;
    532         }
    533       }
    534       term = null;
    535       return false;
    536     }
    537 
    538     // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
    539     // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
    540     private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
    541       string varName, string varValue = "", int lag = 0) {
    542       var data = new DataForVariable(varName, varValue, lag);
    543 
    544       AutoDiff.Variable par = null;
    545       if (!parameters.TryGetValue(data, out par)) {
    546         // not found -> create new parameter and entries in names and values lists
    547         par = new AutoDiff.Variable();
    548         parameters.Add(data, par);
    549       }
    550       return par;
    551     }
    552 
    553269    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
    554       var containsUnknownSymbol = (
    555         from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
    556         where
    557          !(n.Symbol is Variable) &&
    558          !(n.Symbol is BinaryFactorVariable) &&
    559          !(n.Symbol is FactorVariable) &&
    560          !(n.Symbol is LaggedVariable) &&
    561          !(n.Symbol is Constant) &&
    562          !(n.Symbol is Addition) &&
    563          !(n.Symbol is Subtraction) &&
    564          !(n.Symbol is Multiplication) &&
    565          !(n.Symbol is Division) &&
    566          !(n.Symbol is Logarithm) &&
    567          !(n.Symbol is Exponential) &&
    568          !(n.Symbol is SquareRoot) &&
    569          !(n.Symbol is Square) &&
    570          !(n.Symbol is Sine) &&
    571          !(n.Symbol is Cosine) &&
    572          !(n.Symbol is Tangent) &&
    573          !(n.Symbol is Erf) &&
    574          !(n.Symbol is Norm) &&
    575          !(n.Symbol is StartSymbol)
    576         select n).
    577       Any();
    578       return !containsUnknownSymbol;
    579     }
    580 
    581 
    582     #region helper class
    583     private class DataForVariable {
    584       public readonly string variableName;
    585       public readonly string variableValue; // for factor vars
    586       public readonly int lag;
    587 
    588       public DataForVariable(string varName, string varValue, int lag) {
    589         this.variableName = varName;
    590         this.variableValue = varValue;
    591         this.lag = lag;
    592       }
    593 
    594       public override bool Equals(object obj) {
    595         var other = obj as DataForVariable;
    596         if (other == null) return false;
    597         return other.variableName.Equals(this.variableName) &&
    598                other.variableValue.Equals(this.variableValue) &&
    599                other.lag == this.lag;
    600       }
    601 
    602       public override int GetHashCode() {
    603         return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
    604       }
    605     }
    606     #endregion
     270      return TreeToAutoDiffTermTransformator.IsCompatible(tree);
     271    }
    607272  }
    608273}
Note: See TracChangeset for help on using the changeset viewer.