0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

手把手教你,用BERT完成一个Kaggle竞赛

电子工程师 来源:lq 2019-02-20 10:59 次阅读

过去的一年,深度神经网络的应用开启了自然语言处理的新时代。预训练模型在研究领域的应用已经令许多NLP项目的最新成果产生了巨大的飞跃,例如文本分类,自然语言推理和问答

ELMo,ULMFiT 和OpenAI Transformer是其中几个关键的里程碑。所有这些算法都允许我们在大型数据库(例如所有维基百科文章)上预先训练无监督语言模型,然后在下游任务上对这些预先训练的模型进行微调。

这一年里,在这一领域中最激动人心的事件恐怕要数BERT的发布,这是一种基于多语言转换器的模型,它已经在各种NLP项目中取得了令人瞩目的成果。BERT是一种基于transformer架构的双向模型,它以一种速度更快的基于Attention的方法取代了RNN(LSTM和GRU)的sequential属性。

该模型还在两个无监督任务(“遮蔽语言模型”和“下一句预测”)上进行了预训练。这让我们可以通过对下游特定任务(例如情绪分类,意图检测,问答等)进行微调来使用预先训练的BERT模型。

本文将手把手教你,用BERT完成一个Kaggle竞赛。

在本文中,我们将重点介绍BERT在多标签文本分类问题中的应用。传统的分类问题假定每个文档都分配给一个且只分配给一个类别,即标签。这有时也被称为多元分类,比如类别数量是2的话,就叫做二元分类。

而多标签分类假设文档可以同时独立地分配给多个标签或类别。多标签分类具有许多实际应用,例如业务分类或为电影分配多个类型。在客户服务领域,此技术可用于识别客户电子邮件的多种意图。

我们将使用Kaggle的“恶意评论分类挑战”来衡量BERT在多标签文本分类中的表现。

在本次竞赛中,我们将尝试构建一个能够将给文本片段分配给同恶评类别的模型。我们设定了恶意评论类别作为模型的目标标签,它们包括普通恶评、严重恶评、污言秽语、威胁、侮辱和身份仇视。

比赛链接:

https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge

从哪开始?

Google Research最近公开了BERT 的tensorflow部署代码,并发布了以下预训练模型:

BERT-Base, Uncased: 12层,768个隐藏单元,自注意力的 head数为12,110M参数

BERT-Large, Uncased:24层,1024个隐藏单元,自注意力的 head数为16,340M参数

BERT-Base, Cased:12层,768个隐藏单元,自注意力的 head数为12,110M参数

BERT-Large, Cased:24层,1024个隐藏单元,自注意力的 head数为16,340M参数

BERT-Base, Multilingual Cased (最新推荐):104种语言,12层,768个隐藏单元,自注意力的 head数为12,110M参数

BERT-Base, Chinese:中文(简体和繁体),12层,768个隐藏单元,自注意力的 head数为12,110M参数

编者注:这里cased和uncased的意思是在进行WordPiece分词之前是否区分大小写。uncased表示全部会调整成小写,且剔除所有的重音标记;cased则表示文本的真实情况和重音标记都会保留下来。

我们将使用较小的Bert-Base,uncased模型来完成此任务。Bert-Base模型有12个attention层,所有文本都将由标记器转换为小写。我们在亚马逊云 p3.8xlarge EC2实例上运行此模型,该实例包含4个Tesla V100 GPU,GPU内存总共64 GB。

因为我个人更喜欢在TensorFlow上使用PyTorch,所以我们将使用来自HuggingFace的BERT模型PyTorch端口,这可从https://github.com/huggingface/pytorch-pretrained-BERT下载。我们已经用HuggingFace的repo脚本将预先训练的TensorFlow检查点(checkpoints)转换为PyTorch权重。

我们的实现很大程度上是以BERT原始实现中提供的run_classifier示例为基础的。

数据展示

数据用类InputExample来表示。

text_a:文本评论

text_b:未使用

标签:来自训练数据集的评论标签列表(很明显,测试数据集的标签将为空)

class InputExample(object): """A single training/test example for sequence classification.""" def __init__(self, guid, text_a, text_b=None, labels=None): """Constructs a InputExample. Args: guid: Unique id for the example. text_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified. text_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks. labels: (Optional) [string]. The label of the example. This should be specified for train and dev examples, but not for test examples. """ self.guid = guid self.text_a = text_a self.text_b = text_b self.labels = labels

class InputFeatures(object): """A single set of features of data.""" def __init__(self, input_ids, input_mask, segment_ids, label_ids): self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids self.label_ids = label_ids

我们将InputExample转换为BERT能理解的特征,该特征用类InputFeatures来表示。

input_ids:标记化文本的数字id列表

