信息抽取(五)实体命名识别之嵌套实体识别哪家强,我做了一个简单的对比实验

实体矩阵构建框架

GlobalPointer

class GlobalPointer(Module):
    """全局指针模块
    将序列的每个(start, end)作为整体来进行判断
"""
    def __init__(self, heads, head_size,hidden_size,RoPE=True):
        super(GlobalPointer, self).__init__()
        self.heads = heads
        self.head_size = head_size
        self.RoPE = RoPE
        self.dense = nn.Linear(hidden_size,self.head_size * self.heads * 2)

    def forward(self, inputs, mask=None):
        inputs = self.dense(inputs)
        inputs = torch.split(inputs, self.head_size * 2 , dim=-1)
        inputs = torch.stack(inputs, dim=-2)
        qw, kw = inputs[..., :self.head_size], inputs[..., self.head_size:]

        if self.RoPE:
            pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
            cos_pos = pos[..., None, 1::2].repeat(1,1,1,2)
            sin_pos = pos[..., None, ::2].repeat(1,1,1,2)
            qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 4)
            qw2 = torch.reshape(qw2, qw.shape)
            qw = qw * cos_pos + qw2 * sin_pos
            kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 4)
            kw2 = torch.reshape(kw2, kw.shape)
            kw = kw * cos_pos + kw2 * sin_pos

        logits = torch.einsum('bmhd , bnhd -> bhmn', qw, kw)

        logits = add_mask_tril(logits,mask)
        return logits / self.head_size ** 0.5

TPLinker

class MutiHeadSelection(Module):

    def __init__(self,hidden_size,c_size,abPosition = False,rePosition=False, maxlen=None,max_relative=None):
        super(MutiHeadSelection, self).__init__()
        self.hidden_size = hidden_size
        self.c_size = c_size
        self.abPosition = abPosition
        self.rePosition = rePosition
        self.Wh = nn.Linear(hidden_size * 2,self.hidden_size)
        self.Wo = nn.Linear(self.hidden_size,self.c_size)
        if self.rePosition:
            self.relative_positions_encoding = relative_position_encoding(max_length=maxlen,
                                    depth= 2 * hidden_size,max_relative_position=max_relative)

    def forward(self, inputs, mask=None):
        input_length = inputs.shape[1]
        batch_size = inputs.shape[0]
        if self.abPosition:

            inputs = SinusoidalPositionEmbedding(self.hidden_size, 'add')(inputs)
        x1 = torch.unsqueeze(inputs, 1)
        x2 = torch.unsqueeze(inputs, 2)
        x1 = x1.repeat(1, input_length, 1, 1)
        x2 = x2.repeat(1, 1, input_length, 1)
        concat_x = torch.cat([x2, x1], dim=-1)

        if self.rePosition:

            relations_keys = self.relative_positions_encoding[:input_length, :input_length, :].to(inputs.device)
            concat_x += relations_keys
        hij = torch.tanh(self.Wh(concat_x))
        logits = self.Wo(hij)
        logits = logits.permute(0,3,1,2)
        logits = add_mask_tril(logits, mask)
        return logits

Tencent Muti-head

class TxMutihead(Module):

    def __init__(self,hidden_size,c_size,abPosition = False,rePosition=False, maxlen=None,max_relative=None):
        super(TxMutihead, self).__init__()
        self.hidden_size = hidden_size
        self.c_size = c_size
        self.abPosition = abPosition
        self.rePosition = rePosition
        self.Wh = nn.Linear(hidden_size * 4, self.hidden_size)
        self.Wo = nn.Linear(self.hidden_size,self.c_size)
        if self.rePosition:
            self.relative_positions_encoding = relative_position_encoding(max_length=maxlen,
                                    depth= 4 * hidden_size,max_relative_position=max_relative)

    def forward(self, inputs, mask=None):
        input_length = inputs.shape[1]
        batch_size = inputs.shape[0]
        if self.abPosition:

            inputs = SinusoidalPositionEmbedding(self.hidden_size, 'add')(inputs)
        x1 = torch.unsqueeze(inputs, 1)
        x2 = torch.unsqueeze(inputs, 2)
        x1 = x1.repeat(1, input_length, 1, 1)
        x2 = x2.repeat(1, 1, input_length, 1)
        concat_x = torch.cat([x2, x1,x2-x1,x2.mul(x1)], dim=-1)
        if self.rePosition:
            relations_keys = self.relative_positions_encoding[:input_length, :input_length, :].to(inputs.device)
            concat_x += relations_keys
        hij = torch.tanh(self.Wh(concat_x))
        logits = self.Wo(hij)
        logits = logits.permute(0,3,1,2)
        logits = add_mask_tril(logits, mask)
        return logits

Deep Biaffine

class Biaffine(Module):

    def __init__(self, in_size, out_size, Position = False):
        super(Biaffine, self).__init__()
        self.out_size = out_size
        self.weight1 = Parameter(torch.Tensor(in_size, out_size, in_size))
        self.weight2 = Parameter(torch.Tensor(2 * in_size + 1, out_size))
        self.Position = Position
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight1,a=math.sqrt(5))
        torch.nn.init.kaiming_uniform_(self.weight2,a=math.sqrt(5))

    def forward(self, inputs, mask = None):
        input_length = inputs.shape[1]
        hidden_size = inputs.shape[-1]
        if self.Position:

            inputs = SinusoidalPositionEmbedding(hidden_size, 'add')(inputs)
        x1 = torch.unsqueeze(inputs, 1)
        x2 = torch.unsqueeze(inputs, 2)
        x1 = x1.repeat(1, input_length, 1, 1)
        x2 = x2.repeat(1, 1, input_length, 1)
        concat_x = torch.cat([x2, x1], dim=-1)
        concat_x = torch.cat([concat_x, torch.ones_like(concat_x[..., :1])],dim=-1)

        logits_1 = torch.einsum('bxi,ioj,byj -> bxyo', inputs, self.weight1, inputs)
        logits_2 = torch.einsum('bijy,yo -> bijo', concat_x, self.weight2)
        logits = logits_1 + logits_2
        logits = logits.permute(0,3,1,2)
        logits = add_mask_tril(logits, mask)
        return logits

代码开源,各种实体矩阵构建方法都写成了类,方便大家复现或直接调用 https://github.com/zhengyanzhao1997/NLP-model/tree/main/model/model/Torch_model/ExtractionEntities

MethodPositionBatch_sizelearning_rateCMeEE

/F1%CMeEE

/F1%

RoPE162e-573.23

TPLinker\81e-580.5762.69TPLinkerPos

81e-583.2163.10

81e-576.63

Tencent Muti-head\41e-583.5063.74Tencent Muti-headPos

41e-576.3264.18

41e-577.37

2e-568.81

Deep Biaffine\81e-578.2762.85Deep BiaffinePos

81e-577.5262.66

Original: https://blog.csdn.net/weixin_45839693/article/details/116425297
Author: 是算法不是法术
Title: 信息抽取(五)实体命名识别之嵌套实体识别哪家强,我做了一个简单的对比实验

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/548009/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球