电子发烧友App

硬声App

0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示
电子发烧友网>电子资料下载>电子资料>PyTorch教程15.9之预训练BERT的数据集

PyTorch教程15.9之预训练BERT的数据集

2023-06-05 | pdf | 0.14 MB | 次下载 | 免费

资料介绍

为了预训练第 15.8 节中实现的 BERT 模型,我们需要以理想的格式生成数据集,以促进两项预训练任务:掩码语言建模和下一句预测。一方面,原始的 BERT 模型是在两个巨大的语料库 BookCorpus 和英文维基百科(参见第15.8.5 节)的串联上进行预训练的,这使得本书的大多数读者难以运行。另一方面,现成的预训练 BERT 模型可能不适合医学等特定领域的应用。因此,在自定义数据集上预训练 BERT 变得越来越流行。为了便于演示 BERT 预训练,我们使用较小的语料库 WikiText-2 ( Merity et al. , 2016 )

15.3节用于预训练word2vec的PTB数据集相比,WikiText-2(i)保留了原有的标点符号,适合下一句预测;(ii) 保留原始案例和编号;(iii) 大两倍以上。

import os
import random
import torch
from d2l import torch as d2l
import os
import random
from mxnet import gluon, np, npx
from d2l import mxnet as d2l

npx.set_np()

在 WikiText-2 数据集中,每一行代表一个段落,其中在任何标点符号及其前面的标记之间插入空格。保留至少两句话的段落。为了简单起见,为了拆分句子,我们只使用句点作为分隔符。我们将在本节末尾的练习中讨论更复杂的句子拆分技术。

#@save
d2l.DATA_HUB['wikitext-2'] = (
  'https://s3.amazonaws.com/research.metamind.io/wikitext/'
  'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')

#@save
def _read_wiki(data_dir):
  file_name = os.path.join(data_dir, 'wiki.train.tokens')
  with open(file_name, 'r') as f:
    lines = f.readlines()
  # Uppercase letters are converted to lowercase ones
  paragraphs = [line.strip().lower().split(' . ')
         for line in lines if len(line.split(' . ')) >= 2]
  random.shuffle(paragraphs)
  return paragraphs
#@save
d2l.DATA_HUB['wikitext-2'] = (
  'https://s3.amazonaws.com/research.metamind.io/wikitext/'
  'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')

#@save
def _read_wiki(data_dir):
  file_name = os.path.join(data_dir, 'wiki.train.tokens')
  with open(file_name, 'r') as f:
    lines = f.readlines()
  # Uppercase letters are converted to lowercase ones
  paragraphs = [line.strip().lower().split(' . ')
         for line in lines if len(line.split(' . ')) >= 2]
  random.shuffle(paragraphs)
  return paragraphs

15.9.1。为预训练任务定义辅助函数

下面,我们首先为两个 BERT 预训练任务实现辅助函数:下一句预测和掩码语言建模。这些辅助函数将在稍后将原始文本语料库转换为理想格式的数据集以预训练 BERT 时调用。

15.9.1.1。生成下一句预测任务

根据15.8.5.2 节的描述,该 _get_next_sentence函数为二元分类任务生成一个训练样例。

#@save
def _get_next_sentence(sentence, next_sentence, paragraphs):
  if random.random() < 0.5:
    is_next = True
  else:
    # `paragraphs` is a list of lists of lists
    next_sentence = random.choice(random.choice(paragraphs))
    is_next = False
  return sentence, next_sentence, is_next
#@save
def _get_next_sentence(sentence, next_sentence, paragraphs):
  if random.random() < 0.5:
    is_next = True
  else:
    # `paragraphs` is a list of lists of lists
    next_sentence = random.choice(random.choice(paragraphs))
    is_next = False
  return sentence, next_sentence, is_next

以下函数paragraph通过调用该 _get_next_sentence函数从输入生成用于下一句预测的训练示例。paragraph是一个句子列表,其中每个句子都是一个标记列表。参数 max_len指定预训练期间 BERT 输入序列的最大长度。

#@save
def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
  nsp_data_from_paragraph = []
  for i in range(len(paragraph) - 1):
    tokens_a, tokens_b, is_next = _get_next_sentence(
      paragraph[i], paragraph[i + 1], paragraphs)
    # Consider 1 '' token and 2 '' tokens
    if len(tokens_a) + len(tokens_b) + 3 > max_len:
      continue
    tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
    nsp_data_from_paragraph.append((tokens, segments, is_next))
  return nsp_data_from_paragraph
#@save
def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
  nsp_data_from_paragraph = []
  for i in range(len(paragraph) - 1):
    tokens_a, tokens_b, is_next = _get_next_sentence(
      paragraph[i], paragraph[i + 1], paragraphs)
    # Consider 1 '' token and 2 '' tokens
    if len(tokens_a) + len(tokens_b) + 3 > max_len:
      continue
    tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
    nsp_data_from_paragraph.append((tokens, segments, is_next))
  return nsp_data_from_paragraph

15.9.1.2。生成掩码语言建模任务

为了从 BERT 输入序列为掩码语言建模任务生成训练示例,我们定义了以下 _replace_mlm_tokens函数。在它的输入中,tokens是代表BERT输入序列的token列表,candidate_pred_positions 是BERT输入序列的token索引列表,不包括特殊token(masked语言建模任务中不预测特殊token),num_mlm_preds表示预测(召回 15% 的随机标记来预测)。遵循第 15.8.5.1 节中屏蔽语言建模任务的定义 ,在每个预测位置,输入可能被特殊的“”标记或随机标记替换,或者保持不变。最后,该函数返回可能替换后的输入标记、发生预测的标记索引以及这些预测的标签

