debug tensorflow的seq2seq的attention_decoder方法
原创
TechOnly博主文章分类:TensorFlow ©著作权
文章标签 tensorflow python github 文章分类 Python 后端开发
©著作权归作者所有:来自51CTO博客作者TechOnly的原创作品,请联系作者获取转载授权,否则将追究法律责任
写这个attention_decoder的testcase来用debug的方式看看注意力机制的实现
import tensorflow as tffrom tensorflow.python.ops import rnnfrom tensorflow.python.ops import rnn_cellfrom tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_libwith tf.Session() as sess: batch_size = 16 step1 = 20 step2 = 10 input_size = 50 output_size = 40 gru_hidden = 30 cell_fn = lambda: rnn_cell.GRUCell(gru_hidden) cell = cell_fn() inp = [tf.constant(0.8, shape=[batch_size, input_size])] * step1 enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=tf.float32) attn_states = tf.concat([ tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs ], 1) dec_inp = [tf.constant(0.3, shape=[batch_size, output_size])] * step2 dec, mem = seq2seq_lib.attention_decoder( dec_inp, enc_state, attn_states, cell_fn(), output_size=7) sess.run([tf.global_variables_initializer()]) res = sess.run(dec) print(len(res)) print(res[0].shape) res = sess.run([mem]) print(len(res)) print(res[0].shape)
- 赞
- 收藏
- 评论
- *举报
上一篇:End-To-End Memory Networks 论文翻译
下一篇:tf.gather 实例
Original: https://blog.51cto.com/guotong1988/5485758
Author: TechOnly
Title: debug tensorflow的seq2seq的attention_decoder方法
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/513411/
转载文章受原作者版权保护。转载请注明原作者出处!