电子发烧友App

硬声App

0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示
电子发烧友网>电子资料下载>电子资料>PyTorch教程15.10之预训练BERT

PyTorch教程15.10之预训练BERT

2023-06-05 | pdf | 0.15 MB | 次下载 | 免费

资料介绍

借助15.8 节中实现的 BERT 模型和15.9 节中从 WikiText-2 数据集生成的预训练示例 ,我们将在本节中在 WikiText-2 数据集上预训练 BERT。

import torch
from torch import nn
from d2l import torch as d2l
from mxnet import autograd, gluon, init, np, npx
from d2l import mxnet as d2l

npx.set_np()

首先,我们将 WikiText-2 数据集加载为用于屏蔽语言建模和下一句预测的小批量预训练示例。批量大小为 512,BERT 输入序列的最大长度为 64。请注意,在原始 BERT 模型中,最大长度为 512。

batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)
batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)
Downloading ../data/wikitext-2-v1.zip from https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip...

15.10.1。预训练 BERT

原始 BERT 有两个不同模型大小的版本 Devlin et al. , 2018基础模型(BERTBASE) 使用 12 层(Transformer 编码器块),具有 768 个隐藏单元(隐藏大小)和 12 个自注意力头。大模型(BERTLARGE) 使用 24 层,有 1024 个隐藏单元和 16 个自注意力头。值得注意的是,前者有 1.1 亿个参数,而后者有 3.4 亿个参数。为了便于演示,我们定义了一个小型 BERT,使用 2 层、128 个隐藏单元和 2 个自注意力头。

net = d2l.BERTModel(len(vocab), num_hiddens=128,
          ffn_num_hiddens=256, num_heads=2, num_blks=2, dropout=0.2)
devices = d2l.try_all_gpus()
loss = nn.CrossEntropyLoss()
net = d2l.BERTModel(len(vocab), num_hiddens=128, ffn_num_hiddens=256,
          num_heads=2, num_blks=2, dropout=0.2)
devices = d2l.try_all_gpus()
net.initialize(init.Xavier(), ctx=devices)
loss = gluon.loss.SoftmaxCELoss()

在定义训练循环之前,我们定义了一个辅助函数 _get_batch_loss_bert给定训练示例的碎片,此函数计算掩码语言建模和下一句预测任务的损失。请注意,BERT 预训练的最终损失只是掩码语言建模损失和下一句预测损失的总和。

#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,
             segments_X, valid_lens_x,
             pred_positions_X, mlm_weights_X,
             mlm_Y, nsp_y):
  # Forward pass
  _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
                 valid_lens_x.reshape(-1),
                 pred_positions_X)
  # Compute masked language model loss
  mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
  mlm_weights_X.reshape(-1, 1)
  mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
  # Compute next sentence prediction loss
  nsp_l = loss(nsp_Y_hat, nsp_y)
  l = mlm_l + nsp_l
  return mlm_l, nsp_l, l
#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X_shards,
             segments_X_shards, valid_lens_x_shards,
             pred_positions_X_shards, mlm_weights_X_shards,
             mlm_Y_shards, nsp_y_shards):
  mlm_ls, nsp_ls, ls = [], [], []
  for (tokens_X_shard, segments_X_shard, valid_lens_x_shard,
     pred_positions_X_shard, mlm_weights_X_shard, mlm_Y_shard,
     nsp_y_shard) in zip(
    tokens_X_shards, segments_X_shards, valid_lens_x_shards,
    pred_positions_X_shards, mlm_weights_X_shards, mlm_Y_shards,
    nsp_y_shards):
    # Forward pass
    _, mlm_Y_hat, nsp_Y_hat = net(
      tokens_X_shard, segments_X_shard, valid_lens_x_shard.reshape(-1),
      pred_positions_X_shard)
    # Compute masked language model loss
    mlm_l = loss(
      mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y_shard.reshape(-1),
      mlm_weights_X_shard.reshape((-1, 1)))
    mlm_l = mlm_l.sum() / (mlm_weights_X_shard.sum() + 1e-8)
    # Compute next sentence prediction loss
    nsp_l = loss(nsp_Y_hat, nsp_y_shard)
    nsp_l = nsp_l.mean()
    mlm_ls.append(mlm_l)
    nsp_ls.append(nsp_l)
    ls.append(mlm_l + nsp_l)
    npx.waitall()
  return mlm_ls, nsp_ls, ls

调用上述两个辅助函数,以下 函数定义了在 WikiText-2 ( ) 数据集上train_bert预训练 BERT ( ) 的过程训练 BERT 可能需要很长时间。与在函数中指定训练的时期数不同 (参见第 14.1 节),以下函数的输入指定训练的迭代步数。nettrain_itertrain_ch13num_steps

