bert参数统计

bert主要的实现是基于transformer的encoder部分,参数维度不同的地方是1)输入多了一项segment embedding,2)中间维度基本是768,以及多头注意力以及前向网络重复了12次。

在统计bert参数的时候,一共要考虑5部分。

1)第一部分:输入层包含三项

token embedding词表大小768position embmax_len(512768)segment emb两个取值0,1(2*768)

2)第二部分:多头注意力

12个头,其中每个头包括Q\K\V三组参数

768(原始维度)768/12(每个头的q\k\v的维度)3*12(头的个数)

然后concat起来所有输出,再变换一下 768*768+768

3)第三部分:Add and Norm

add不需要参数,norm有两个参数需要学习:shift和scale(2*768)

4)第四部分:前向网络

两层全连接网络(W,b):第一层是768*3072(4H)+3072

第二层是3072*768+768

5)第五部分:Add and Norm

同第三部分:2*768

总参数: 第一部分+12*(第二+第三+第四+第五部分)

Original: https://blog.csdn.net/baoyan2015/article/details/121206765
Author: samoyan
Title: bert参数统计

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/531908/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球