debug tensorflow的seq2seq的attention_decoder方法

debug tensorflow的seq2seq的attention_decoder方法

原创

TechOnly博主文章分类:TensorFlow ©著作权

文章标签 tensorflow python github 文章分类 Python 后端开发

©著作权归作者所有:来自51CTO博客作者TechOnly的原创作品,请联系作者获取转载授权,否则将追究法律责任

写这个attention_decoder的testcase来用debug的方式看看注意力机制的实现

import tensorflow as tffrom tensorflow.python.ops import rnnfrom tensorflow.python.ops import rnn_cellfrom tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_libwith tf.Session() as sess:    batch_size = 16    step1 = 20    step2 = 10    input_size = 50    output_size = 40    gru_hidden = 30    cell_fn = lambda: rnn_cell.GRUCell(gru_hidden)    cell = cell_fn()    inp = [tf.constant(0.8, shape=[batch_size, input_size])] * step1    enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=tf.float32)    attn_states = tf.concat([        tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs    ], 1)    dec_inp = [tf.constant(0.3, shape=[batch_size, output_size])] * step2    dec, mem = seq2seq_lib.attention_decoder(        dec_inp, enc_state, attn_states, cell_fn(), output_size=7)    sess.run([tf.global_variables_initializer()])    res = sess.run(dec)    print(len(res))    print(res[0].shape)    res = sess.run([mem])    print(len(res))    print(res[0].shape)

改编自​​https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py​

  • 收藏
  • 评论
  • *举报

上一篇:End-To-End Memory Networks 论文翻译

下一篇:tf.gather 实例

Original: https://blog.51cto.com/guotong1988/5485758
Author: TechOnly
Title: debug tensorflow的seq2seq的attention_decoder方法

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/513411/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球