NLP(五十)使用PyTorch训练多标签文本分类模型

已经很长时间没有更新了。最近,上海疫情严重,我的工作很忙。我只是发现,以前有时间和心情写文章是一种奢侈。

[En]

It has not been updated for a long time. Recently, the epidemic in Shanghai is serious, and I am busy with my work. I only found that it was such a luxury to have the time and mood to write articles before.

本文将介绍如何使用PyTorch训练多标签文本分类模型。
所谓多标签文本分类,是指文本可能属于多个类别,而不是单一类别。文本多分类和文本多分类的区别在于,文本多分类模型往往有多个类别,但文本属于其中之一,而多标签文本分类也有多个类别,但文本属于多个类别。

[En]

The so-called multi-tag text classification means that the text may belong to multiple categories rather than a single category. The difference between text multi-classification and text multi-classification is that text multi-classification models often have multiple categories, but text belongs to one of them, while multi-label text classification also has multiple categories, but text belongs to multiple categories.

数据集

本文演示的数据集为英语论文数据集,参考网址为:https://datahack.analyticsvidhya.com/contest/janatahack-independence-day-2020-ml-hackathon,数据下载需翻墙,读者也可参看后续给出的项目Github。该论文数据集实际上是比赛数据,供选手尝试模型。本文所采用的数据集为英语,至于中文,其原理是一致的,稍微做调整即可。
该数据集给出论文的标题(TITLE)和摘要(ABSTRACT),来预测论文属于哪个主题。该数据集共有20972个训练样本,有六个主题,分别为:Computer Science, Physics, Mathematics, Statistics, Quantitative Biology, Quantitative Finance。在此给出一个样例数据:

TITLE : Many-Body Localization: Stability and Instability
ABSTRACT: Rare regions with weak disorder (Griffiths regions) have the potential to spoil localization. We describe a non-perturbative construction of local integrals of motion (LIOMs) for a weakly interacting spin chain in one dimension, under a physically reasonable assumption on the statistics of eigenvalues. We discuss ideas about the situation in higher dimensions, where one can no longer ensure that interactions involving the Griffiths regions are much smaller than the typical energy-level spacing for such regions. We argue that ergodicity is restored in dimension d > 1, although equilibration should be extremely slow, similar to the dynamics of glasses.

TOPICS: Physics, Mathematics

模型结构

本文给出的多标签文本分类模型使用预训练模型(BERT),下游网络结构较为简单,算是比较中庸但简单好用的模型方案,模型结构图如下:

NLP(五十)使用PyTorch训练多标签文本分类模型
该模型使用PyTorch的transformers模块来实现,代码如下:

class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.l1 = transformers.BertModel.from_pretrained(MODEL_NAME_OR_PATH)
        self.l2 = torch.nn.Dropout(0.2)
        self.l3 = torch.nn.Linear(HIDDEN_LAYER_SIZE, 6)

    def forward(self, ids, mask, token_type_ids):
        _, output_1 = self.l1(ids, attention_mask=mask, token_type_ids=token_type_ids)
        output_2 = self.l2(output_1)
        output = self.l3(output_2)
        return output

使用损失函数为 torch.nn.BCEWithLogitsLoss,因而不需要在output层后加上sigmoid激活函数。
在模型训练过程中,将训练数据随机分为训练集和测试集,两部分的比例为8:2,模型参数设置如下:

[En]

In the process of model training, the training data are randomly divided into training set and test set, the proportion of the two parts is 8:2, and the model parameters are set as follows:


MAX_LEN = 128
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 1e-05

MODEL_NAME_OR_PATH = './bert-base-uncased'
HIDDEN_LAYER_SIZE = 768

模型效果

笔者分别尝试使用 bert-base-uncasedbert-large-uncased训练模型,并在测试数据上进行预测,在比赛官网上进行提交,结果如下表:

模型max lengthbatch sizeprivate scorerankbert-base-uncased128320.8320107bert-large-uncased128160.835579

看过一个rank为17的方案,其采用的是多个预训练模型训练后的集成,后接网络与笔者一致。

总结

本项目已经开源,其Github网址为:https://github.com/percent4/pytorch_english_mltc。后续将尝试该模型在中文多标签文本分类数据集上的效果,感谢大家阅读~

参考网址

  1. https://jovian.ai/kyawkhaung/1-titles-only-for-medium
  2. https://datahack.analyticsvidhya.com/contest/janatahack-independence-day-2020-ml-hackathon
  3. Fine-tuned BERT Model for Multi-Label Tweets Classification: https://trec.nist.gov/pubs/trec28/papers/DICE_UPB.IS.pdf

Original: https://blog.csdn.net/jclian91/article/details/123563040
Author: 山阴少年
Title: NLP(五十)使用PyTorch训练多标签文本分类模型

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/527521/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球