#@save
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
            vocab):
  # For the input of a masked language model, make a new copy of tokens and
  # replace some of them by '' or random tokens
  mlm_input_tokens = [token for token in tokens]
  pred_positions_and_labels = []
  # Shuffle for getting 15% random tokens for prediction in the masked
  # language modeling task
  random.shuffle(candidate_pred_positions)
  for mlm_pred_position in candidate_pred_positions:
    if len(pred_positions_and_labels) >= num_mlm_preds:
      break
    masked_token = None
    # 80% of the time: replace the word with the '' token
    if random.random() < 0.8:
      masked_token = ''
    else:
      # 10% of the time: keep the word unchanged
      if random.random() < 0.5:
        masked_token = tokens[mlm_pred_position]
      # 10% of the time: replace the word with a random word
      else:
        masked_token = random.choice(vocab.idx_to_token)
    mlm_input_tokens[mlm_pred_position] = masked_token
    pred_positions_and_labels.append(
      (mlm_pred_position, tokens[mlm_pred_position]))
  return mlm_input_tokens, pred_positions_and_labels
#@save
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
            vocab):
  # For the input of a masked language model, make a new copy of tokens and
  # replace some of them by '' or random tokens
  mlm_input_tokens = [token for token in tokens]
  pred_positions_and_labels = []
  # Shuffle for getting 15% random tokens for prediction in the masked
  # language modeling task
  random.shuffle(candidate_pred_positions)
  for mlm_pred_position in candidate_pred_positions:
    if len(pred_positions_and_labels) >= num_mlm_preds:
      break
    masked_token = None
    # 80% of the time: replace the word with the '' token
    if random.random() < 0.8:
      masked_token = ''
    else:
      # 10% of the time: keep the word unchanged
      if random.random() < 0.5:
        masked_token = tokens[mlm_pred_position]
      # 10% of the time: replace the word with a random word
      else:
        masked_token = random.choice(vocab.idx_to_token)
    mlm_input_tokens[mlm_pred_position] = masked_token
    pred_positions_and_labels.append(
      (mlm_pred_position, tokens[mlm_pred_position]))
  return mlm_input_tokens, pred_positions_and_labels

通过调用上述_replace_mlm_tokens函数,以下函数将 BERT 输入序列 ( tokens) 作为输入并返回输入标记的索引(在可能的标记替换之后,如第15.8.5.1 节所述)、发生预测的标记索引和标签这些预测的指标。

#@save
def _get_mlm_data_from_tokens(tokens, vocab):
  candidate_pred_positions = []
  # `tokens` is a list of strings
  for i, token in enumerate(tokens):
    # Special tokens are not predicted in the masked language modeling
    # task
    if token in ['', '']:
      continue
    candidate_pred_positions.append(i)
  # 15% of random tokens are predicted in the masked language modeling task
  num_mlm_preds = max(1, round(len(tokens) * 0.15))
  mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(
    tokens, candidate_pred_positions, num_mlm_preds, vocab)
  pred_positions_and_labels = sorted(pred_positions_and_labels,
                    key=lambda x: x[0])
  pred_positions <

下载该资料的人也在下载 下载该资料的人还在阅读
更多 >

评论

查看更多

下载排行

本周

  1. 1山景DSP芯片AP8248A2数据手册
  2. 1.06 MB  |  532次下载  |  免费
  3. 2RK3399完整板原理图(支持平板,盒子VR)
  4. 3.28 MB  |  339次下载  |  免费
  5. 3TC358743XBG评估板参考手册
  6. 1.36 MB  |  330次下载  |  免费
  7. 4DFM软件使用教程
  8. 0.84 MB  |  295次下载  |  免费
  9. 5元宇宙深度解析—未来的未来-风口还是泡沫
  10. 6.40 MB  |  227次下载  |  免费
  11. 6迪文DGUS开发指南
  12. 31.67 MB  |  194次下载  |  免费
  13. 7元宇宙底层硬件系列报告
  14. 13.42 MB  |  182次下载  |  免费
  15. 8FP5207XR-G1中文应用手册
  16. 1.09 MB  |  178次下载  |  免费

本月

  1. 1OrCAD10.5下载OrCAD10.5中文版软件
  2. 0.00 MB  |  234315次下载  |  免费
  3. 2555集成电路应用800例(新编版)
  4. 0.00 MB  |  33566次下载  |  免费
  5. 3接口电路图大全
  6. 未知  |  30323次下载  |  免费
  7. 4开关电源设计实例指南
  8. 未知  |  21549次下载  |  免费
  9. 5电气工程师手册免费下载(新编第二版pdf电子书)
  10. 0.00 MB  |  15349次下载  |  免费
  11. 6数字电路基础pdf(下载)
  12. 未知  |  13750次下载  |  免费
  13. 7电子制作实例集锦 下载
  14. 未知  |  8113次下载  |  免费
  15. 8《LED驱动电路设计》 温德尔著
  16. 0.00 MB  |  6656次下载  |  免费

总榜

  1. 1matlab软件下载入口
  2. 未知  |  935054次下载  |  免费
  3. 2protel99se软件下载(可英文版转中文版)
  4. 78.1 MB  |  537798次下载  |  免费
  5. 3MATLAB 7.1 下载 (含软件介绍)
  6. 未知  |  420027次下载  |  免费
  7. 4OrCAD10.5下载OrCAD10.5中文版软件
  8. 0.00 MB  |  234315次下载  |  免费
  9. 5Altium DXP2002下载入口
  10. 未知  |  233046次下载  |  免费
  11. 6电路仿真软件multisim 10.0免费下载
  12. 340992  |  191187次下载  |  免费
  13. 7十天学会AVR单片机与C语言视频教程 下载
  14. 158M  |  183279次下载  |  免费
  15. 8proe5.0野火版下载(中文版免费下载)
  16. 未知  |  138040次下载  |  免费