wenet/utils/mask.py代码理解

原理:后续补充链接

代码位置: https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py

函数定义:主要包括如下几个函数

subsequent_chunk_mask:

参数:size,chunk_size, num_left_chunks,device

作用:创建chunk格式的mask。如输入size=4,chunk_size=2, num_left_chunks=-1,device=”cpu”,如下图

wenet/utils/mask.py代码理解

make_pad_mask:

作用:使mask补0的位置为1,其余位置为0

wenet/utils/mask.py代码理解

make_finished_scores:

参数:score, flag

作用:当一个sequence完成后,只允许保留一个alive branch。这个函数就是将一个branch的score置为0,而其他branch的score置为-inf

score: 一个实数数组,大小为(batch_size*beam_size,beam_size)

flag: bool型数组,大小为(batch_size,1)

在识别的解码阶段被调用,https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/asr_model.py

其主要作用应该使当预测完后就用eos进行mask

flag: bool型数组,大小为(batch_sizebeam_size,1)
return: (batch_size
beam_size)

Original: https://blog.csdn.net/shaoyou223/article/details/122271121
Author: 少游223
Title: wenet/utils/mask.py代码理解

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/512936/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球