【Bert + BiLSTM + CRF】实现实体命名识别,后续封装Dataset,DataLoader,进行批次训练

上次介绍了【Bert + BiLSTM + CRF】实现实体命名识别的简单应用,只使用了单个例子跑,这次接着上回继续更新,封装了一下Dataset,并进行了批量数据的训练。本项目使用的标注好的数据集可以私信找我要哦!全程无bug跑完!

项目结构:

【Bert + BiLSTM + CRF】实现实体命名识别,后续封装Dataset,DataLoader,进行批次训练
bert-base-chinese: 存放了bert模型,vocab.txt ,config.json
data: 标注好的数据
output:输出的日志文件和模型文件
dataSet.py: 数据预处理代码
main.py: 训练和验证的代码

直接上代码,后面给讲解

dataSet.py

from torch.utils.data import Dataset,DataLoader
from transformers import BertTokenizer
import torch
import warnings
import os
import json
import sys
import re
warnings.filterwarnings('ignore')

def collect_data(path,original_value,result_value,a,b,c,d,e,f):
    with open(path,'r',encoding='utf-8') as file:
        s = json.load(file)

        try:
            for i,k in enumerate(s):
                if k=='originalValue':
                    original_value.append(s['originalValue'])
                if k=='resultValue' and s['resultValue']!='':
                    result_value.append(s['resultValue'])
                if k=='classify':
                    classify_data = s[k]
                    a.append(classify_data['组织学分型']) if "组织学分型" in classify_data else a.append(" ")
                    b.append(classify_data['癌结节']) if "癌结节" in classify_data else b.append(" ")
                    c.append(classify_data['两侧切缘是否有癌浸润']) if "两侧切缘是否有癌浸润" in classify_data else c.append(" ")
                    d.append(classify_data['pCRM']) if "pCRM" in classify_data else d.append(" ")
                    e.append(classify_data['脉管']) if "脉管" in classify_data else e.append(" ")
                    f.append(classify_data['神经']) if "神经" in classify_data else f.append(" ")
        except Exception:
            print(f'Errors occus at path : {path}, key : "{k}", with reasons : {sys.exc_info()}')
    return original_value,result_value,a,b,c,d,e,f

