tensorflow获取动态shape

相同点:都可以得到tensor a的尺寸

不同点:tf.shape()中a 数据的类型可以是tensor, list, array

a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)

如果需要根据上一层的动态shape计算当前层的shape,该如何做呢

x=tf.placeholder(tf.float32, shape=[None, 227,227,3] )

但在运行的时候想知道 None到底是多少,这时候,只能通过 tf.shape(x)[0]这种方式来获得.输出的shape,例如:

out_shape=[tf.shape(x)[0],tf.shape(x)[1],tf.shape(x)[2]tf.shape(x)[3]]

但是 tf.shape(x)[0]的值是None,这样是无法直接指定shape的,程序会崩

由于返回的时tensor,所以我们可以使用其他tensorflow节点操作进行处理,如下面的转置卷积中,使用stack来合并各个shape的分量,

def conv2d_transpose(x, input_filters, output_filters, kernel, strides):
with tf.variable_scope( 'conv_transpose' ):
shape  = [kernel, kernel, output_filters, input_filters]
weight  = tf.Variable(tf.truncated_normal(shape, stddev = 0.1 ), name = 'weight' )
batch_size  = tf.shape(x)[ 0 ]
height  = tf.shape(x)[ 1 ]  * strides
width  = tf.shape(x)[ 2 ]  * strides
output_shape  = tf.stack([batch_size, height, width, output_filters])

return tf.nn.conv2d_transpose(x, weight, output_shape, strides = [ 1 , strides, strides,  1 ], name = 'conv_transpose' )

这样就可以将数值和tensor合并转为Tensor,程序就不会报错啦

Original: https://blog.csdn.net/ab0902cd/article/details/123872696
Author: ab0902cd
Title: tensorflow获取动态shape

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/508999/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球