Free cookie consent management tool by TermsFeed Policy Generator

Changeset 18237


Ignore:
Timestamp:
03/14/22 14:15:52 (2 years ago)
Author:
pfleck
Message:

#3040 Added sub-vector, std dev and variance support for Tree to Tensor converter.

Location:
branches/3040_VectorBasedGP
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/VectorUnrollingNonlinearLeastSquaresConstantOptimizationEvaluator.cs

    r18234 r18237  
    182182      int i = 0;
    183183      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
    184         ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
    185         VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase;
    186         FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
    187         if (constantTreeNode != null)
     184        if (node is ConstantTreeNode constantTreeNode) {
    188185          constantTreeNode.Value = constants[i++];
    189         else if (updateVariableWeights && variableTreeNodeBase != null)
     186        } else if (updateVariableWeights && node is VariableTreeNodeBase variableTreeNodeBase) {
    190187          variableTreeNodeBase.Weight = constants[i++];
    191         else if (factorVarTreeNode != null) {
     188        } else if (node is FactorVariableTreeNode factorVarTreeNode) {
    192189          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
    193190            factorVarTreeNode.Weights[j] = constants[i++];
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToTensorConverter.cs

    r17554 r18237  
    120120      if (node.Symbol is Addition) {
    121121        var terms = node.Subtrees.Select(ConvertNode).ToList();
     122        if (terms.Count == 1) return terms[0];
    122123        return terms.Aggregate((a, b) => a + b);
    123124      }
     
    131132      if (node.Symbol is Multiplication) {
    132133        var terms = node.Subtrees.Select(ConvertNode).ToList();
     134        if (terms.Count == 1) return terms[0];
    133135        return terms.Aggregate((a, b) => a * b);
    134136      }
     
    152154
    153155      if (node.Symbol is Logarithm) {
    154         return math_ops.log(
     156        return tf.log(
    155157          ConvertNode(node.GetSubtree(0)));
    156158      }
    157159
    158160      if (node.Symbol is Exponential) {
    159         return math_ops.pow(
     161        return tf.pow(
    160162          (float)Math.E,
    161163          ConvertNode(node.GetSubtree(0)));
     
    168170
    169171      if (node.Symbol is SquareRoot) {
    170         return math_ops.sqrt(
     172        return tf.sqrt(
    171173          ConvertNode(node.GetSubtree(0)));
    172174      }
    173175
    174176      if (node.Symbol is Cube) {
    175         return math_ops.pow(
     177        return tf.pow(
    176178          ConvertNode(node.GetSubtree(0)), 3.0f);
    177179      }
    178180
    179181      if (node.Symbol is CubeRoot) {
    180         return math_ops.pow(
     182        return tf.pow(
    181183          ConvertNode(node.GetSubtree(0)), 1.0f / 3.0f);
    182184        // TODO
     
    207209      }
    208210
    209       //if (node.Symbol is StandardDeviation) {
    210       //  return tf.reduce_std(
    211       //    ConvertNode(node.GetSubtree(0)),
    212       //    axis: new [] { 1 }
    213       // );
    214       //}
     211      if (node.Symbol is StandardDeviation) {
     212        return reduce_std(
     213          ConvertNode(node.GetSubtree(0)),
     214          axis: new[] { 1 },
     215          keepdims: true
     216       );
     217      }
     218
     219      if (node.Symbol is Variance) {
     220        return reduce_var(
     221          ConvertNode(node.GetSubtree(0)),
     222          axis: new[] { 1 } ,
     223          keepdims: true
     224        );
     225      }
    215226
    216227      if (node.Symbol is Sum) {
     
    220231          keepdims: true);
    221232      }
     233
     234      if (node.Symbol is SubVector) {
     235        var tensor = ConvertNode(node.GetSubtree(0));
     236        int rows = tensor.shape[0], vectorLength = tensor.shape[1];
     237        var windowedNode = (IWindowedSymbolTreeNode)node;
     238        int startIdx = SymbolicDataAnalysisExpressionTreeVectorInterpreter.ToVectorIdx(windowedNode.Offset, vectorLength);
     239        int endIdx = SymbolicDataAnalysisExpressionTreeVectorInterpreter.ToVectorIdx(windowedNode.Length, vectorLength);
     240        var slices = SymbolicDataAnalysisExpressionTreeVectorInterpreter.GetVectorSlices(startIdx, endIdx, vectorLength);
     241
     242        var segments = new List<Tensor>();
     243        foreach (var (start, count) in slices) {
     244          segments.Add(tensor[new Slice(), new Slice(start, start + count)]);
     245        }
     246        return tf.concat(segments.ToArray(), axis: 1);
     247      }
     248
    222249
    223250      if (node.Symbol is StartSymbol) {
     
    237264        }
    238265
    239         return tf.reduce_sum(prediction, axis: new[] { 1 });
     266        return tf.reshape(prediction, shape: new[] { -1 });
    240267      }
    241268
    242269      throw new NotSupportedException($"Node symbol {node.Symbol} is not supported.");
     270    }
     271
     272    private static Tensor reduce_var(Tensor input_tensor,  int[] axis = null, bool keepdims = false) {
     273      var means = tf.reduce_mean(input_tensor, axis, true);
     274      var squared_deviation = tf.square(input_tensor - means);
     275      return tf.reduce_mean(squared_deviation, axis, keepdims);
     276    }
     277    private static Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false) {
     278      return tf.sqrt(reduce_var(input_tensor, axis, keepdims));
    243279    }
    244280
     
    271307          !(n.Symbol is CubeRoot) &&
    272308          !(n.Symbol is Mean) &&
    273           //!(n.Symbol is StandardDeviation) &&
    274           !(n.Symbol is Sum)
     309          !(n.Symbol is StandardDeviation) &&
     310          !(n.Symbol is Variance) &&
     311          !(n.Symbol is Sum) &&
     312          !(n.Symbol is SubVector)
    275313        select n).Any();
    276314      return !containsUnknownSymbol;
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/OpCodes.cs

    r18060 r18237  
    339339      { typeof(NumberPeaksOfSize), OpCodes.NumberPeaksOfSize },
    340340      { typeof(LargeNumberOfPeaks), OpCodes.LargeNumberOfPeaks },
    341       { typeof(TimeReversalAsymmetryStatistic), OpCodes.TimeReversalAsymmetryStatistic },             
     341      { typeof(TimeReversalAsymmetryStatistic), OpCodes.TimeReversalAsymmetryStatistic },
    342342      #endregion
    343343    };
Note: See TracChangeset for help on using the changeset viewer.