def fun4Word(data):
    output = ''
    for i in data:
        word = ''
        label = ''
        word_label = re.split(r'(\[[^\]]+\]/aj_lcjl|\[[^\]]+\]/aj_hzjl|\[[^\]]+\]/lbj_z|\[[^\]]+\]/lbj_y|\[[^\]]+\]/lbj_fz|\[[^\]]+\]/mlh1|\[[^\]]+\]/msh2|\[[^\]]+\]/msh6|\[[^\]]+\]/pms2|\[[^\]]+\]/ki67|\[[^\]]+\]/p53)',i)
        for f in word_label:
            if 'lbj_y' in f:
                word_index = f[1:-7]
                if len(word_index)>1:
                    label_index = "B_lbjy "+(len(word_index)-2)*'M_lbjy '+"E_lbjy "
                else:
                    label_index = "W_lbjy "
                word += word_index
                label += label_index
            elif 'lbj_z' in f:
                word_index = f[1:-7]
                if len(word_index) > 1:
                    label_index = "B_lbjz " + (len(word_index) - 2)* 'M_lbjz '+ "E_lbjz "
                else:
                    label_index = "W_lbjz "
                word += word_index
                label += label_index
            elif 'lbj_fz' in f:
                word_index = f[1:-8]
                if len(word_index) > 1:
                    label_index = "B_lbjfz " + (len(word_index) - 2)*'M_lbjfz ' + "E_lbjfz "
                else:
                    label_index = "W_lbjfz "
                word += word_index
                label += label_index
            elif 'aj_lcjl' in f:
                word_index = f[1:-9]
                if 'cm' in word_index:
                    if len(word_index) > 3:
                        label_index = "B_ajl " + (len(word_index) - 4) * 'M_ajl ' + "E_ajl "+"O "*2
                    else:
                        label_index = "W_ajl " +"O "*2
                elif 'c' in word_index:
                    if len(word_index) > 2:
                        label_index = "B_ajl " + (len(word_index) - 2) * 'M_ajl ' + "E_ajl " +'O '
                    else:
                        label_index = "W_ajl " +'O '
                else:
                    if len(word_index) > 1:
                        label_index = "B_ajl " + (len(word_index) - 2) * 'M_ajl ' + "E_ajl "
                    else:
                        label_index = "W_ajl "
                word += word_index
                label += label_index
            elif 'aj_hzjl' in f:
                word_index = f[1:-9]
                if 'cm' in word_index:
                    if len(word_index) > 3:
                        label_index = "B_ajh " + (len(word_index) - 4) * 'M_ajh ' + "E_ajh " + "O " * 2
                    else:
                        label_index = "W_ajh " + "O " * 2
                elif 'c' in word_index:
                    if len(word_index) > 2:
                        label_index = "B_ajh " + (len(word_index) - 2) * 'M_ajh ' + "E_ajh " + 'O '
                    else:
                        label_index = "W_ajh " + 'O '
                else:
                    if len(word_index) > 1:
                        label_index = "B_ajh " + (len(word_index) - 2) * 'M_ajh ' + "E_ajh "
                    else:
                        label_index = "W_ajh "
                word += word_index
                label += label_index
            elif 'mlh1' in f:
                word_index = f[1:-6]
                if len(word_index) > 1:
                    label_index = "B_mlh1 " + (len(word_index) - 2) * 'M_mlh1 ' + "E_mlh1 "
                else:
                    label_index = "W_mlh1 "
                word += word_index
                label += label_index
            elif 'msh2' in f:
                word_index = f[1:-6]
                if len(word_index) > 1:
                    label_index = "B_msh2 " + (len(word_index) - 2) * 'M_msh2 ' + "E_msh2 "
                else:
                    label_index = "W_msh2 "
                word += word_index
                label += label_index
            elif 'msh6' in f:
                word_index = f[1:-6]
                if len(word_index) > 1:
                    label_index = "B_msh6 " + (len(word_index) - 2) * 'M_msh6 ' + "E_msh6 "
                else:
                    label_index = "W_msh6 "
                word += word_index
                label += label_index
            elif 'pms2' in f:
                word_index = f[1:-6]
                if len(word_index) > 1:
                    label_index = "B_pms2 " + (len(word_index) - 2) * 'M_pms2 ' + "E_pms2 "
                else:
                    label_index = "W_pms2 "
                word += word_index
                label += label_index
            elif 'ki67' in f:
                word_index = f[1:-6]
                if len(word_index) > 1:
                    label_index = "B_ki67 " + (len(word_index) - 2) * 'M_ki67 ' + "E_ki67 "
                else:
                    label_index = "W_ki67 "
                word += word_index
                label += label_index
            elif 'p53' in f:

                word_index = f[1:-5]
                if len(word_index) > 1:
                    label_index = "B_p53 " + (len(word_index) - 2) * 'M_p53 ' + "E_p53 "
                else:
                    label_index = "W_p53 "
                word += word_index
                label += label_index
            else:
                word += f
                label +=len(f)*"O "
        if word !='':
            output += word + ' //' + label + '\n'
    return output