def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
  net(*next(iter(train_iter))[:4])
  net = nn.DataParallel(net, device_ids=devices).to(devices[0])
  trainer = torch.optim.Adam(net.parameters(), lr=0.01)
  step, timer = 0, d2l.Timer()
  animator = d2l.Animator(xlabel='step', ylabel='loss',
              xlim=[1, num_steps], legend=['mlm', 'nsp'])
  # Sum of masked language modeling losses, sum of next sentence prediction
  # losses, no. of sentence pairs, count
  metric = d2l.Accumulator(4)
  num_steps_reached = False
  while step < num_steps and not num_steps_reached:
    for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
      mlm_weights_X, mlm_Y, nsp_y in train_iter:
      tokens_X = tokens_X.to(devices[0])
      segments_X = segments_X.to(devices[0])
      valid_lens_x = valid_lens_x.to(devices[0])
      pred_positions_X = pred_positions_X.to(devices[0])
      mlm_weights_X = mlm_weights_X.to(devices[0])
      mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
      trainer.zero_grad()
      timer.start()
      mlm_l, nsp_l, l = _get_batch_loss_bert(
        net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
        pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
      l.backward()
      trainer.step()
      metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)
      timer.stop()
      animator.add(step + 1,
             (metric[0] / metric[3], metric[1] / metric[3]))
      step += 1
      if step == num_steps:
        num_steps_reached = True
        break

  print(f'MLM loss {metric[0] / metric[3]:.3f}, '
     f'NSP loss {metric[1] / metric[3]:.3f}')
  print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
     f'{str(devices)}')
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
  trainer = gluon.Trainer(net.collect_params(), 'adam',
              {'learning_rate': 0.01})
  step, timer = 0, d2l.Timer()
  animator = d2l.Animator(xlabel='step', ylabel='loss',
              xlim=[1, num_steps], legend=['mlm', 'nsp'])
  # Sum of masked language modeling losses, sum of next sentence prediction
  # losses, no. of sentence pairs, count
  metric = d2l.

下载该资料的人也在下载 下载该资料的人还在阅读
更多 >

评论

查看更多

下载排行

本周

  1. 1山景DSP芯片AP8248A2数据手册
  2. 1.06 MB  |  532次下载  |  免费
  3. 2RK3399完整板原理图(支持平板,盒子VR)
  4. 3.28 MB  |  339次下载  |  免费
  5. 3TC358743XBG评估板参考手册
  6. 1.36 MB  |  330次下载  |  免费
  7. 4DFM软件使用教程
  8. 0.84 MB  |  295次下载  |  免费
  9. 5元宇宙深度解析—未来的未来-风口还是泡沫
  10. 6.40 MB  |  227次下载  |  免费
  11. 6迪文DGUS开发指南
  12. 31.67 MB  |  194次下载  |  免费
  13. 7元宇宙底层硬件系列报告
  14. 13.42 MB  |  182次下载  |  免费
  15. 8FP5207XR-G1中文应用手册
  16. 1.09 MB  |  178次下载  |  免费

本月

  1. 1OrCAD10.5下载OrCAD10.5中文版软件
  2. 0.00 MB  |  234315次下载  |  免费
  3. 2555集成电路应用800例(新编版)
  4. 0.00 MB  |  33566次下载  |  免费
  5. 3接口电路图大全
  6. 未知  |  30323次下载  |  免费
  7. 4开关电源设计实例指南
  8. 未知  |  21549次下载  |  免费
  9. 5电气工程师手册免费下载(新编第二版pdf电子书)
  10. 0.00 MB  |  15349次下载  |  免费
  11. 6数字电路基础pdf(下载)
  12. 未知  |  13750次下载  |  免费
  13. 7电子制作实例集锦 下载
  14. 未知  |  8113次下载  |  免费
  15. 8《LED驱动电路设计》 温德尔著
  16. 0.00 MB  |  6656次下载  |  免费

总榜

  1. 1matlab软件下载入口
  2. 未知  |  935054次下载  |  免费
  3. 2protel99se软件下载(可英文版转中文版)
  4. 78.1 MB  |  537798次下载  |  免费
  5. 3MATLAB 7.1 下载 (含软件介绍)
  6. 未知  |  420027次下载  |  免费
  7. 4OrCAD10.5下载OrCAD10.5中文版软件
  8. 0.00 MB  |  234315次下载  |  免费
  9. 5Altium DXP2002下载入口
  10. 未知  |  233046次下载  |  免费
  11. 6电路仿真软件multisim 10.0免费下载
  12. 340992  |  191187次下载  |  免费
  13. 7十天学会AVR单片机与C语言视频教程 下载
  14. 158M  |  183279次下载  |  免费
  15. 8proe5.0野火版下载(中文版免费下载)
  16. 未知  |  138040次下载  |  免费