电子发烧友App

硬声App

0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示
电子发烧友网>电子资料下载>电子资料>PyTorch教程9.6之递归神经网络的简洁实现

PyTorch教程9.6之递归神经网络的简洁实现

2023-06-05 | pdf | 0.20 MB | 次下载 | 免费

资料介绍

与我们大多数从头开始的实施一样, 第 9.5 节旨在深入了解每个组件的工作原理但是,当您每天使用 RNN 或编写生产代码时,您会希望更多地依赖于减少实现时间(通过为通用模型和函数提供库代码)和计算时间(通过优化这些库实现)。本节将向您展示如何使用深度学习框架提供的高级 API 更有效地实现相同的语言模型。和以前一样,我们首先加载时间机器数据集。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
from mxnet import np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2l

npx.set_np()
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf
from d2l import tensorflow as d2l

9.6.1. 定义模型

我们使用由高级 API 实现的 RNN 定义以下类。

class RNN(d2l.Module): #@save
  """The RNN model implemented with high-level APIs."""
  def __init__(self, num_inputs, num_hiddens):
    super().__init__()
    self.save_hyperparameters()
    self.rnn = nn.RNN(num_inputs, num_hiddens)

  def forward(self, inputs, H=None):
    return self.rnn(inputs, H)

Specifically, to initialize the hidden state, we invoke the member method begin_state. This returns a list that contains an initial hidden state for each example in the minibatch, whose shape is (number of hidden layers, batch size, number of hidden units). For some models to be introduced later (e.g., long short-term memory), this list will also contain other information.

class RNN(d2l.Module): #@save
  """The RNN model implemented with high-level APIs."""
  def __init__(self, num_hiddens):
    super().__init__()
    self.save_hyperparameters()
    self.rnn = rnn.RNN(num_hiddens)

  def forward(self, inputs, H=None):
    if H is None:
      H, = self.rnn.begin_state(inputs.shape[1], ctx=inputs.ctx)
    outputs, (H, ) = self.rnn(inputs, (H, ))
    return outputs, H

Flax does not provide an RNNCell for concise implementation of Vanilla RNNs as of today. There are more advanced variants of RNNs like LSTMs and GRUs which are available in the Flax linen API.

class RNN(nn.Module): #@save
  """The RNN model implemented with high-level APIs."""
  num_hiddens: int

  @nn.compact
  def __call__(self, inputs, H=None):
    raise NotImplementedError
class RNN(d2l.Module): #@save
  """The RNN model implemented with high-level APIs."""
  def __init__(self, num_hiddens):
    super().__init__()
    self.save_hyperparameters()
    self.rnn = tf.keras.layers.SimpleRNN(
      num_hiddens, return_sequences=True, return_state=True,
      time_major=True)

  def forward(self, inputs, H=None):
    outputs, H = self.rnn(inputs, H)
    return outputs, H

继承自9.5 节RNNLMScratch中的类 ,下面的类定义了一个完整的基于 RNN 的语言模型。请注意,我们需要创建一个单独的全连接输出层。RNNLM

class RNNLM(d2l.RNNLMScratch): #@save
  """The RNN-based language model implemented with high-level APIs."""
  def init_params(self):
    self.linear = nn.LazyLinear(self.vocab_size)

  def output_layer(self, hiddens):
    return self.linear(hiddens).swapaxes(0, 1)
class RNNLM(d2l.RNNLMScratch): #@save
  """The RNN-based language model implemented with high-level APIs."""
  def init_params(self):
    self.linear = nn.Dense(self.vocab_size, flatten=False)
    self.initialize()
  def output_layer(self, hiddens):
    return self.linear(hiddens).swapaxes(0, 1)
class RNNLM(d2l.RNNLMScratch): #@save
  """The RNN-based language model implemented with high-level APIs."""
  training: bool = True

  def setup(self):
    self.linear = nn.Dense(self.vocab_size)

  def output_layer(self, hiddens):
    return self.linear(hiddens).swapaxes(0, 1)

  def forward(self, X, state=None):
    embs = self.one_hot(X)
    rnn_outputs, _ = self.rnn(embs, state, self.training)
    return self.output_layer(rnn_outputs)
class RNNLM(d2l.RNNLMScratch): #@save
  """The RNN-based language model implemented with high-level APIs."""
  def init_params(self):
    self.linear = tf.keras.layers.Dense(self.vocab_size)

  def output_layer(self, hiddens):
    return tf.transpose(self.linear(hiddens), (1, 0, 2))

