Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/26/20 16:43:25 (4 years ago)
Author:
pfleck
Message:

#3040 Added a constant opt evaluator for vectors that uses the existing AutoDiff library by unrolling all vector operations.

File:
1 copied

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/VectorUnrollingTreeToAutoDiffTermConverter.cs

    r17725 r17726  
    2525using System.Runtime.Serialization;
    2626using AutoDiff;
     27using HeuristicLab.Common;
    2728using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2829
    2930namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
    30   public class TreeToAutoDiffTermConverter {
     31  public class VectorUnrollingTreeToAutoDiffTermConverter {
    3132    public delegate double ParametricFunction(double[] vars, double[] @params);
    3233
     
    3839      public readonly string variableValue; // for factor vars
    3940      public readonly int lag;
    40 
    41       public DataForVariable(string varName, string varValue, int lag) {
     41      public readonly int index; // for vectors
     42
     43      public DataForVariable(string varName, string varValue, int lag, int index) {
    4244        this.variableName = varName;
    4345        this.variableValue = varValue;
    4446        this.lag = lag;
     47        this.index = index;
    4548      }
    4649
     
    5053        return other.variableName.Equals(this.variableName) &&
    5154               other.variableValue.Equals(this.variableValue) &&
    52                other.lag == this.lag;
     55               other.lag == this.lag &&
     56               other.index == this.index;
    5357      }
    5458
    5559      public override int GetHashCode() {
    56         return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
     60        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag ^ index;
    5761      }
    5862    }
     
    101105    #endregion
    102106
    103     public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms,
     107    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree,
     108      IDictionary<ISymbolicExpressionTreeNode, SymbolicDataAnalysisExpressionTreeVectorInterpreter.EvaluationResult> evaluationTrace,
     109      bool makeVariableWeightsVariable, bool addLinearScalingTerms,
    104110      out List<DataForVariable> parameters, out double[] initialConstants,
    105111      out ParametricFunction func,
     
    107113
    108114      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
    109       var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable, addLinearScalingTerms);
    110       AutoDiff.Term term;
     115      var transformator = new VectorUnrollingTreeToAutoDiffTermConverter(evaluationTrace,
     116        makeVariableWeightsVariable, addLinearScalingTerms);
     117      Term term;
    111118      try {
    112         term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
     119        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0)).Single();
    113120        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
    114121        var compiledTerm = term.Compile(transformator.variables.ToArray(),
     
    128135    }
    129136
     137    private readonly IDictionary<ISymbolicExpressionTreeNode, SymbolicDataAnalysisExpressionTreeVectorInterpreter.EvaluationResult> evaluationTrace;
    130138    // state for recursive transformation of trees
    131     private readonly
    132     List<double> initialConstants;
     139    private readonly List<double> initialConstants;
    133140    private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
    134141    private readonly List<AutoDiff.Variable> variables;
     
    136143    private readonly bool addLinearScalingTerms;
    137144
    138     private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
     145    private VectorUnrollingTreeToAutoDiffTermConverter(IDictionary<ISymbolicExpressionTreeNode, SymbolicDataAnalysisExpressionTreeVectorInterpreter.EvaluationResult> evaluationTrace,
     146      bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
     147      this.evaluationTrace = evaluationTrace;
    139148      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
    140149      this.addLinearScalingTerms = addLinearScalingTerms;
     
    144153    }
    145154
    146     private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
    147       if (node.Symbol is Constant) {
     155    private IList<AutoDiff.Term> ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
     156      IList<Term> BinaryOp(Func<Term, Term, Term> binaryOp, Func<Term, Term> singleElementOp, params IList<Term>[] terms) {
     157        if (terms.Length == 1) return terms[0].Select(singleElementOp).ToList();
     158        return terms.Aggregate((acc, vectorizedTerm) => acc.Zip(vectorizedTerm, binaryOp).ToList());
     159      }
     160      IList<Term> BinaryOp2(Func<Term, Term, Term> binaryOp, params IList<Term>[] terms) {
     161        return terms.Aggregate((acc, vectorizedTerm) => acc.Zip(vectorizedTerm, binaryOp).ToList());
     162      }
     163      IList<Term> UnaryOp(Func<Term, Term> unaryOp, IList<Term> term) {
     164        return term.Select(unaryOp).ToList();
     165      }
     166
     167      var evaluationResult = evaluationTrace[node];
     168
     169      if (node.Symbol is Constant) { // assume scalar constant
    148170        initialConstants.Add(((ConstantTreeNode)node).Value);
    149171        var var = new AutoDiff.Variable();
    150172        variables.Add(var);
    151         return var;
     173        return new Term[] { var };
    152174      }
    153175      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
     
    156178        // factor variable values are only 0 or 1 and set in x accordingly
    157179        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
    158         var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
     180        var pars = evaluationResult.IsVector
     181          ? Enumerable.Range(0, evaluationResult.Vector.Count).Select(i => FindOrCreateParameter(parameters, varNode.VariableName, varValue, index: i))
     182          : FindOrCreateParameter(parameters, varNode.VariableName, varValue).ToEnumerable();
    159183
    160184        if (makeVariableWeightsVariable) {
     
    162186          var w = new AutoDiff.Variable();
    163187          variables.Add(w);
    164           return AutoDiff.TermBuilder.Product(w, par);
     188          return pars.Select(par => AutoDiff.TermBuilder.Product(w, par)).ToList();
    165189        } else {
    166           return varNode.Weight * par;
     190          return pars.Select(par => varNode.Weight * par).ToList();
    167191        }
    168192      }
     
    179203          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
    180204        }
    181         return AutoDiff.TermBuilder.Sum(products);
    182       }
    183       if (node.Symbol is LaggedVariable) {
    184         var varNode = node as LaggedVariableTreeNode;
    185         var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
    186 
    187         if (makeVariableWeightsVariable) {
    188           initialConstants.Add(varNode.Weight);
    189           var w = new AutoDiff.Variable();
    190           variables.Add(w);
    191           return AutoDiff.TermBuilder.Product(w, par);
    192         } else {
    193           return varNode.Weight * par;
    194         }
    195       }
     205        return new[] { AutoDiff.TermBuilder.Sum(products) };
     206      }
     207      //if (node.Symbol is LaggedVariable) {
     208      //  var varNode = node as LaggedVariableTreeNode;
     209      //  var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
     210
     211      //  if (makeVariableWeightsVariable) {
     212      //    initialConstants.Add(varNode.Weight);
     213      //    var w = new AutoDiff.Variable();
     214      //    variables.Add(w);
     215      //    return AutoDiff.TermBuilder.Product(w, par);
     216      //  } else {
     217      //    return varNode.Weight * par;
     218      //  }
     219      //}
    196220      if (node.Symbol is Addition) {
    197         List<AutoDiff.Term> terms = new List<Term>();
    198         foreach (var subTree in node.Subtrees) {
    199           terms.Add(ConvertToAutoDiff(subTree));
    200         }
    201         return AutoDiff.TermBuilder.Sum(terms);
     221        var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray();
     222        return BinaryOp((a, b) => a + b, a => a, terms);
    202223      }
    203224      if (node.Symbol is Subtraction) {
    204         List<AutoDiff.Term> terms = new List<Term>();
    205         for (int i = 0; i < node.SubtreeCount; i++) {
    206           AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i));
    207           if (i > 0) t = -t;
    208           terms.Add(t);
    209         }
    210         if (terms.Count == 1) return -terms[0];
    211         else return AutoDiff.TermBuilder.Sum(terms);
     225        var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray();
     226        return BinaryOp((a, b) => a - b, a => -a, terms);
    212227      }
    213228      if (node.Symbol is Multiplication) {
    214         List<AutoDiff.Term> terms = new List<Term>();
    215         foreach (var subTree in node.Subtrees) {
    216           terms.Add(ConvertToAutoDiff(subTree));
    217         }
    218         if (terms.Count == 1) return terms[0];
    219         else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
     229        var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray();
     230        return BinaryOp((a, b) => a * b, a => a, terms);
    220231      }
    221232      if (node.Symbol is Division) {
    222         List<AutoDiff.Term> terms = new List<Term>();
    223         foreach (var subTree in node.Subtrees) {
    224           terms.Add(ConvertToAutoDiff(subTree));
    225         }
    226         if (terms.Count == 1) return 1.0 / terms[0];
    227         else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
     233        var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray();
     234        return BinaryOp((a, b) => a / b, a => 1.0 / a, terms);
    228235      }
    229236      if (node.Symbol is Absolute) {
    230         var x1 = ConvertToAutoDiff(node.GetSubtree(0));
    231         return abs(x1);
    232       }
    233       if (node.Symbol is AnalyticQuotient) {
    234         var x1 = ConvertToAutoDiff(node.GetSubtree(0));
    235         var x2 = ConvertToAutoDiff(node.GetSubtree(1));
    236         return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5));
    237       }
     237        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     238        return UnaryOp(abs, term);
     239      }
     240      //if (node.Symbol is AnalyticQuotient) {
     241      //  var x1 = ConvertToAutoDiff(node.GetSubtree(0));
     242      //  var x2 = ConvertToAutoDiff(node.GetSubtree(1));
     243      //  return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5));
     244      //}
    238245      if (node.Symbol is Logarithm) {
    239         return AutoDiff.TermBuilder.Log(
    240           ConvertToAutoDiff(node.GetSubtree(0)));
     246        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     247        return UnaryOp(TermBuilder.Log, term);
    241248      }
    242249      if (node.Symbol is Exponential) {
    243         return AutoDiff.TermBuilder.Exp(
    244           ConvertToAutoDiff(node.GetSubtree(0)));
     250        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     251        return UnaryOp(TermBuilder.Exp, term);
    245252      }
    246253      if (node.Symbol is Square) {
    247         return AutoDiff.TermBuilder.Power(
    248           ConvertToAutoDiff(node.GetSubtree(0)), 2.0);
     254        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     255        return UnaryOp(t => TermBuilder.Power(t, 2.0), term);
    249256      }
    250257      if (node.Symbol is SquareRoot) {
    251         return AutoDiff.TermBuilder.Power(
    252           ConvertToAutoDiff(node.GetSubtree(0)), 0.5);
     258        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     259        return UnaryOp(t => TermBuilder.Power(t, 0.5), term);
    253260      }
    254261      if (node.Symbol is Cube) {
    255         return AutoDiff.TermBuilder.Power(
    256           ConvertToAutoDiff(node.GetSubtree(0)), 3.0);
     262        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     263        return UnaryOp(t => TermBuilder.Power(t, 3.0), term);
    257264      }
    258265      if (node.Symbol is CubeRoot) {
    259         return cbrt(ConvertToAutoDiff(node.GetSubtree(0)));
     266        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     267        return UnaryOp(cbrt, term);
    260268      }
    261269      if (node.Symbol is Sine) {
    262         return sin(
    263           ConvertToAutoDiff(node.GetSubtree(0)));
     270        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     271        return UnaryOp(sin, term);
    264272      }
    265273      if (node.Symbol is Cosine) {
    266         return cos(
    267           ConvertToAutoDiff(node.GetSubtree(0)));
     274        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     275        return UnaryOp(cos, term);
    268276      }
    269277      if (node.Symbol is Tangent) {
    270         return tan(
    271           ConvertToAutoDiff(node.GetSubtree(0)));
     278        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     279        return UnaryOp(tan, term);
    272280      }
    273281      if (node.Symbol is HyperbolicTangent) {
    274         return tanh(
    275           ConvertToAutoDiff(node.GetSubtree(0)));
     282        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     283        return UnaryOp(tanh, term);
    276284      }
    277285      if (node.Symbol is Erf) {
    278         return erf(
    279           ConvertToAutoDiff(node.GetSubtree(0)));
     286        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     287        return UnaryOp(erf, term);
    280288      }
    281289      if (node.Symbol is Norm) {
    282         return norm(
    283           ConvertToAutoDiff(node.GetSubtree(0)));
     290        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     291        return UnaryOp(norm, term);
    284292      }
    285293      if (node.Symbol is StartSymbol) {
     
    291299          variables.Add(alpha);
    292300          var t = ConvertToAutoDiff(node.GetSubtree(0));
    293           return t * alpha + beta;
     301          if (t.Count > 1) throw new InvalidOperationException("Tree Result must be scalar value");
     302          return new[] { t[0] * alpha + beta };
    294303        } else return ConvertToAutoDiff(node.GetSubtree(0));
    295304      }
     305      if (node.Symbol is Sum) {
     306        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     307        return new[] { TermBuilder.Sum(term) };
     308      }
     309      if (node.Symbol is Mean) {
     310        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     311        return new[] { TermBuilder.Sum(term) / term.Count };
     312      }
     313      if (node.Symbol is StandardDeviation) {
     314        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     315        var mean = TermBuilder.Sum(term) / term.Count;
     316        var ssd = TermBuilder.Sum(term.Select(t => TermBuilder.Power(t - mean, 2.0)));
     317        return new[] { TermBuilder.Power(ssd / term.Count, 0.5) };
     318      }
     319      if (node.Symbol is Length) {
     320        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     321        return new[] { TermBuilder.Constant(term.Count) };
     322      }
     323      //if (node.Symbol is Min) {
     324      //}
     325      //if (node.Symbol is Max) {
     326      //}
     327      if (node.Symbol is Variance) {
     328        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
     329        var mean = TermBuilder.Sum(term) / term.Count;
     330        var ssd = TermBuilder.Sum(term.Select(t => TermBuilder.Power(t - mean, 2.0)));
     331        return new[] { ssd / term.Count };
     332      }
     333      //if (node.Symbol is Skewness) {
     334      //}
     335      //if (node.Symbol is Kurtosis) {
     336      //}
     337      //if (node.Symbol is EuclideanDistance) {
     338      //}
     339      //if (node.Symbol is Covariance) {
     340      //}
     341
     342
    296343      throw new ConversionException();
    297344    }
     
    301348    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
    302349    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
    303       string varName, string varValue = "", int lag = 0) {
    304       var data = new DataForVariable(varName, varValue, lag);
     350      string varName, string varValue = "", int lag = 0, int index = -1) {
     351      var data = new DataForVariable(varName, varValue, lag, index);
    305352
    306353      AutoDiff.Variable par = null;
     
    319366          !(n.Symbol is Variable) &&
    320367          !(n.Symbol is BinaryFactorVariable) &&
    321           !(n.Symbol is FactorVariable) &&
    322           !(n.Symbol is LaggedVariable) &&
     368          //!(n.Symbol is FactorVariable) &&
     369          //!(n.Symbol is LaggedVariable) &&
    323370          !(n.Symbol is Constant) &&
    324371          !(n.Symbol is Addition) &&
     
    338385          !(n.Symbol is StartSymbol) &&
    339386          !(n.Symbol is Absolute) &&
    340           !(n.Symbol is AnalyticQuotient) &&
     387          //!(n.Symbol is AnalyticQuotient) &&
    341388          !(n.Symbol is Cube) &&
    342           !(n.Symbol is CubeRoot)
     389          !(n.Symbol is CubeRoot) &&
     390          !(n.Symbol is Sum) &&
     391          !(n.Symbol is Mean) &&
     392          !(n.Symbol is StandardDeviation) &&
     393          !(n.Symbol is Length) &&
     394          //!(n.Symbol is Min) &&
     395          //!(n.Symbol is Max) &&
     396          !(n.Symbol is Variance)
     397        //!(n.Symbol is Skewness) &&
     398        //!(n.Symbol is Kurtosis) &&
     399        //!(n.Symbol is EuclideanDistance) &&
     400        //!(n.Symbol is Covariance)
    343401        select n).Any();
    344402      return !containsUnknownSymbol;
Note: See TracChangeset for help on using the changeset viewer.