原理:后续补充链接
代码位置: https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
函数定义:主要包括如下几个函数
subsequent_chunk_mask:
参数:size,chunk_size, num_left_chunks,device
作用:创建chunk格式的mask。如输入size=4,chunk_size=2, num_left_chunks=-1,device=”cpu”,如下图
make_pad_mask:
作用:使mask补0的位置为1,其余位置为0
make_finished_scores:
参数:score, flag
作用:当一个sequence完成后,只允许保留一个alive branch。这个函数就是将一个branch的score置为0,而其他branch的score置为-inf
score: 一个实数数组,大小为(batch_size*beam_size,beam_size)
flag: bool型数组,大小为(batch_size,1)
在识别的解码阶段被调用,https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/asr_model.py
其主要作用应该使当预测完后就用eos进行mask
flag: bool型数组,大小为(batch_sizebeam_size,1)
return: (batch_sizebeam_size)
Original: https://blog.csdn.net/shaoyou223/article/details/122271121
Author: 少游223
Title: wenet/utils/mask.py代码理解
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/512936/
转载文章受原作者版权保护。转载请注明原作者出处!