bert主要的实现是基于transformer的encoder部分,参数维度不同的地方是1)输入多了一项segment embedding,2)中间维度基本是768,以及多头注意力以及前向网络重复了12次。
在统计bert参数的时候,一共要考虑5部分。
1)第一部分:输入层包含三项
token embedding词表大小768position embmax_len(512768)segment emb两个取值0,1(2*768)
2)第二部分:多头注意力
12个头,其中每个头包括Q\K\V三组参数
768(原始维度)768/12(每个头的q\k\v的维度)3*12(头的个数)
然后concat起来所有输出,再变换一下 768*768+768
3)第三部分:Add and Norm
add不需要参数,norm有两个参数需要学习:shift和scale(2*768)
4)第四部分:前向网络
两层全连接网络(W,b):第一层是768*3072(4H)+3072
第二层是3072*768+768
5)第五部分:Add and Norm
同第三部分:2*768
总参数: 第一部分+12*(第二+第三+第四+第五部分)
Original: https://blog.csdn.net/baoyan2015/article/details/121206765
Author: samoyan
Title: bert参数统计
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/531908/
转载文章受原作者版权保护。转载请注明原作者出处!