“BERT” 这个词相信大家已经不在陌生了, 发布至今,BERT 已成为 NLP 实验中无处不在的基线。这里稍微扯一下什么是BERT毕竟不是今天的重点,BERT在模型架构方面沿用了Transformer的Encoder端(不知道什么是transformer的小伙伴们可以去阅读论文:),它是一个预训练模型,模型训练时两个任务分别是预测句子中被掩盖的词以及判断输入的两个句子是不是上下句。在预训练好的BERT模型后面根据特定任务加上相应的网络,可以完成NLP的下游任务,比如文本分类、机器翻译等。说的简单点核心就是通过上下文去增强对目标词的表达。
今天主要是想和大家扒一扒这两个预训练任务的源码,预估你的收获是:1)熟系BERT预训练代码,如果条件允许的话可以自己进行预训练 ;2)最近大火的Prompt范式,可以使用BERT源码实现。
一、Mask Launage Model
随机掩盖掉一些单词,然后通过上下文预测该单词。BERT中有15%的子词(BERT是以 wordpiece token为最小单位)会被随机掩盖,这15%的token中有80%的概率会被mask, 10%的概率用随机其他词来替换 (使得模型具有一定纠错能力)还有10%的概率不做操作(和下游任务统一)。那么这一部分具体是怎么操作的呢,接下来带着大家看看源码是如何实现的。
1.2 mlm源码
<br>def create_masked_lm_predictions(tokens, masked_lm_prob,<br>                                 max_predictions_per_seq, vocab_words, rng):<br>  <span class="hljs-string">""</span><span class="hljs-string">"<br>  tokens:输入文本<br>  masked_lm_prob:掩码语言模型的掩码概率<br>  max_predictions_per_seq:每个序列最大预测数目<br>  vocab_words:每个列表的最大预测数目<br>  rng: 随机数生成器<br>  <br>  "</span><span class="hljs-string">""</span><br><br>  cand_indexes = []  <br>  <span class="hljs-keyword">for</span> (i, token) <span class="hljs-keyword">in</span> enumerate(tokens):<br>   <br>    <span class="hljs-keyword">if</span> token == <span class="hljs-string">"[CLS]"</span> or token == <span class="hljs-string">"[SEP]"</span>:<br>      <span class="hljs-built_in">continue</span><br>    <br>    <span class="hljs-keyword">if</span> (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and<br>        token.startswith(<span class="hljs-string">"##"</span>)):<br>      cand_indexes[-1].append(i)<br>    <span class="hljs-keyword">else</span>:<br>      cand_indexes.append([i])<br> <br>  rng.shuffle(cand_indexes)<br>  <br>  output_tokens = list(tokens)<br>  <br>  num_to_predict = min(max_predictions_per_seq,<br>                       max(1, int(round(len(tokens) * masked_lm_prob))))<br><br>  masked_lms = []<br>  covered_indexes = <span class="hljs-built_in">set</span>() <br>  <span class="hljs-keyword">for</span> index_set <span class="hljs-keyword">in</span> cand_indexes:<br>    <span class="hljs-keyword">if</span> len(masked_lms) >= num_to_predict:<br>      <span class="hljs-built_in">break</span><br>    <br>    <br>    <span class="hljs-keyword">if</span> len(masked_lms) + len(index_set) > num_to_predict:<br>      <span class="hljs-built_in">continue</span><br>    is_any_index_covered = False<br>    <span class="hljs-keyword">for</span> index <span class="hljs-keyword">in</span> index_set:<br>      <span class="hljs-keyword">if</span> index <span class="hljs-keyword">in</span> covered_indexes:<br>        is_any_index_covered = True<br>        <span class="hljs-built_in">break</span><br>    <span class="hljs-keyword">if</span> is_any_index_covered:<br>      <span class="hljs-built_in">continue</span><br>    <span class="hljs-keyword">for</span> index <span class="hljs-keyword">in</span> index_set:<br>      covered_indexes.add(index)<br><br>      masked_token = None<br>      <br>      <span class="hljs-keyword">if</span> rng.random() < 0.8:<br>        masked_token = <span class="hljs-string">"[MASK]"</span><br>      <span class="hljs-keyword">else</span>:<br>        <br>        <span class="hljs-keyword">if</span> rng.random() < 0.5:<br>          masked_token = tokens[index]<br>        <br>        <span class="hljs-keyword">else</span>:<br>          masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]<br><br>      output_tokens[index] = masked_token <br>  <br>      masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) <br>  assert len(masked_lms) <= num_to_predict<br>  masked_lms = sorted(masked_lms, key=lambda x: x.index)<br><br>  masked_lm_positions = [] <br>  masked_lm_labels = [] <br>  <span class="hljs-keyword">for</span> p <span class="hljs-keyword">in</span> masked_lms:<br>    masked_lm_positions.append(p.index)<br>    masked_lm_labels.append(p.label)<br><br>  <span class="hljs-built_in">return</span> (output_tokens, masked_lm_positions, masked_lm_labels)<br></= num_to_predict<br></ 0.5:<br></ 0.8:
以上就是创建训练数据的源码,接下来讲解下模型如何训练得到masked LM loss
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,<br>                         label_ids, label_weights):<br>  <span class="hljs-string">""</span><span class="hljs-string">"Get loss and log probs for the masked LM."</span><span class="hljs-string">""</span><br>  <br>  <br>  <br>  <br>  input_tensor = gather_indexes(input_tensor, positions)<br><br>  with tf.variable_scope(<span class="hljs-string">"cls/predictions"</span>):<br>    <br>    <br>    with tf.variable_scope(<span class="hljs-string">"transform"</span>):<br>      <br>      input_tensor = tf.layers.dense(<br>          input_tensor,<br>          units=bert_config.hidden_size,<br>          activation=modeling.get_activation(bert_config.hidden_act),<br>          kernel_initializer=modeling.create_initializer(<br>              bert_config.initializer_range))<br>      <br>      input_tensor = modeling.layer_norm(input_tensor)<br><br>    <br>    <br>    output_bias = tf.get_variable(<br>        <span class="hljs-string">"output_bias"</span>,<br>        shape=[bert_config.vocab_size],<br>        initializer=tf.zeros_initializer())<br>    logits = tf.matmul(input_tensor, output_weights, transpose_b=True)<br>    logits = tf.nn.bias_add(logits, output_bias)<br><br>    log_probs = tf.nn.log_softmax(logits, axis=-1)<br><br>    label_ids = tf.reshape(label_ids, [-1])<br>    label_weights = tf.reshape(label_weights, [-1])<br><br>    one_hot_labels = tf.one_hot(<br>        label_ids, depth=bert_config.vocab_size, dtype=tf.float32)<br><br>    <br>    <br>    <br>    <br>    per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])<br>    numerator = tf.reduce_sum(label_weights * per_example_loss)<br>    denominator = tf.reduce_sum(label_weights) + 1e-5<br>    loss = numerator / denominator<br><br>  <span class="hljs-built_in">return</span> (loss, per_example_loss, log_probs)
其实看下来和fine-tune差不多,如果想要实现Prompt无非是把fine-tune时create_model 函数替换为get_masked_lm_output,输入输出得改变下。
使用WordPiece的时候一个单词可能会被拆分成两部分,比如 loving 会被拆分成 lov ##ing 如果mask的时候可能只mask两者之一,那么如果只mask一部分的话很容易被模型预测到,比如”我很喜欢吃苹[MASK]”,模型很容易根据”苹”预测出果,那么我们希望mask整个单词,其实新版bert已经支持英文的整词mask了,中文整词mask需要先进行分词。
二、Next Sentence prediction
该任务其实就是分类任务,输入[CLS]a[SEP]b[SEP],预测b是否为a的下一句,即二分类问题。
原文中50%的概率两个句子来自于同一个文档中的上下文(正样本),50%的概率来自不同文档的句子(负样本)
def get_next_sentence_output(bert_config, input_tensor, labels):<br>  <span class="hljs-string">""</span><span class="hljs-string">"Get loss and log probs for the next sentence prediction."</span><span class="hljs-string">""</span><br><br>  <br>  <br>  with tf.variable_scope(<span class="hljs-string">"cls/seq_relationship"</span>):<br>    output_weights = tf.get_variable(<br>        <span class="hljs-string">"output_weights"</span>,<br>        shape=[2, bert_config.hidden_size],<br>        initializer=modeling.create_initializer(bert_config.initializer_range))<br>    output_bias = tf.get_variable(<br>        <span class="hljs-string">"output_bias"</span>, shape=[2], initializer=tf.zeros_initializer())<br>  <br>    logits = tf.matmul(input_tensor, output_weights, transpose_b=True)<br>    logits = tf.nn.bias_add(logits, output_bias)<br>    log_probs = tf.nn.log_softmax(logits, axis=-1)<br>    labels = tf.reshape(labels, [-1])<br>    one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)<br>    per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)<br>    loss = tf.reduce_mean(per_example_loss)<br>    <span class="hljs-built_in">return</span> (loss, per_example_loss, log_probs)
源码这里也是相当的简单呀,不是就是拿cls位的向量,经过一次线下变换,输入softmax得到一个概率(0-1),判断是否是上下文。
三、总结
- 效果好,横扫了11项NLP任务。bert之后基本全面拥抱transformer。微调下游任务的时候,即使数据集非常小(比如小于5000个标注样本),模型性能也有不错的提升。
- [MASK]标记在实际预测中不会出现,训练时用过多[MASK]影响模型表现
- 每个batch只有15%的token被预测,所以BERT收敛得比left-to-right模型要慢(它们会预测每个token)
- BERT的预训练任务MLM使得能够借助上下文对序列进行编码,但同时也使得其预训练过程与中的数据与微调的数据不匹配,难以适应生成式任务
- BERT没有考虑预测[MASK]之间的相关性,是对语言模型联合概率的有偏估计
- 由于最大输入长度的限制,适合句子和段落级别的任务,不适用于文档级别的任务(如长文本分类)
Original: https://blog.csdn.net/justorderman/article/details/122144438
Author: CReep~
Title: 浅谈BERT预训练源码
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/542823/
转载文章受原作者版权保护。转载请注明原作者出处!