【Bert + BiLSTM + CRF】实现实体命名识别,最少的代码实现功能,简单易用

【Bert + BiLSTM + CRF】实现实体命名识别,最少的代码实现功能,简单易用

基于Transformers, pytorch-crf, torch.nn 实现NLP的实体命名识别ner功能,此文章仅供刚入门NLP的新同学,大佬请绕路,不敢造作……

废话不多说,直接上代码…

代码


'''
@author : sito
@date : 2022-02-25
@description:
Trying to build model (Bert+BiLSTM+CRF) to solve the problem of Ner,
With low level of code and the persistute of transformers, torch, pytorch-crf
Next Step is to stronger the Training Dataset and text the real data.

'''
import torch
import torch.nn as nn
from transformers import BertModel, AdamW, BertTokenizer
from torchcrf import CRF

class Model(nn.Module):

    def __init__(self,tag_num,max_length):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-chinese')
        config = self.bert.config
        self.lstm = nn.LSTM(bidirectional=True, num_layers=2, input_size=config.hidden_size, hidden_size=config.hidden_size//2, batch_first=True)
        self.crf = CRF(tag_num)
        self.fc = nn.Linear(config.hidden_size,tag_num)

    def forward(self,x,y):
        with torch.no_grad():
            bert_output = self.bert(input_ids=x.input_ids,attention_mask=x.attention_mask,token_type_ids=x.token_type_ids)[0]
        lstm_output, _ = self.lstm(bert_output)
        fc_output = self.fc(lstm_output)

        loss = self.crf(fc_output,y)
        tag = self.crf.decode(fc_output)
        return loss,tag

if __name__ == '__main__':

    epoches = 50
    max_length = 30

    x = ["我 和 小 明 今 天 去 了 北 京".split(),"普 京 在  昨 天 进 攻 了 乌 克 拉 , 造 成 了 大 量 人 员 的 伤 亡".split()]
    y = ["O O B-PER I-PER O O O O B-LOC I-LOC".split(), "B-PER I-PER O O O O O O B-LOC I-LOC I-LOC O O O O O O O O O O O".split()]

    tag_to_ix = {"B-PER": 0, "I-PER": 1, "O": 2, "[CLS]": 3, "[SEP]": 4, "B-LOC":5, "I-LOC":6}

    labels = []
    for label in y:
        r = [tag_to_ix[x] for x in label]
        if len(r)<max_length:
            r += [tag_to_ix['O']] * (max_length-len(r))
        labels.append(r)

    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

    tokenizer_result = tokenizer.encode_plus(x[0],return_token_type_ids=True,return_attention_mask=True,return_tensors='pt',
                                             padding='max_length',max_length=max_length)

    model = Model(len(tag_to_ix),max_length)
    optimizer = AdamW(model.parameters(), lr=5e-4)
    model.train()
    for i in range(epoches):
        loss,_ = model(tokenizer_result, torch.tensor(labels[0]).unsqueeze(dim=0))
        loss = abs(loss)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f'loss : {loss}')

    model.eval()
    with torch.no_grad():
        _, tag = model(tokenizer_result, torch.tensor(labels[0]).unsqueeze(dim=0))
        print(f' ori tag: {labels[0]} \n predict tag : {tag}')

输出如下

loss : 58.82263946533203
loss : 38.57122039794922
loss : 23.51917839050293
loss : 17.443994522094727
loss : 19.788169860839844
loss : 20.11433219909668
loss : 18.492122650146484
loss : 15.683046340942383
loss : 13.58675765991211
loss : 13.220212936401367
loss : 13.955292701721191
loss : 14.226445198059082
loss : 13.392374992370605
loss : 11.734152793884277
loss : 10.527318954467773
loss : 10.421426773071289
loss : 9.423626899719238
loss : 8.633639335632324
loss : 8.24331283569336
loss : 7.25379753112793
loss : 6.560595512390137
loss : 5.79638147354126
loss : 5.452062606811523
loss : 4.555328369140625
loss : 4.393014430999756
loss : 3.9612410068511963
loss : 2.97975754737854
loss : 2.843627691268921
loss : 2.211019515991211
loss : 1.8970086574554443
loss : 2.2976162433624268
loss : 1.3155405521392822
loss : 2.394059658050537
loss : 2.8929264545440674
loss : 0.859898567199707
loss : 1.008394718170166
loss : 0.8743772506713867
loss : 0.567021369934082
loss : 0.6397604942321777
loss : 0.38392019271850586
loss : 0.36254167556762695
loss : 0.25980567932128906
loss : 0.2933225631713867
loss : 0.1776103973388672
loss : 0.16666841506958008
loss : 0.2049698829650879
loss : 0.15804100036621094
loss : 0.1316676139831543
loss : 0.09202098846435547
loss : 0.0849909782409668
 ori tag: [2, 2, 0, 1, 2, 2, 2, 2, 5, 6, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
 predict tag : [[2], [2], [0], [1], [2], [2], [2], [2], [5], [6], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2]]

一开始的loss比较高,但是随着训练增多loss会一点点下去的,还没有在大规模数据集上跑,后续可能会更新DataLoader,DataSet等工具的封装,实现大数据集的训练,此demo仅供参考。感兴趣的同学可以私信或者关注我哦!

Bert中文预训练文件可以在 这里 下载

链接:https://huggingface.co/bert-base-chinese/tree/main
只需要下载config.json, pytorch_model.bin, vocab.txt 三个文件就行啦
下载之后保存在bert-base-chinese文件夹下

像这样:

【Bert + BiLSTM + CRF】实现实体命名识别,最少的代码实现功能,简单易用
然后运行python model.py或者直接跑上面的代码就行了。
model.py里面就是上面贴的代码。

第一次写文章有点紧张,谢谢大家的鼓励和支持!!

Original: https://blog.csdn.net/m0_37576959/article/details/123135281
Author: Sito_zz
Title: 【Bert + BiLSTM + CRF】实现实体命名识别,最少的代码实现功能,简单易用

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/710316/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球