9.6.2. 训练和预测

在训练模型之前,让我们使用随机权重初始化的模型进行预测。鉴于我们还没有训练网络,它会产生无意义的预测。

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNN(num_inputs=len(data.vocab), num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)
'it hasgggggggggggggggggggg'
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNN(num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)
'it hasxlxlxlxlxlxlxlxlxlxl'
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNN(num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)
'it hasnvjdtagwbcsxvcjwuyby'

接下来,我们利用高级 API 训练我们的模型。

trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
https://file.elecfans.com/web2/M00/A9/C8/poYBAGR9NrKAA2V1ABG9IJKp_s8858.svg
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
https://file.elecfans.com/web2/M00/A9/C8/poYBAGR9NrmAC0QYABHpbt_PvZk929.svg
with d2l.try_gpu():
  trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1)
trainer.fit(model, data)
https://file.elecfans.com/web2/M00/A9/C8/poYBAGR9NsGAZ5qbABHCG7mYLzs874.svg

第 9.5 节相比,该模型实现了相当的困惑度,但由于实现优化,运行速度更快。和以前一样,我们可以在指定的前缀字符串之后生成预测标记。

 

下载该资料的人也在下载 下载该资料的人还在阅读
更多 >

评论

查看更多

下载排行

本周

  1. 1人工智能+消费:技术赋能与芯片驱动未来
  2. 15.25 MB  |  4次下载  |  免费
  3. 2⼯业电源&模块电源产品⼿册
  4. 15.40 MB   |  1次下载  |  免费
  5. 379M15 TO-252三端稳压IC规格书
  6. 0.86 MB   |  次下载  |  免费
  7. 4MBRD20150CT TO-252肖特基二极管规格书
  8. 0.54 MB   |  次下载  |  免费
  9. 5自动锁螺丝运动控制系统用户手册
  10. 6.65 MB   |  次下载  |  5 积分
  11. 6奥特光耦产品手册
  12. 4.83 MB  |  次下载  |  免费
  13. 7SMA系列10MHz~40GHz同轴检波器
  14. 559.60 KB  |  次下载  |  免费
  15. 8CD7388CZ:7W×4 四通道音频功率放大电路技术手册
  16. 0.39 MB   |  次下载  |  10 积分

本月

  1. 1元宇宙深度解析—未来的未来-风口还是泡沫
  2. 6.40 MB  |  241次下载  |  免费
  3. 2元宇宙底层硬件系列报告
  4. 13.42 MB  |  184次下载  |  免费
  5. 32022 年展望報告 – 半導體產業
  6. 1.43 MB  |  136次下载  |  免费
  7. 4晶振与滤波器应用电路《电子工程师必备:元器件应用宝典》
  8. 1.57 MB  |  90次下载  |  免费
  9. 5汽车电子行业深度解析:智能化与电动化方兴未艾
  10. 6.47 MB  |  71次下载  |  免费
  11. 6中国DPU行业白皮书
  12. 23.80 MB  |  37次下载  |  免费
  13. 7晶科鑫代理线-微盟电子2021年度产品目录选型手册
  14. 14.75 MB  |  27次下载  |  免费
  15. 8SJK晶振产品目录-简化版-2022
  16. 13.77 MB  |  20次下载  |  免费

总榜

  1. 1matlab软件下载入口
  2. 未知  |  935134次下载  |  10 积分
  3. 2开源硬件-PMP21529.1-4 开关降压/升压双向直流/直流转换器 PCB layout 设计
  4. 1.48MB  |  420064次下载  |  10 积分
  5. 3Altium DXP2002下载入口
  6. 未知  |  233089次下载  |  10 积分
  7. 4电路仿真软件multisim 10.0免费下载
  8. 340992  |  191425次下载  |  10 积分
  9. 5十天学会AVR单片机与C语言视频教程 下载
  10. 158M  |  183352次下载  |  10 积分
  11. 6labview8.5下载
  12. 未知  |  81602次下载  |  10 积分
  13. 7Keil工具MDK-Arm免费下载
  14. 0.02 MB  |  73822次下载  |  10 积分
  15. 8LabVIEW 8.6下载
  16. 未知  |  65991次下载  |  10 积分