input_mask:对于真实标记将设置为1,对于填充标记将设置为0

segment_ids:对于我们的情况,这将被设置为全1的列表

label_ids:文本的one-hot编码标签

标记化(Tokenisation)

BERT-Base,uncased模型使用包含30,522个单词的词汇表。标记化过程涉及将输入文本拆分为词汇表中可用的标记列表。为了处理不在词汇表中的单词,BERT使用一种称为基于双字节编码(BPE,Byte-Pair Encoding)的WordPiece标记化技术。

这种方法将不在词汇表之中的词一步步分解成子词。因为子词是词汇表的一部分,模型已经学习了这些子词在上下文中的表示,并且该词的上下文仅仅是子词的上下文的组合,因此这个词就可以由一组子词表示。要了解关于此方法的更多详细信息,请参阅文章《使用子词单位的稀有单词的神经网络机器翻译》。

文章链接:

https://arxiv.org/pdf/1508.07909

在我看来,这与BERT本身一样都是一种突破。

模型架构

我们将改写BertForSequenceClassification类以使其满足多标签分类的要求。

class BertForMultiLabelSequenceClassification(PreTrainedBertModel): """BERT model for classification. This module is composed of the BERT model with a linear layer on top of the pooled output. """ def __init__(self, config, num_labels=2): super(BertForMultiLabelSequenceClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) if labels is not None: loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) return loss else: return logits def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True

这里主要的改动是用logits作为二进制交叉熵的损失函数(BCEWithLogitsLoss),取代用于多元分类的vanilla交叉熵损失函数(CrossEntropyLoss)。二进制交叉熵损失可以让我们的模型为标签分配独立的概率。

下面的模型摘要说明了模型的各个层及其维度。

