Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/09/17 20:08:11 (8 years ago)
Author:
gkronber
Message:

#2697: code improvement in TreeToAutoDiffTermConverter

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToAutoDiffTermConverter.cs

    r14851 r14950  
    2323using System.Collections.Generic;
    2424using System.Linq;
     25using System.Runtime.Serialization;
    2526using AutoDiff;
    2627using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     
    2930  public class TreeToAutoDiffTermConverter {
    3031    public delegate double ParametricFunction(double[] vars, double[] @params);
     32
    3133    public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
    3234
     
    6264      eval: Math.Atan,
    6365      diff: x => 1 / (1 + x * x));
     66
    6467    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
    6568      eval: Math.Sin,
    6669      diff: Math.Cos);
     70
    6771    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
    68        eval: Math.Cos,
    69        diff: x => -Math.Sin(x));
     72      eval: Math.Cos,
     73      diff: x => -Math.Sin(x));
     74
    7075    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
    7176      eval: Math.Tan,
    7277      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
     78
    7379    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
    7480      eval: alglib.errorfunction,
    7581      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
     82
    7683    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
    7784      eval: alglib.normaldistribution,
     
    8895      var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable);
    8996      AutoDiff.Term term;
    90       var success = transformator.TryConvertToAutoDiff(tree.Root.GetSubtree(0), out term);
    91       if (success) {
     97      try {
     98        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
    9299        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
    93         var compiledTerm = term.Compile(transformator.variables.ToArray(), parameterEntries.Select(kvp => kvp.Value).ToArray());
     100        var compiledTerm = term.Compile(transformator.variables.ToArray(),
     101          parameterEntries.Select(kvp => kvp.Value).ToArray());
    94102        parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
    95103        initialConstants = transformator.initialConstants.ToArray();
    96104        func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
    97105        func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
    98       } else {
     106        return true;
     107      } catch (ConversionException) {
    99108        func = null;
    100109        func_grad = null;
     
    102111        initialConstants = null;
    103112      }
    104       return success;
     113      return false;
    105114    }
    106115
    107116    // state for recursive transformation of trees
    108     private readonly List<double> initialConstants;
     117    private readonly
     118    List<double> initialConstants;
    109119    private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
    110120    private readonly List<AutoDiff.Variable> variables;
     
    118128    }
    119129
    120     private bool TryConvertToAutoDiff(ISymbolicExpressionTreeNode node, out AutoDiff.Term term) {
     130    private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
    121131      if (node.Symbol is Constant) {
    122132        initialConstants.Add(((ConstantTreeNode)node).Value);
    123133        var var = new AutoDiff.Variable();
    124134        variables.Add(var);
    125         term = var;
    126         return true;
     135        return var;
    127136      }
    128137      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
     
    137146          var w = new AutoDiff.Variable();
    138147          variables.Add(w);
    139           term = AutoDiff.TermBuilder.Product(w, par);
     148          return AutoDiff.TermBuilder.Product(w, par);
    140149        } else {
    141           term = varNode.Weight * par;
    142         }
    143         return true;
     150          return varNode.Weight * par;
     151        }
    144152      }
    145153      if (node.Symbol is FactorVariable) {
     
    155163          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
    156164        }
    157         term = AutoDiff.TermBuilder.Sum(products);
    158         return true;
     165        return AutoDiff.TermBuilder.Sum(products);
    159166      }
    160167      if (node.Symbol is LaggedVariable) {
     
    166173          var w = new AutoDiff.Variable();
    167174          variables.Add(w);
    168           term = AutoDiff.TermBuilder.Product(w, par);
     175          return AutoDiff.TermBuilder.Product(w, par);
    169176        } else {
    170           term = varNode.Weight * par;
    171         }
    172         return true;
     177          return varNode.Weight * par;
     178        }
    173179      }
    174180      if (node.Symbol is Addition) {
    175181        List<AutoDiff.Term> terms = new List<Term>();
    176182        foreach (var subTree in node.Subtrees) {
    177           AutoDiff.Term t;
    178           if (!TryConvertToAutoDiff(subTree, out t)) {
    179             term = null;
    180             return false;
    181           }
    182           terms.Add(t);
    183         }
    184         term = AutoDiff.TermBuilder.Sum(terms);
    185         return true;
     183          terms.Add(ConvertToAutoDiff(subTree));
     184        }
     185        return AutoDiff.TermBuilder.Sum(terms);
    186186      }
    187187      if (node.Symbol is Subtraction) {
    188188        List<AutoDiff.Term> terms = new List<Term>();
    189189        for (int i = 0; i < node.SubtreeCount; i++) {
    190           AutoDiff.Term t;
    191           if (!TryConvertToAutoDiff(node.GetSubtree(i), out t)) {
    192             term = null;
    193             return false;
    194           }
     190          AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i));
    195191          if (i > 0) t = -t;
    196192          terms.Add(t);
    197193        }
    198         if (terms.Count == 1) term = -terms[0];
    199         else term = AutoDiff.TermBuilder.Sum(terms);
    200         return true;
     194        if (terms.Count == 1) return -terms[0];
     195        else return AutoDiff.TermBuilder.Sum(terms);
    201196      }
    202197      if (node.Symbol is Multiplication) {
    203198        List<AutoDiff.Term> terms = new List<Term>();
    204199        foreach (var subTree in node.Subtrees) {
    205           AutoDiff.Term t;
    206           if (!TryConvertToAutoDiff(subTree, out t)) {
    207             term = null;
    208             return false;
    209           }
    210           terms.Add(t);
    211         }
    212         if (terms.Count == 1) term = terms[0];
    213         else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
    214         return true;
    215 
     200          terms.Add(ConvertToAutoDiff(subTree));
     201        }
     202        if (terms.Count == 1) return terms[0];
     203        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
    216204      }
    217205      if (node.Symbol is Division) {
    218206        List<AutoDiff.Term> terms = new List<Term>();
    219207        foreach (var subTree in node.Subtrees) {
    220           AutoDiff.Term t;
    221           if (!TryConvertToAutoDiff(subTree, out t)) {
    222             term = null;
    223             return false;
    224           }
    225           terms.Add(t);
    226         }
    227         if (terms.Count == 1) term = 1.0 / terms[0];
    228         else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
    229         return true;
     208          terms.Add(ConvertToAutoDiff(subTree));
     209        }
     210        if (terms.Count == 1) return 1.0 / terms[0];
     211        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
    230212      }
    231213      if (node.Symbol is Logarithm) {
    232         AutoDiff.Term t;
    233         if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
    234           term = null;
    235           return false;
    236         } else {
    237           term = AutoDiff.TermBuilder.Log(t);
    238           return true;
    239         }
     214        return AutoDiff.TermBuilder.Log(
     215          ConvertToAutoDiff(node.GetSubtree(0)));
    240216      }
    241217      if (node.Symbol is Exponential) {
    242         AutoDiff.Term t;
    243         if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
    244           term = null;
    245           return false;
    246         } else {
    247           term = AutoDiff.TermBuilder.Exp(t);
    248           return true;
    249         }
     218        return AutoDiff.TermBuilder.Exp(
     219          ConvertToAutoDiff(node.GetSubtree(0)));
    250220      }
    251221      if (node.Symbol is Square) {
    252         AutoDiff.Term t;
    253         if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
    254           term = null;
    255           return false;
    256         } else {
    257           term = AutoDiff.TermBuilder.Power(t, 2.0);
    258           return true;
    259         }
     222        return AutoDiff.TermBuilder.Power(
     223          ConvertToAutoDiff(node.GetSubtree(0)), 2.0);
    260224      }
    261225      if (node.Symbol is SquareRoot) {
    262         AutoDiff.Term t;
    263         if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
    264           term = null;
    265           return false;
    266         } else {
    267           term = AutoDiff.TermBuilder.Power(t, 0.5);
    268           return true;
    269         }
     226        return AutoDiff.TermBuilder.Power(
     227          ConvertToAutoDiff(node.GetSubtree(0)), 0.5);
    270228      }
    271229      if (node.Symbol is Sine) {
    272         AutoDiff.Term t;
    273         if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
    274           term = null;
    275           return false;
    276         } else {
    277           term = sin(t);
    278           return true;
    279         }
     230        return sin(
     231          ConvertToAutoDiff(node.GetSubtree(0)));
    280232      }
    281233      if (node.Symbol is Cosine) {
    282         AutoDiff.Term t;
    283         if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
    284           term = null;
    285           return false;
    286         } else {
    287           term = cos(t);
    288           return true;
    289         }
     234        return cos(
     235          ConvertToAutoDiff(node.GetSubtree(0)));
    290236      }
    291237      if (node.Symbol is Tangent) {
    292         AutoDiff.Term t;
    293         if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
    294           term = null;
    295           return false;
    296         } else {
    297           term = tan(t);
    298           return true;
    299         }
     238        return tan(
     239          ConvertToAutoDiff(node.GetSubtree(0)));
    300240      }
    301241      if (node.Symbol is Erf) {
    302         AutoDiff.Term t;
    303         if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
    304           term = null;
    305           return false;
    306         } else {
    307           term = erf(t);
    308           return true;
    309         }
     242        return erf(
     243          ConvertToAutoDiff(node.GetSubtree(0)));
    310244      }
    311245      if (node.Symbol is Norm) {
    312         AutoDiff.Term t;
    313         if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
    314           term = null;
    315           return false;
    316         } else {
    317           term = norm(t);
    318           return true;
    319         }
     246        return norm(
     247          ConvertToAutoDiff(node.GetSubtree(0)));
    320248      }
    321249      if (node.Symbol is StartSymbol) {
     
    324252        variables.Add(beta);
    325253        variables.Add(alpha);
    326         AutoDiff.Term branchTerm;
    327         if (TryConvertToAutoDiff(node.GetSubtree(0), out branchTerm)) {
    328           term = branchTerm * alpha + beta;
    329           return true;
    330         } else {
    331           term = null;
    332           return false;
    333         }
    334       }
    335       term = null;
    336       return false;
     254        return ConvertToAutoDiff(node.GetSubtree(0)) * alpha + beta;
     255      }
     256      throw new ConversionException();
    337257    }
    338258
     
    357277        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
    358278        where
    359         !(n.Symbol is Variable) &&
    360         !(n.Symbol is BinaryFactorVariable) &&
    361         !(n.Symbol is FactorVariable) &&
    362         !(n.Symbol is LaggedVariable) &&
    363         !(n.Symbol is Constant) &&
    364         !(n.Symbol is Addition) &&
    365         !(n.Symbol is Subtraction) &&
    366         !(n.Symbol is Multiplication) &&
    367         !(n.Symbol is Division) &&
    368         !(n.Symbol is Logarithm) &&
    369         !(n.Symbol is Exponential) &&
    370         !(n.Symbol is SquareRoot) &&
    371         !(n.Symbol is Square) &&
    372         !(n.Symbol is Sine) &&
    373         !(n.Symbol is Cosine) &&
    374         !(n.Symbol is Tangent) &&
    375         !(n.Symbol is Erf) &&
    376         !(n.Symbol is Norm) &&
    377         !(n.Symbol is StartSymbol)
     279          !(n.Symbol is Variable) &&
     280          !(n.Symbol is BinaryFactorVariable) &&
     281          !(n.Symbol is FactorVariable) &&
     282          !(n.Symbol is LaggedVariable) &&
     283          !(n.Symbol is Constant) &&
     284          !(n.Symbol is Addition) &&
     285          !(n.Symbol is Subtraction) &&
     286          !(n.Symbol is Multiplication) &&
     287          !(n.Symbol is Division) &&
     288          !(n.Symbol is Logarithm) &&
     289          !(n.Symbol is Exponential) &&
     290          !(n.Symbol is SquareRoot) &&
     291          !(n.Symbol is Square) &&
     292          !(n.Symbol is Sine) &&
     293          !(n.Symbol is Cosine) &&
     294          !(n.Symbol is Tangent) &&
     295          !(n.Symbol is Erf) &&
     296          !(n.Symbol is Norm) &&
     297          !(n.Symbol is StartSymbol)
    378298        select n).Any();
    379299      return !containsUnknownSymbol;
    380300    }
     301    #region exception class
     302    [Serializable]
     303    public class ConversionException : Exception {
     304
     305      public ConversionException() {
     306      }
     307
     308      public ConversionException(string message) : base(message) {
     309      }
     310
     311      public ConversionException(string message, Exception inner) : base(message, inner) {
     312      }
     313
     314      protected ConversionException(
     315        SerializationInfo info,
     316        StreamingContext context) : base(info, context) {
     317      }
     318    }
     319    #endregion
    381320  }
    382321}
Note: See TracChangeset for help on using the changeset viewer.