CMeKG代码解读(以项目为导向从零开始学习知识图谱)(五)

新的一个python文件,继续加油!!!

medical_ner.py

medical_ner类:

from_input():

from_text():

spilt_entity_input():

predict_sentence():

predict_file():

import codecs
import torch
from torch.autograd import Variable
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import json
from utils import load_vocab
from ner_constant import *

from model_ner import BERT_LSTM_CRF
import os

从命名方式易知,该文件是用来进行医学实体识别的。

medical_ner类:

class medical_ner(object):
    def __init__(self):
        self.NEWPATH = '/Users/yangyf/workplace/model/medical_ner/model.pkl'
        self.vocab = load_vocab('/Users/yangyf/workplace/model/medical_ner/vocab.txt')
        self.vocab_reverse = {v: k for k, v in self.vocab.items()}

        self.model = BERT_LSTM_CRF('/Users/yangyf/workplace/model/medical_ner', tagset_size, 768, 200, 2,
                              dropout_ratio=0.5, dropout1=0.5, use_cuda=use_cuda)

        if use_cuda:
            self.model.to(device)

和前面的医学文本分词类似,首先是加载命名实体识别的字典,然后使用BERT_LSTM_CRF模型进行训练,并且询问cuda是否准备好,准备则在cuda上运行。

    def from_input(self, input_str):
        raw_text = []
        textid = []
        textmask = []
        textlength = []
        text = ['[CLS]'] + [x for x in input_str] + ['[SEP]']
        raw_text.append(text)
        cur_len = len(text)
        # raw_textid = [self.vocab[x] for x in text] + [0] * (max_length - cur_len)
        raw_textid = [self.vocab[x] for x in text if self.vocab.__contains__(x)] + [0] * (max_length - cur_len)
        textid.append(raw_textid)
        raw_textmask = [1] * cur_len + [0] * (max_length - cur_len)
        textmask.append(raw_textmask)
        textlength.append([cur_len])
        textid = torch.LongTensor(textid)
        textmask = torch.LongTensor(textmask)
        textlength = torch.LongTensor(textlength)
        return raw_text, textid, textmask, textlength