def label_process(data):
    word = ''
    label = ''
    word_label = re.split(
        r'(\[[^\]]+\]/aj_lcjl|\[[^\]]+\]/aj_hzjl|\[[^\]]+\]/lbj_z|\[[^\]]+\]/lbj_y|\[[^\]]+\]/lbj_fz|\[[^\]]+\]/mlh1|\[[^\]]+\]/msh2|\[[^\]]+\]/msh6|\[[^\]]+\]/pms2|\[[^\]]+\]/ki67|\[[^\]]+\]/p53)',data)
    for f in word_label:
        if 'lbj_y' in f:
            word_index = f[1:-7]
            if len(word_index) > 1:
                label_index = "B_lbjy " + (len(word_index) - 2) * 'M_lbjy ' + "E_lbjy "
            else:
                label_index = "W_lbjy "
            word += word_index
            label += label_index
        elif 'lbj_z' in f:
            word_index = f[1:-7]
            if len(word_index) > 1:
                label_index = "B_lbjz " + (len(word_index) - 2) * 'M_lbjz ' + "E_lbjz "
            else:
                label_index = "W_lbjz "
            word += word_index
            label += label_index
        elif 'lbj_fz' in f:
            word_index = f[1:-8]
            if len(word_index) > 1:
                label_index = "B_lbjfz " + (len(word_index) - 2) * 'M_lbjfz ' + "E_lbjfz "
            else:
                label_index = "W_lbjfz "
            word += word_index
            label += label_index
        elif 'aj_lcjl' in f:
            word_index = f[1:-9]
            if 'cm' in word_index:
                if len(word_index) > 3:
                    label_index = "B_ajl " + (len(word_index) - 4) * 'M_ajl ' + "E_ajl " + "O " * 2
                else:
                    label_index = "W_ajl " + "O " * 2
            elif 'c' in word_index:
                if len(word_index) > 2:
                    label_index = "B_ajl " + (len(word_index) - 2) * 'M_ajl ' + "E_ajl " + 'O '
                else:
                    label_index = "W_ajl " + 'O '
            else:
                if len(word_index) > 1:
                    label_index = "B_ajl " + (len(word_index) - 2) * 'M_ajl ' + "E_ajl "
                else:
                    label_index = "W_ajl "
            word += word_index
            label += label_index
        elif 'aj_hzjl' in f:
            word_index = f[1:-9]
            if 'cm' in word_index:
                if len(word_index) > 3:
                    label_index = "B_ajh " + (len(word_index) - 4) * 'M_ajh ' + "E_ajh " + "O " * 2
                else:
                    label_index = "W_ajh " + "O " * 2
            elif 'c' in word_index:
                if len(word_index) > 2:
                    label_index = "B_ajh " + (len(word_index) - 2) * 'M_ajh ' + "E_ajh " + 'O '
                else:
                    label_index = "W_ajh " + 'O '
            else:
                if len(word_index) > 1:
                    label_index = "B_ajh " + (len(word_index) - 2) * 'M_ajh ' + "E_ajh "
                else:
                    label_index = "W_ajh "
            word += word_index
            label += label_index
        elif 'mlh1' in f:
            word_index = f[1:-6]
            if len(word_index) > 1:
                label_index = "B_mlh1 " + (len(word_index) - 2) * 'M_mlh1 ' + "E_mlh1 "
            else:
                label_index = "W_mlh1 "
            word += word_index
            label += label_index
        elif 'msh2' in f:
            word_index = f[1:-6]
            if len(word_index) > 1:
                label_index = "B_msh2 " + (len(word_index) - 2) * 'M_msh2 ' + "E_msh2 "
            else:
                label_index = "W_msh2 "
            word += word_index
            label += label_index
        elif 'msh6' in f:
            word_index = f[1:-6]
            if len(word_index) > 1:
                label_index = "B_msh6 " + (len(word_index) - 2) * 'M_msh6 ' + "E_msh6 "
            else:
                label_index = "W_msh6 "
            word += word_index
            label += label_index
        elif 'pms2' in f:
            word_index = f[1:-6]
            if len(word_index) > 1:
                label_index = "B_pms2 " + (len(word_index) - 2) * 'M_pms2 ' + "E_pms2 "
            else:
                label_index = "W_pms2 "
            word += word_index
            label += label_index
        elif 'ki67' in f:
            word_index = f[1:-6]
            if len(word_index) > 1:
                label_index = "B_ki67 " + (len(word_index) - 2) * 'M_ki67 ' + "E_ki67 "
            else:
                label_index = "W_ki67 "
            word += word_index
            label += label_index
        elif 'p53' in f:

            word_index = f[1:-5]
            if len(word_index) > 1:
                label_index = "B_p53 " + (len(word_index) - 2) * 'M_p53 ' + "E_p53 "
            else:
                label_index = "W_p53 "
            word += word_index
            label += label_index
        else:
            word += f
            label += len(f) * "O "
    return word,label

def my_collate(data):
    inputs, labels = [],[]
    for i,dat in enumerate(data):
        (input,label) = dat
        inputs.append(input)
        labels.append(label)
    return torch.tensor(inputs),torch.tensor(labels)

