简单seq2seq代码 使用tensorflow的LSTMCell构造循环decoder

好多预测模型的论文都是用seq2seq实现的,具体是LSTM_encoder将输入序列编码为一个tensor(又叫output、H或Y),同时保留序列状态state(又叫w或c);
LSTM_decoder继承encoder的状态,将上层的output作为输入,得到的每个输出到embeding中找对应的词向量,然后再次调用LSTM_decoder刚才的输出作为这次的输入。一直循环,直到输出EOS为止。

tensorflow中并没有循环网络(可能有,我不知道)。因此决定用LSTMCell循环实现。
代码如下:

class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.encodeLSTM = layers.LSTM(1, return_state=True)
        self.decodeLSTM = layers.LSTMCell(1)

    def call(self, inputs):
        x, memory_state, carry_state = self.encodeLSTM(inputs)
        pred = tf.constant([], shape=(16, 0), dtype=tf.float32)
        for i in range(8):
            x,[memory_state, carry_state] = self.decodeLSTM(x,[memory_state, carry_state])
            pred = tf.concat((pred, x),axis=1 )
        return pred

model = MyModel()
model.build((None, 10, 1))
model.summary()

简单seq2seq代码 使用tensorflow的LSTMCell构造循环decoder
LSTMCell一般与RNN组合使用,例
cell=[layer.LSTMCell(10) , layers.LSTMCell(5)]
layers.RNN(cell)

单独使用时请注意几点:

[En]

Note a few points when using it alone:

①LSTMCell帮助文档中没有关于状态的参数,需要从**kwargs传入。
②LSTMCell的状态不能保留,因此它每一次运算都会返回当前状态,以便下一次继续使用。
③LSTMCell由于不处理时间序列time_seq,它的输入格式为(batch_size,units)和输出格式相同。对比LSTM输入(batch_size,time_seq,units)输出(batch_size,units)

感谢官方文档和github教我的用法
https://tensorflow.google.cn/api_docs/python/tf/keras/layers/LSTMCell
https://github.com/search?q=tensorflow+layers.LSTMCell&type=Code

Original: https://blog.csdn.net/Loutre_star/article/details/124015331
Author: Loutre_star
Title: 简单seq2seq代码 使用tensorflow的LSTMCell构造循环decoder

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/508807/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球