Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/14/22 14:15:52 (3 years ago)
Author:
pfleck
Message:

#3040 Added sub-vector, std dev and variance support for Tree to Tensor converter.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToTensorConverter.cs

    r17554 r18237  
    120120      if (node.Symbol is Addition) {
    121121        var terms = node.Subtrees.Select(ConvertNode).ToList();
     122        if (terms.Count == 1) return terms[0];
    122123        return terms.Aggregate((a, b) => a + b);
    123124      }
     
    131132      if (node.Symbol is Multiplication) {
    132133        var terms = node.Subtrees.Select(ConvertNode).ToList();
     134        if (terms.Count == 1) return terms[0];
    133135        return terms.Aggregate((a, b) => a * b);
    134136      }
     
    152154
    153155      if (node.Symbol is Logarithm) {
    154         return math_ops.log(
     156        return tf.log(
    155157          ConvertNode(node.GetSubtree(0)));
    156158      }
    157159
    158160      if (node.Symbol is Exponential) {
    159         return math_ops.pow(
     161        return tf.pow(
    160162          (float)Math.E,
    161163          ConvertNode(node.GetSubtree(0)));
     
    168170
    169171      if (node.Symbol is SquareRoot) {
    170         return math_ops.sqrt(
     172        return tf.sqrt(
    171173          ConvertNode(node.GetSubtree(0)));
    172174      }
    173175
    174176      if (node.Symbol is Cube) {
    175         return math_ops.pow(
     177        return tf.pow(
    176178          ConvertNode(node.GetSubtree(0)), 3.0f);
    177179      }
    178180
    179181      if (node.Symbol is CubeRoot) {
    180         return math_ops.pow(
     182        return tf.pow(
    181183          ConvertNode(node.GetSubtree(0)), 1.0f / 3.0f);
    182184        // TODO
     
    207209      }
    208210
    209       //if (node.Symbol is StandardDeviation) {
    210       //  return tf.reduce_std(
    211       //    ConvertNode(node.GetSubtree(0)),
    212       //    axis: new [] { 1 }
    213       // );
    214       //}
     211      if (node.Symbol is StandardDeviation) {
     212        return reduce_std(
     213          ConvertNode(node.GetSubtree(0)),
     214          axis: new[] { 1 },
     215          keepdims: true
     216       );
     217      }
     218
     219      if (node.Symbol is Variance) {
     220        return reduce_var(
     221          ConvertNode(node.GetSubtree(0)),
     222          axis: new[] { 1 } ,
     223          keepdims: true
     224        );
     225      }
    215226
    216227      if (node.Symbol is Sum) {
     
    220231          keepdims: true);
    221232      }
     233
     234      if (node.Symbol is SubVector) {
     235        var tensor = ConvertNode(node.GetSubtree(0));
     236        int rows = tensor.shape[0], vectorLength = tensor.shape[1];
     237        var windowedNode = (IWindowedSymbolTreeNode)node;
     238        int startIdx = SymbolicDataAnalysisExpressionTreeVectorInterpreter.ToVectorIdx(windowedNode.Offset, vectorLength);
     239        int endIdx = SymbolicDataAnalysisExpressionTreeVectorInterpreter.ToVectorIdx(windowedNode.Length, vectorLength);
     240        var slices = SymbolicDataAnalysisExpressionTreeVectorInterpreter.GetVectorSlices(startIdx, endIdx, vectorLength);
     241
     242        var segments = new List<Tensor>();
     243        foreach (var (start, count) in slices) {
     244          segments.Add(tensor[new Slice(), new Slice(start, start + count)]);
     245        }
     246        return tf.concat(segments.ToArray(), axis: 1);
     247      }
     248
    222249
    223250      if (node.Symbol is StartSymbol) {
     
    237264        }
    238265
    239         return tf.reduce_sum(prediction, axis: new[] { 1 });
     266        return tf.reshape(prediction, shape: new[] { -1 });
    240267      }
    241268
    242269      throw new NotSupportedException($"Node symbol {node.Symbol} is not supported.");
     270    }
     271
     272    private static Tensor reduce_var(Tensor input_tensor,  int[] axis = null, bool keepdims = false) {
     273      var means = tf.reduce_mean(input_tensor, axis, true);
     274      var squared_deviation = tf.square(input_tensor - means);
     275      return tf.reduce_mean(squared_deviation, axis, keepdims);
     276    }
     277    private static Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false) {
     278      return tf.sqrt(reduce_var(input_tensor, axis, keepdims));
    243279    }
    244280
     
    271307          !(n.Symbol is CubeRoot) &&
    272308          !(n.Symbol is Mean) &&
    273           //!(n.Symbol is StandardDeviation) &&
    274           !(n.Symbol is Sum)
     309          !(n.Symbol is StandardDeviation) &&
     310          !(n.Symbol is Variance) &&
     311          !(n.Symbol is Sum) &&
     312          !(n.Symbol is SubVector)
    275313        select n).Any();
    276314      return !containsUnknownSymbol;
Note: See TracChangeset for help on using the changeset viewer.