class MyDataSet(Dataset):

    def __init__(self,max_length = 512):

        labels = ['B_lbjy', 'M_lbjy', 'E_lbjy', 'W_lbjy', 'B_lbjz', 'M_lbjz', 'E_lbjz', 'W_lbjz', 'B_lbjfz', 'M_lbjfz',
                  'E_lbjfz', 'W_lbjfz',
                  'B_ajl', 'M_ajl', 'E_ajl', 'W_ajl', 'B_ajh', 'M_ajh', 'E_ajh', 'W_ajh', 'B_mlh1', 'M_mlh1', 'E_mlh1',
                  'W_mlh1', 'B_msh2', 'M_msh2', 'E_msh2', 'W_msh2',
                  'B_msh6', 'M_msh6', 'E_msh6', 'W_msh6', 'B_pms2', 'M_pms2', 'E_pms2', 'W_pms2', 'B_ki67', 'M_ki67',
                  'E_ki67', 'W_ki67', 'B_p53', 'M_p53', 'E_p53', 'W_p53', 'O']
        self.tag_num = len(labels)
        original_value, result_value, a, b, c, d, e, f = [], [], [], [], [], [], [], []
        count = 0
        root = 'data'
        tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

        root_path = os.listdir(root)
        for path in root_path:
            father_path = os.path.join(root, path)
            child_paths = os.listdir(os.path.join(root, path))
            for child_path in child_paths:
                count += 1
                original_value, result_value, a, b, c, d, e, f = collect_data(os.path.join(father_path, child_path),
                                                                              original_value, result_value, a, b, c, d,
                                                                              e, f)
                if result_value=='': print(f'result_value null: count :{count}, path:{os.path.join(father_path, child_path)}')
        print(f'Data Collection Info : original_value : {len(original_value)} result_value : {len(result_value)} '
              f'a : {len(a)} b : {len(b)} c :{len(c)} d : {len(d)} e : {len(e)} f : {len(f)} final count : {count}')

        tokenized_data = []
        encoded_labels = []
        for i,sentence in enumerate(result_value):
            word, label = label_process(sentence)

            if len(word)>max_length:
                word = word[:max_length]

            s = tokenizer.encode_plus(word,return_token_type_ids=True,return_attention_mask=True,return_tensors='pt',
                                             padding='max_length',max_length=max_length)
            tokenized_data.append(s)

            label = label.strip().split(' ')

            if len(label)>max_length:
                label = label[:max_length]
            if len(label)<max_length:
                label += ['O'] * (max_length-len(label))

            l = {k: v for v, k in enumerate(labels)}
            encoded_label = [l[k] for k in label]
            encoded_labels.append(encoded_label)
            if s.input_ids.shape[1]>max_length or s.attention_mask.shape[1]>max_length or s.token_type_ids.shape[1]>max_length:
                print(f'len data:{s.input_ids.shape} {s.attention_mask.shape} {s.token_type_ids.shape} len label:{len(encoded_label)}')
        self.data = tokenized_data
        self.label = encoded_labels

    def __getitem__(self, index):
        return self.data[index],self.label[index]

    def __len__(self):
        return len(self.data)

if __name__ == '__main__':
    dataset = MyDataSet()
    token_count = 0
    data_loader = DataLoader(dataset=dataset,shuffle=False,batch_size=10,collate_fn=my_collate)
    for i,data in enumerate(data_loader):
        inputs,labels = data
        print(f'inputs_size:{inputs.shape}\t labels_size:{labels.shape}')
        token_count +=1
    print(f'token_count:{token_count}')

main.py


'''
@author : sito
@date : 2022-02-25
@description:
Trying to build model (Bert+BiLSTM+CRF) to solve the problem of Ner,
With low level of code and the persistute of transformers, torch, pytorch-crf
'''
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from dataSet import MyDataSet
from torch.utils.data import DataLoader
from transformers import BertModel
from torchcrf import CRF
import time
import warnings
import logging
import sys
warnings.filterwarnings('ignore')

logger = logging.getLogger('training log')
logger.setLevel(logging.INFO)

rf_handler = logging.StreamHandler(sys.stderr)
rf_handler.setLevel(logging.INFO)
rf_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(message)s"))

f_handler = logging.FileHandler('output/training.log')
f_handler.setLevel(logging.INFO)
f_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(filename)s[:%(lineno)d] - %(message)s"))
logger.addHandler(rf_handler)
logger.addHandler(f_handler)

def my_collate(data):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    input_ids, attention_mask, token_type_ids, labels = [],[],[],[]
    for i,dat in enumerate(data):
        (input,label) = dat
        input_ids.append(input.input_ids.cpu().squeeze().detach().numpy().tolist())
        attention_mask.append(input.attention_mask.cpu().squeeze().detach().numpy().tolist())
        token_type_ids.append(input.token_type_ids.cpu().squeeze().detach().numpy().tolist())
        labels.append(label)
    return {'input_ids': torch.tensor(input_ids).to(device), 'attention_mask':torch.tensor(attention_mask).to(device),
            'token_type_ids':torch.tensor(token_type_ids).to(device)}, torch.tensor(labels).to(device)

