TensorFlow Keras LSTM 输出解释

参考文章:
What does Tensorflow LSTM return?
Tensorflow RNN LSTM output explanation

>>> inputs = tf.random.normal([32, 10, 8])
>>> lstm = tf.keras.layers.LSTM(4)
>>> output = lstm(inputs)
>>> print(output.shape)
(32, 4)
>>> lstm = tf.keras.layers.LSTM(4, return_sequences=True, return_state=True)
>>> whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)
>>> print(whole_seq_output.shape)
(32, 10, 4)
>>> print(final_memory_state.shape)
(32, 4)
>>> print(final_carry_state.shape)
(32, 4)

TensorFlow Keras LSTM 输出解释
其中图里上方的输出h t h_t h t ​可以视为o t o_t o t ​

在Keras中如果 return_state=True则LSTM单元有三个输出,分别为

  • 一个输出状态(output state)o t o_t o t ​
  • 一个隐藏状态(hidden state)h t h_t h t ​
  • 一个单元状态(cell state)c t c_t c t ​

keras 文档中给出的写法如下:

whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)

在文档中,他们不使用隐藏和单元状态这些术语。他们使用memory state表示短期记忆,即上面提到的隐藏状态。用carry state 通过所有LSTM单元,即上面提到的单元状态。

以下是前向传播的部分源代码

[En]

The following is part of the source code for forward propagation

def step(cell_inputs, cell_states):
    """Step function that will be used by Keras RNN backend."""
    h_tm1 = cell_states[0]
    c_tm1 = cell_states[2]

    z = backend.dot(cell_inputs, kernel)
    z += backend.dot(h_tm1, recurrent_kernel)
    z = backend.bias_add(z, bias)

    z0, z1, z2, z3 = array_ops.split(z, 4, axis=1)

    i = nn.sigmoid(z0)
    f = nn.sigmoid(z1)
    c = f * c_tm1 + i * nn.tanh(z2)
    o = nn.sigmoid(z3)

    h = o * nn.tanh(c)
    return h, [h, c]

从源码中可以看出,第一个和第二个输出是output/hidden state,第三个输出是cell state。并且从注释中可以看出,将hidden state 命名为 memory state ;将cell state 命名为 carry state。

return_sequences=True时,whole_seq_output是整个序列的输出,维度为(batch_size,seq_length,units)。
return_sequences=False时,whole_seq_output是最后一个单元的输出,维度为(batch_size,units),此时与第二个输出相同。

Original: https://blog.csdn.net/qq_50710984/article/details/124075216
Author: 深海里的鱼(・ω<)★
Title: TensorFlow Keras LSTM 输出解释

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/497488/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球