和医学文本分词的单行输入的情况都一样,前期处理都类似。

    def from_txt(self, input_path):
        raw_text = []
        textid = []
        textmask = []
        textlength = []
        with open(input_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                if len(line.strip())==0:
                    continue
                if len(line) > 448:
                    line = line[:448]
                temptext = ['[CLS]'] + [x for x in line[:-1]] + ['[SEP]']
                cur_len = len(temptext)
                raw_text.append(temptext)

                tempid = [self.vocab[x] for x in temptext[:cur_len]] + [0] * (max_length - cur_len)
                textid.append(tempid)
                textmask.append([1] * cur_len + [0] * (max_length - cur_len))
                textlength.append([cur_len])

        textid = torch.LongTensor(textid)
        textmask = torch.LongTensor(textmask)
        textlength = torch.LongTensor(textlength)
        return raw_text, textid, textmask, textlength

同样和医学文本分词的多行输入处理函数一样,只不过在此函数中能处理的文本长度更长。(我怀疑是不是敲错了,一位1和4还是蛮近的)

    def split_entity_input(self,label_seq):
        entity_mark = dict()
        entity_pointer = None
        for index, label in enumerate(label_seq):
            #print(f"before: {label_seq}")
            if label.split('-')[-1]=='B':
                category = label.split('-')[0]
                entity_pointer = (index, category)
                entity_mark.setdefault(entity_pointer, [label])
            elif label.split('-')[-1]=='M':
                if entity_pointer is None: continue
                if entity_pointer[1] != label.split('-')[0]: continue
                entity_mark[entity_pointer].append(label)
            elif label.split('-')[-1]=='E':
                if entity_pointer is None: continue
                if entity_pointer[1] != label.split('-')[0]: continue
                entity_mark[entity_pointer].append(label)
            else:
                entity_pointer = None
           # print(entity_mark)
        return entity_mark

对输入进函数的文本序列首先进行enumerate标号处理,然后根据标签判断是否是’B’、’M’和’E’中的哪一种( B stands for ‘ beginning‘ (signifies beginning of an NE)、 Mstands for ‘ middle‘ (signifies that the word is inside an NE)、 E stands for ‘ end‘ (signifies that the word is the end of an NE))。如果判断标签里面设定的是B,则记录这个开始标签的类别并放到实体指针当中,然后按照指针对应标签内容的方式存储在字典内;如果判断为M,则再次判断是否有已经存在的实体指针,如果没有实体指针,则继续循环,如果有实体指针,再判定实体指针的类别是否和该标签的内的类别一致,如果一致则按照实体指针存储在entity_mark字典中,不一致继续循环遍历;如果判断为E,和判断为M的操作相同,最终返回entity_mark字典。而从函数中我们也可以大概推断一下label_seq里面的存储方式,最后一个元素用来表示文本的内容处于实体内什么样的位置,第一个元素存储了该实体的类别。

    def predict_sentence(self, sentence):
        tag_dic = {"d": "疾病", "b": "身体", "s": "症状", "p": "医疗程序", "e": "医疗设备", "y": "药物", "k": "科室",
                   "m": "微生物类", "i": "医学检验项目"}
        if sentence == '':
            print("输入为空!请重新输入")
            return
        if len(sentence) > 448:
            print("输入句子过长,请输入小于148的长度字符!")
            sentence = sentence[:448]
        raw_text, test_ids, test_masks, test_lengths = self.from_input(sentence)
        test_dataset = TensorDataset(test_ids, test_masks, test_lengths)
        test_loader = DataLoader(test_dataset, shuffle=False, batch_size=1)
        self.model.load_state_dict(torch.load(self.NEWPATH, map_location=device))
        self.model.eval()

        for i, dev_batch in enumerate(test_loader):
            sentence, masks, lengths = dev_batch
            batch_raw_text = raw_text[i]
            sentence, masks, lengths = Variable(sentence), Variable(masks), Variable(lengths)
            if use_cuda:
                sentence = sentence.to(device)
                masks = masks.to(device)

            predict_tags = self.model(sentence, masks)
            predict_tags.tolist()
            predict_tags = [i2l_dic[t.item()] for t in predict_tags[0]]
            predict_tags = predict_tags[:len(batch_raw_text)]
            pred = predict_tags[1:-1]
            raw_text = batch_raw_text[1:-1]
            entity_mark = self.split_entity_input(pred)
            entity_list = {}
            if entity_mark is not None:
                for item, ent in entity_mark.items():
                    # print(item, ent)
                    entity = ''
                    index, tag = item[0], item[1]
                    len_entity = len(ent)

                    for i in range(index, index + len_entity):
                        entity = entity + raw_text[i]
                    entity_list[tag_dic[tag]] = entity
            # print(entity_list)
        return entity_list

该函数的前半部分是将上一篇文章当中的predict_sentence()和recover_to_text()两个函数结合起来,因此前半部分作用也都一样。然后是对预测出来的pred列表放到上一个spilt_entity_input()函数中进行切分和组合处理,如果最后切分出来的字典非空则针对其中的内容进行提取处理。从该函数中也可以看出,上一个函数返回的字典中所包含的内容形式,包含有实体在文本中的位置,实体的类型,还有实体本身。

    def predict_file(self, input_file, output_file):
        tag_dic = {"d": "疾病", "b": "身体", "s": "症状", "p": "医疗程序", "e": "医疗设备", "y": "药物", "k": "科室",
                   "m": "微生物类", "i": "医学检验项目"}
        raw_text, test_ids, test_masks, test_lengths = self.from_txt(input_file)
        test_dataset = TensorDataset(test_ids, test_masks, test_lengths)
        test_loader = DataLoader(test_dataset, shuffle=False, batch_size=1)
        self.model.load_state_dict(torch.load(self.NEWPATH, map_location=device))
        self.model.eval()
        op_file = codecs.open(output_file, 'w', 'utf-8')
        for i, dev_batch in enumerate(test_loader):
            sentence, masks, lengths = dev_batch
            batch_raw_text = raw_text[i]
            sentence, masks, lengths = Variable(sentence), Variable(masks), Variable(lengths)
            if use_cuda:
                sentence = sentence.to(device)
                masks = masks.to(device)

            predict_tags = self.model(sentence, masks)
            predict_tags.tolist()
            predict_tags = self.model(sentence, masks)
            predict_tags.tolist()
            predict_tags = [i2l_dic[t.item()] for t in predict_tags[0]]
            predict_tags = predict_tags[:len(batch_raw_text)]
            pred = predict_tags[1:-1]
            raw_text = batch_raw_text[1:-1]

            entity_mark = self.split_entity_input(pred)
            entity_list = {}
            if entity_mark is not None:
                for item, ent in entity_mark.items():
                    entity = ''
                    index, tag = item[0], item[1]
                    len_entity = len(ent)
                    for i in range(index, index + len_entity):
                        entity = entity + raw_text[i]
                    entity_list[tag_dic[tag]] = entity
            op_file.write("".join(raw_text))
            op_file.write("\n")
            op_file.write(json.dumps(entity_list, ensure_ascii=False))
            op_file.write("\n")

        op_file.close()
        print('处理完成!')
        print("结果保存至 {}".format(output_file))

此时是处理的内容是多行输入的情况,和上一篇文章中的多行处理类似,同时处理思路也与本篇文章中的单行处理思路相同,函数结束部分添加了部分文档操作的内容,方便使用。

有前面的医学文本分词做铺垫,医疗实体识别的难度降低了不少,但是关键的是处理这两项任务的模型还没有阅读,这才是本次工程的重点任务。如果您看到这里的话,希望您能留下一个免费的赞赞,这对我的帮助很大,对我本人同样也有着激励作用。

Original: https://blog.csdn.net/chen_nnn/article/details/122873902
Author: chen_nnn
Title: CMeKG代码解读(以项目为导向从零开始学习知识图谱)(五)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/556374/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球