class Model(nn.Module):

    def __init__(self,tag_num):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-chinese')
        config = self.bert.config
        self.lstm = nn.LSTM(bidirectional=True, num_layers=2, input_size=config.hidden_size, hidden_size=config.hidden_size//2, batch_first=True)
        self.crf = CRF(tag_num)
        self.fc = nn.Linear(config.hidden_size,tag_num)

    def forward(self,x,y):
        with torch.no_grad():
            bert_output = self.bert(input_ids=x['input_ids'],attention_mask=x['attention_mask'],token_type_ids=x['token_type_ids'])[0]
        lstm_output, _ = self.lstm(bert_output)
        fc_output = self.fc(lstm_output)
        loss = self.crf(fc_output,y)
        tag = self.crf.decode(fc_output)
        return loss,tag

if __name__ == '__main__':

    epoches = 50
    max_length = 512
    batch_size = 64
    lr = 0.0001

    dataset = MyDataSet(max_length)
    tag_num = dataset.tag_num
    data_loader = DataLoader(dataset=dataset, shuffle=False, batch_size=batch_size, collate_fn=my_collate)

    logger.info(f'>>> Training Start!')
    model = Model(tag_num).cuda()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=50)
    for e in range(epoches):

        epoch_end_loss = 0
        model.train()
        for i,data in enumerate(data_loader):
            optimizer.zero_grad()
            inputs, labels = data
            loss,_ = model(inputs,labels)
            loss = abs(loss)
            loss.backward()
            optimizer.step()
            scheduler.step()
            epoch_end_loss = loss
            if i%10==0:
                logger.info(f'>>> epoch {e} <<< step {i} : loss : {loss}')
        logger.info(f'{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())} epoch {e} training loss : {epoch_end_loss}')

        if e%10==0 and e!=0:
            model.eval()
            logger.info(f'{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())} epoch {e} Start Evaluation!')
            step_end_accuracy = []
            with torch.no_grad():
                for i, data in enumerate(data_loader):
                    inputs, labels = data
                    _, tag = model(inputs,labels)
                    tag = np.array(tag).T

                    for i,(pre_y,real_y) in enumerate(zip(tag,labels)):
                        assert pre_y.shape[0]==real_y.shape[0]==max_length, \
                            f'length not match pre_y.shape[0]:{pre_y.shape[0]} real_y.shape[0]:{real_y.shape[0]}  max_length:{max_length}'
                        sum = pre_y.shape[0]
                        real_y_numpy= real_y.cpu().numpy()
                        cal = pre_y==real_y_numpy
                        count = np.where(cal>0)[0].size
                        accu = count/sum
                        step_end_accuracy.append(accu)
            epoch_end_accuracy = np.mean(step_end_accuracy)
            logger.info(f'{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())} epoch {e} evaluation accuracy : {epoch_end_accuracy}')

            torch.save(model.state_dict(),f'model_p_{epoch_end_accuracy}.pt')

如果您为数据集添加标签,则不方便发布它。感兴趣的朋友可以给我发私信。

[En]

If you label the dataset, it is not convenient to post it. Interested friends can send a private message to me.

训练的过程也是比较简单了,一共50个epoch,每10个epoch进行一次evaluation,并保存我们的模型,优化器使用了Adam,学利率方面使用了余弦退火的策略,因为一开始学习率需要大一点,越到后面模型学习到的信息越多,就不需要很大的学习率了,小的学习率反而能增加模型的鲁棒性和性能,具体可以看一下albert的论文,里面有很多的训练trick。

欢迎大家一起学习交流,博主对计算机视觉和NLP方向都很感兴趣,以后也会不定时的更新一些好的比较有用的文章,感兴趣的童鞋可以关注我哦~

最后,发布培训输出的结果。

[En]

Finally, post the results of the training output.

【Bert + BiLSTM + CRF】实现实体命名识别,后续封装Dataset,DataLoader,进行批次训练
可以看到在第10个epoch的时候,已经有0.96051的准确率了,bert果然是厉害啊!

Original: https://blog.csdn.net/m0_37576959/article/details/123233758
Author: Sito_zz
Title: 【Bert + BiLSTM + CRF】实现实体命名识别,后续封装Dataset,DataLoader,进行批次训练

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/527539/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球