BertForMultiLabelSequenceClassification( (bert): BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(28996, 768) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): FusedLayerNorm(torch.Size([768]), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1) ) (encoder): BertEncoder( (layer): ModuleList(# 12 BertLayers (11): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): FusedLayerNorm(torch.Size([768]), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): FusedLayerNorm(torch.Size([768]), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=768, out_features=768, bias=True) (activation): Tanh() ) ) (dropout): Dropout(p=0.1) (classifier): Linear(in_features=768, out_features=6, bias=True))

BertEmbeddings:输入嵌入层

BertEncoder: 12个BERT模型attention层

分类器:我们的多标签分类器,out_features = 6,每个分类符对应6个标签

模型训练

训练循环与原始BERT实现中提供的run_classifier.py里的循环相同。我们的模型训练了4个epoch(一个完整的数据集通过了神经网络一次并且返回了一次,这个过程称为一个 epoch),每批数据大小为32,序列长度为512,即预训练模型的最大可能性。根据原始论文的建议,学习率保持在3e-5。

因为有机会使用多个GPU,所以我们将Pytorch模型封装在DataParallel模块中,这使我们能够在所有可用的GPU上进行训练。

我们没有使用半精度FP16技术,因为使用logits 损失函数的二进制交叉熵不支持FP16处理。但这并不会影响最终结果,只是需要更长的时间训练。

评估指标

def accuracy_thresh(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True): "Compute accuracy when `y_pred` and `y_true` are the same size." if sigmoid: y_pred = y_pred.sigmoid() return np.mean(((y_pred>thresh)==y_true.byte()).float().cpu().numpy(), axis=1).sum()

from sklearn.metrics import roc_curve, auc# Compute ROC curve and ROC area for each classfpr = dict()tpr = dict()roc_auc = dict()for i in range(num_labels): fpr[i], tpr[i], _ = roc_curve(all_labels[:, i], all_logits[:, i]) roc_auc[i] = auc(fpr[i], tpr[i])# Compute micro-average ROC curve and ROC areafpr["micro"], tpr["micro"], _ = roc_curve(all_labels.ravel(), all_logits.ravel())roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

我们为精度度量函数增加了一个阈值,默认设置为0.5。

对于多标签分类,更重要的指标是ROC-AUC曲线。这也是Kaggle比赛的评分指标。我们分别计算每个标签的ROC-AUC,并对单个标签的roc-auc分数进行微平均。

如果想深入了解roc-auc曲线,这里有一篇很不错的博客。

博客链接:

https://towardsdatascience.com/understanding-auc-roc-curve-68b2303cc9c5。

评估分数

我们重复进行了几次实验,每次都有一些输入上的变化,但都得到了类似的结果,如下所示:

训练损失:0.022,验证损失:0.018,验证准确度:99.31%。

各个标签的ROC-AUC分数:

普通恶评:0.9988

严重恶评:0.9935

污言秽语:0.9988

威胁:0.9989

侮辱:0.9975

身份仇视:0.9988

微观平均ROC-AUC得分:0.9987

这样的结果似乎非常令人鼓舞,因为我们看上去已经创建了一个近乎完美的模型来检测文本评论的恶毒程度。现在看看我们在Kaggle排行榜上的得分。

Kaggle竞赛结果

我们在Kaggle提供的测试数据集上运行推理逻辑,并将结果提交给竞赛。以下是结果:

我们的roc-auc评分达到了0.9863,在所有竞争者中排名前10%。为了使比赛结果更具说服力,这次Kaggle比赛的奖金为35000美元,而一等奖得分为0.9885。

最高分的团队由专业的高技能数据科学家和从业者组成。除了我们所做的工作之外,他们还使用各种技术来进行数据集成,数据增强(data augmentation)和测试时增强(test-time augmentation)。

结论和后续

我们使用强大的BERT预训练模型实现了多标签分类模型。正如我们所展示的那样,模型在已熟知的公开数据集上得到了相当不错的结果。我们能够建立一个世界级的模型生产应用于各行业,尤其是客户服务领域。

对于我们来说,下一步将是使用“遮蔽语言模型”和“下一句预测”对下游任务的文本语料库来微调预训练的语言模型。这将是一项无监督的任务,希望该模型能够学习一些我们自定义的上下文和术语,这和ULMFiT使用的技术类似。

资料链接:

https://nbviewer.jupyter.org/github/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb

https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4562

    浏览量

    98644
  • 数据库
    +关注

    关注

    7

    文章

    3584

    浏览量

    63344
  • Transformer
    +关注

    关注

    0

    文章

    130

    浏览量

    5895

原文标题:搞定NLP领域的“变形金刚”!手把手教你用BERT进行多标签文本分类

文章出处:【微信号:BigDataDigest,微信公众号:大数据文摘】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    手把手教你使用QUARTUS

    手把手教你使用QUARTUS
    发表于 12-29 20:35

    手把手教你构建一个完整的工程

    手把手教你构建一个完整的工程
    发表于 08-03 09:54 33次下载
    <b class='flag-5'>手把手</b><b class='flag-5'>教你</b>构建一个完整的工程

    美女手把手教你如何装机(中)

    美女手把手教你如何装机(中) 再来是硬碟的部份,这款机壳还不错,可以旋转支架~
    发表于 01-27 11:14 1355次阅读

    美女手把手教你如何装机(下)

    美女手把手教你如何装机(下) 接著下来就是今天的重头戏,开核萝!~
    发表于 01-27 11:16 2867次阅读

    手把手教你学DSP28335_张卿杰

    手把手教你学DSP28335张卿杰百度云分享手把手教你学DSP28335张卿杰百度云分享
    发表于 01-11 11:45 173次下载

    手把手教你学习FPGA—LED篇

    电子专业单片机相关知识学习教材资料——手把手教你学习FPGA—LED篇
    发表于 08-08 17:19 0次下载

    手把手教你做彩铃电子时钟---第1章

    手把手教你做彩铃电子时钟全部资料
    发表于 11-14 17:15 7次下载

    手把手教你做彩铃电子时钟---第4章

    手把手教你做彩铃电子时钟
    发表于 11-14 17:09 12次下载

    手把手教你做彩铃电子时钟--第7章

    手把手教你做彩铃电子时钟
    发表于 11-14 17:05 6次下载

    手把手教你做彩铃电子时钟---第12章

    手把手教你做彩铃电子时钟
    发表于 11-14 16:59 9次下载

    手把手教你如何开始DSP编程

    手把手教你如何开始DSP编程。
    发表于 04-09 11:54 12次下载
    <b class='flag-5'>手把手</b><b class='flag-5'>教你</b>如何开始DSP编程

    手把手教你学LabVIEW视觉设计

    手把手教你学LabVIEW视觉设计手把手教你学LabVIEW视觉设计手把手教你学LabVIEW视
    发表于 03-06 01:41 2757次阅读

    手把手教你开关电源PCB排板

    手把手教你开关电源PCB排板(新型电源技术)-分享一下开关电源PCB排板的基本要点及分析,以及例子讲解。绝对的手把手
    发表于 09-18 12:27 48次下载
    <b class='flag-5'>手把手</b><b class='flag-5'>教你</b>开关电源PCB排板

    手把手教你学DSP-基于TMS320X281x

    顾卫刚手把手教你学DSP
    发表于 05-24 16:06 26次下载

    手把手教你学FPGA仿真

    电子发烧友网站提供《手把手教你学FPGA仿真.pdf》资料免费下载
    发表于 10-19 09:17 1次下载
    <b class='flag-5'>手把手</b><b class='flag-5'>教你</b>学FPGA仿真