0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

基于Eric Jang用158行Python代码实现该系统的思路

DPVg_AI_era 来源:lq 2019-09-13 16:13 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

最近,谷歌 DeepMInd 发表论文提出了一个用于图像生成的递归神经网络,该系统大大提高了 MNIST 上生成模型的质量。为更加深入了解 DRAW,本文作者基于 Eric Jang 用 158 行 Python 代码实现该系统的思路,详细阐述了 DRAW 的概念、架构和优势等。

递归神经网络是一种用于图像生成的神经网络结构。Draw Networks 结合了一种新的空间注意机制,该机制模拟了人眼的中心位置,采用了一个顺序变化的自动编码框架,使之对复杂图像进行迭代构造。 该系统大大提高了 MNIST 上生成模型的质量,特别是当对街景房屋编号数据集进行训练时,肉眼竟然无法将它生成的图像与真实数据区别开来。

Draw 体系结构的核心是一对递归神经网络:一个是压缩用于训练的真实图像的编码器,另一个是在接收到代码后重建图像的解码器。这一组合系统采用随机梯度下降的端到端训练,损失函数的最大值变分主要取决于对数似然函数的数据。

Draw 网络类似于其他变分自动编码器,它包含一个编码器网络,该编码器网络决定着潜在代码上的 distribution(潜在代码主要捕获有关输入数据的显著信息),解码器网络接收来自 code distribution 的样本,并利用它们来调节其自身图像的 distribution 。

DRAW 与其他自动解码器的三大区别 编码器和解码器都是 DRAW 中的递归网络,解码器的输出依次添加到 distribution 中以生成数据,而不是一步一步地生成 distribution 。动态更新的注意机制用于限制由编码器负责的输入区域和由解码器更新的输出区域 。简单地说,这一网络在每个 time-step 都能决定“读到哪里”和“写到哪里”以及“写什么”。

左:传统变分自动编码器 在生成过程中,从先前的 P(z)中提取一个样本 z ,并通过前馈译码器网络来计算给定样本的输入 P(x_z)的概率。 在推理过程中,输入 x 被传递到编码器网络,在潜在变量上产生一个近似的后验 Q(z|x) 。在训练过程中,从 Q(z|x) 中抽取 z,然后用它计算总描述长度 KL ( Q (Z|x)∣∣ P(Z)−log(P(x|z)),该长度随随机梯度的下降(https://en.wikipedia.org/wiki/Stochastic_gradient_descent)而减小至最小值。

右:DRAW网络 在每一个步骤中,都会将先前 P(z)中的一个样本 z_t 传递给递归解码器网络,该网络随后会修改 canvas matrix 的一部分。最后一个 canvas matrix cT 用于计算 P(x|z_1:t)。 在推理过程中,每个 time-step 都会读取输入,并将结果传递给编码器 RNN,然后从上一 time-step 中的 RNN 指定读取位置,编码器 RNN 的输出用于计算该 time-step 的潜在变量的近似后验值。 损失函数 最后一个 canvas matrix cT 用于确定输入数据的模型 D(X | cT)的参数。如果输入是二进制的,D 的自然选择呈伯努利分布,means 由σ(cT) 给出。重建损失 Lx 定义为 D 下 x 的负对数概率:          The latent loss 潜在distributions序列  的潜在损失 被定义为源自  

的潜在先验 P(Z_t)的简要 KL散度。 鉴于这一损失取决于由  绘制的潜在样本 z_t ,因此其反过来又决定了输入 x。如果潜在 distribution是一个 这样的 diagonal Gaussian ,P(Z_t) 便是一个均值为 0,且具有标准离差的标准 Gaussian,这种情况下方程则变为  

。 网络的总损失 L 是重建和潜在损失之和的期望值:         对于每个随机梯度下降,我们使用单个 z 样本进行优化。   L^Z 可以解释为从之前的序列向解码器传输潜在样本序列 z_1:T 所需的 NAT 数量,并且(如果 x 是离散的)L^x 是解码器重建给定 z_1:T 的 x 所需的 NAT 数量。因此,总损失等于解码器和之前数据的预期压缩量。   改善图片   正如 EricJang 在他的文章中提到的,让我们的神经网络仅仅“改善图像”而不是“一次完成图像”会更容易些。正如人类艺术家在画布上涂涂画画,并从绘画过程中推断出要修改什么,以及下一步要绘制什么。   改进图像或逐步细化只是一次又一次地破坏我们的联合 distribution P(C) ,导致潜在变量链 C1,C2,…CT−1 呈现新的变量分布 P(CT) 。       

诀窍是多次从迭代细化分布 P(Ct|Ct−1)中取样,而不是直接从 P(C) 中取样。 在 DRAW 模型中,P(Ct|Ct−1) 是所有 t 的同一 distribution,因此我们可以将其表示为以下递归关系(如果不是,那么就是Markov Chain而不是递归网络了)。

DRAW模型的实际应用 假设你正在尝试对数字 8 的图像进行编码。每个手写数字的绘制方式都不同,有的样本 8 可能看起来宽一些,有的可能长一些。如果不注意,编码器将被迫同时捕获所有这些小的差异。

但是……如果编码器可以在每一帧上选择一小段图像并一次检查数字 8 的每一部分呢?这会使工作更容易,对吧?

同样的逻辑也适用于生成数字。注意力单元将决定在哪里绘制数字 8 的下一部分-或任何其他部分-而传递的潜在矢量将决定解码器生成多大的区域。 基本上,如果我们把变分的自动编码器(VAE)中的潜在代码看作是表示整个图像的矢量,那么绘图中的潜在代码就可以看作是表示笔画的矢量。最后,这些向量的序列实现了原始图像的再现。

好吧,那么它是如何工作的呢?

在一个递归的 VAE 模型中,编码器在每一个 timestep 会接收整个输入图像。在 Draw 中,我们需要将焦点集中在它们之间的 attention gate 上,因此编码器只接收到网络认为在该 timestep 重要的图像部分。第一个 attention gate 被称为“Read”attention。 “Read”attention分为两部分: 选择图像的重要部分和裁剪图像

选择图像的重要部分 为了确定图像的哪一部分最重要,我们需要做些观察,并根据这些观察做出决定。在 DRAW中,我们使用前一个 timestep 的解码器隐藏状态。通过使用一个简单的完全连接的图层,我们可以将隐藏状态映射到三个决定方形裁剪的参数:中心 X、中心 Y 和比例。

裁剪图像 现在,我们不再对整个图像进行编码,而是对其进行裁剪,只对图像的一小部分进行编码。然后,这个编码通过系统解码成一个小补丁。 现在我们到达 attention gate 的第二部分,“write”attention,(与“read”部分的设置相同),只是“write”attention 使用当前的解码器,而不是前一个 timestep 的解码器。

虽然可以直观地将注意力机制描述为一种裁剪,但实践中使用了一种不同的方法。在上面描述的模型结构仍然精确的前提下,使用了gaussian filters矩阵,没有利用裁剪的方式。我们在DRAW 中取了一组每个 filter 的中心间距都均匀的gaussian filters 矩阵。 代码一览 我们在 Eric Jang 的代码的基础上,对其进行一些清理和注释,以便于理解.

# first we import our librariesimport tensorflow as tffrom tensorflow.examples.tutorials import mnistfrom tensorflow.examples.tutorials.mnist import input_dataimport numpy as npimport scipy.miscimport os Eric 为我们提供了一些伟大的功能,可以帮助我们构建 “read” 和 “write” 注意门径,还有过滤我们将使用的初始状态功能,但是首先,我们需要添加新的功能,来使我们能创建一个密集层并合并图像。并将它们保存到本地计算机中,以获取更新的代码。

# fully-conected layerdef dense(x, inputFeatures, outputFeatures, scope=None, with_w=False): with tf.variable_scope(scope or "Linear"): matrix = tf.get_variable("Matrix", [inputFeatures, outputFeatures], tf.float32, tf.random_normal_initializer(stddev=0.02)) bias = tf.get_variable("bias", [outputFeatures], initializer=tf.constant_initializer(0.0)) if with_w: return tf.matmul(x, matrix) + bias, matrix, bias else: return tf.matmul(x, matrix) + bias # merge imagesdef merge(images, size): h, w = images.shape[1], images.shape[2] img = np.zeros((h * size[0], w * size[1])) for idx, image in enumerate(images): i = idx % size[1] j = idx / size[1] img[j*h:j*h+h, i*w:i*w+w] = image return img # save image on local machine def ims(name, img): # print img[:10][:10]scipy.misc.toimage(img,cmin=0,cmax=1).save(name) 现在让我们把代码放在一起以便完成。

# DRAW implementationclass draw_model(): def __init__(self): # First we download the MNIST dataset into our local machine. self.mnist = input_data.read_data_sets("data/", one_hot=True) print "------------------------------------" print "MNIST Dataset Succesufully Imported" print "------------------------------------" self.n_samples = self.mnist.train.num_examples # We set up the model parameters # ------------------------------ # image width,height self.img_size = 28 # read glimpse grid width/height self.attention_n = 5 # number of hidden units / output size in LSTM self.n_hidden = 256 # QSampler output size self.n_z = 10 # MNIST generation sequence length self.sequence_length = 10 # training minibatch size self.batch_size = 64 # workaround for variable_scope(reuse=True) self.share_parameters = False # Build our model self.images = tf.placeholder(tf.float32, [None, 784]) # input (batch_size * img_size) self.e = tf.random_normal((self.batch_size, self.n_z), mean=0, stddev=1) # Qsampler noise self.lstm_enc = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True) # encoder Op self.lstm_dec = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True) # decoder Op # Define our state variables self.cs = [0] * self.sequence_length # sequence of canvases self.mu, self.logsigma, self.sigma = [0] * self.sequence_length, [0] * self.sequence_length, [0] * self.sequence_length # Initial states h_dec_prev = tf.zeros((self.batch_size, self.n_hidden)) enc_state = self.lstm_enc.zero_state(self.batch_size, tf.float32) dec_state = self.lstm_dec.zero_state(self.batch_size, tf.float32) # Construct the unrolled computational graph x = self.images for t in range(self.sequence_length): # error image + original image c_prev = tf.zeros((self.batch_size, self.img_size**2)) if t == 0 else self.cs[t-1] x_hat = x - tf.sigmoid(c_prev) # read the image r = self.read_basic(x,x_hat,h_dec_prev) #sanity check print r.get_shape() # encode to guass distribution self.mu[t], self.logsigma[t], self.sigma[t], enc_state = self.encode(enc_state, tf.concat(1, [r, h_dec_prev])) # sample from the distribution to get z z = self.sampleQ(self.mu[t],self.sigma[t]) #sanity check print z.get_shape() # retrieve the hidden layer of RNN h_dec, dec_state = self.decode_layer(dec_state, z) #sanity check print h_dec.get_shape() # map from hidden layer self.cs[t] = c_prev + self.write_basic(h_dec) h_dec_prev = h_dec self.share_parameters = True # from now on, share variables # Loss function self.generated_images = tf.nn.sigmoid(self.cs[-1]) self.generation_loss = tf.reduce_mean(-tf.reduce_sum(self.images * tf.log(1e-10 + self.generated_images) + (1-self.images) * tf.log(1e-10 + 1 - self.generated_images),1)) kl_terms = [0]*self.sequence_length for t in xrange(self.sequence_length): mu2 = tf.square(self.mu[t]) sigma2 = tf.square(self.sigma[t]) logsigma = self.logsigma[t] kl_terms[t] = 0.5 * tf.reduce_sum(mu2 + sigma2 - 2*logsigma, 1) - self.sequence_length*0.5 # each kl term is (1xminibatch) self.latent_loss = tf.reduce_mean(tf.add_n(kl_terms)) self.cost = self.generation_loss + self.latent_loss # Optimization optimizer = tf.train.AdamOptimizer(1e-3, beta1=0.5) grads = optimizer.compute_gradients(self.cost) for i,(g,v) in enumerate(grads): if g is not None: grads[i] = (tf.clip_by_norm(g,5),v) self.train_op = optimizer.apply_gradients(grads) self.sess = tf.Session() self.sess.run(tf.initialize_all_variables()) # Our training function def train(self): for i in xrange(20000): xtrain, _ = self.mnist.train.next_batch(self.batch_size) cs, gen_loss, lat_loss, _ = self.sess.run([self.cs, self.generation_loss, self.latent_loss, self.train_op], feed_dict={self.images: xtrain}) print "iter %d genloss %f latloss %f" % (i, gen_loss, lat_loss) if i % 500 == 0: cs = 1.0/(1.0+np.exp(-np.array(cs))) # x_recons=sigmoid(canvas) for cs_iter in xrange(10): results = cs[cs_iter] results_square = np.reshape(results, [-1, 28, 28]) print results_square.shape ims("results/"+str(i)+"-step-"+str(cs_iter)+".jpg",merge(results_square,[8,8])) # Eric Jang's main functions # -------------------------- # locate where to put attention filters on hidden layers def attn_window(self, scope, h_dec): with tf.variable_scope(scope, reuse=self.share_parameters): parameters = dense(h_dec, self.n_hidden, 5) # center of 2d gaussian on a scale of -1 to 1 gx_, gy_, log_sigma2, log_delta, log_gamma = tf.split(1,5,parameters) # move gx/gy to be a scale of -imgsize to +imgsize gx = (self.img_size+1)/2 * (gx_ + 1) gy = (self.img_size+1)/2 * (gy_ + 1) sigma2 = tf.exp(log_sigma2) # distance between patches delta = (self.img_size - 1) / ((self.attention_n-1) * tf.exp(log_delta)) # returns [Fx, Fy, gamma] return self.filterbank(gx,gy,sigma2,delta) + (tf.exp(log_gamma),) # Construct patches of gaussian filters def filterbank(self, gx, gy, sigma2, delta): # 1 x N, look like [[0,1,2,3,4]] grid_i = tf.reshape(tf.cast(tf.range(self.attention_n), tf.float32),[1, -1]) # individual patches centers mu_x = gx + (grid_i - self.attention_n/2 - 0.5) * delta mu_y = gy + (grid_i - self.attention_n/2 - 0.5) * delta mu_x = tf.reshape(mu_x, [-1, self.attention_n, 1]) mu_y = tf.reshape(mu_y, [-1, self.attention_n, 1]) # 1 x 1 x imgsize, looks like [[[0,1,2,3,4,...,27]]] im = tf.reshape(tf.cast(tf.range(self.img_size), tf.float32), [1, 1, -1]) # list of gaussian curves for x and y sigma2 = tf.reshape(sigma2, [-1, 1, 1]) Fx = tf.exp(-tf.square((im - mu_x) / (2*sigma2))) Fy = tf.exp(-tf.square((im - mu_x) / (2*sigma2))) # normalize area-under-curve Fx = Fx / tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),1e-8) Fy = Fy / tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),1e-8) return Fx, Fy # read operation without attention def read_basic(self, x, x_hat, h_dec_prev): return tf.concat(1,[x,x_hat]) # read operation with attention def read_attention(self, x, x_hat, h_dec_prev): Fx, Fy, gamma = self.attn_window("read", h_dec_prev) # apply parameters for patch of gaussian filters def filter_img(img, Fx, Fy, gamma): Fxt = tf.transpose(Fx, perm=[0,2,1]) img = tf.reshape(img, [-1, self.img_size, self.img_size]) # apply the gaussian patches glimpse = tf.batch_matmul(Fy, tf.batch_matmul(img, Fxt)) glimpse = tf.reshape(glimpse, [-1, self.attention_n**2]) # scale using the gamma parameter return glimpse * tf.reshape(gamma, [-1, 1]) x = filter_img(x, Fx, Fy, gamma) x_hat = filter_img(x_hat, Fx, Fy, gamma) return tf.concat(1, [x, x_hat]) # encoder function for attention patch def encode(self, prev_state, image): # update the RNN with our image with tf.variable_scope("encoder",reuse=self.share_parameters): hidden_layer, next_state = self.lstm_enc(image, prev_state) # map the RNN hidden state to latent variables with tf.variable_scope("mu", reuse=self.share_parameters): mu = dense(hidden_layer, self.n_hidden, self.n_z) with tf.variable_scope("sigma", reuse=self.share_parameters): logsigma = dense(hidden_layer, self.n_hidden, self.n_z) sigma = tf.exp(logsigma) return mu, logsigma, sigma, next_state def sampleQ(self, mu, sigma): return mu + sigma*self.e # decoder function def decode_layer(self, prev_state, latent): # update decoder RNN using our latent variable with tf.variable_scope("decoder", reuse=self.share_parameters): hidden_layer, next_state = self.lstm_dec(latent, prev_state) return hidden_layer, next_state # write operation without attention def write_basic(self, hidden_layer): # map RNN hidden state to image with tf.variable_scope("write", reuse=self.share_parameters): decoded_image_portion = dense(hidden_layer, self.n_hidden, self.img_size**2) return decoded_image_portion # write operation with attention def write_attention(self, hidden_layer): with tf.variable_scope("writeW", reuse=self.share_parameters): w = dense(hidden_layer, self.n_hidden, self.attention_n**2) w = tf.reshape(w, [self.batch_size, self.attention_n, self.attention_n]) Fx, Fy, gamma = self.attn_window("write", hidden_layer) Fyt = tf.transpose(Fy, perm=[0,2,1]) wr = tf.batch_matmul(Fyt, tf.batch_matmul(w, Fx)) wr = tf.reshape(wr, [self.batch_size, self.img_size**2]) return wr * tf.reshape(1.0/gamma, [-1, 1]) model = draw_mod

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 解码器
    +关注

    关注

    9

    文章

    1225

    浏览量

    43780
  • 编码器
    +关注

    关注

    45

    文章

    4013

    浏览量

    143415
  • 神经网络
    +关注

    关注

    42

    文章

    4844

    浏览量

    108201

原文标题:158行代码!程序员复现DeepMind图像生成神器

文章出处:【微信号:AI_era,微信公众号:新智元】欢迎添加关注!文章转载请注明出处。

收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    HJ158 单电源低功耗双运算放大器

    相同,可 以互相代换。HJ158 适用于需要多个运算放大器且性能一致的电子系统中,特别是航天航空电子仪器和 野外、井下等电池供电的电子设备中。 主要特点有: 输入失调电压低2mV 输入失调电流低3nA
    发表于 04-20 09:31

    [VirtualLab] 使用Python运行VirtualLab Fusion光学仿真

    摘要 VirtualLab Fusion允许Python外部访问其建模技术、求解器和结果。这个例介绍了一种使用路径变量和Visual Studio代码Python连接到Virtu
    发表于 03-31 09:39

    无法去除 Python VisionFive.i2c 库的终端输出?

    烧的官方最新八月份的 debian 12 的系统。 根据这个案例安装好了 python 环境和 VisionFive 库。 执行下面这条代码: import VisionFive.i2c
    发表于 02-25 06:13

    Python运行本地Web服务并实现远程访问

    本文介绍使用Python搭建本地Web服务并结合 ZeroNews 实现公网访问。
    的头像 发表于 02-06 11:39 365次阅读
    <b class='flag-5'>Python</b>运行本地Web服务并<b class='flag-5'>实现</b>远程访问

    Python中借助NVIDIA CUDA Tile简化GPU编程

    模型更高的层级来实现算法。至于如何将计算任务拆分到各个线程,完全由编译器和运行时在底层自动处理。不仅如此,tile kernels 还能够屏蔽 Tensor Core 等专用硬件的细节,写出的代码还能
    的头像 发表于 12-13 10:12 1456次阅读
    在<b class='flag-5'>Python</b>中借助NVIDIA CUDA Tile简化GPU编程

    Termux中调试圣诞树Python代码

    python --version 如果输出Python 3.x.x(比如3.11.4),说明安装成功。 二、代码编写(两种方式可选) 方式1:Termux自带编辑器(nano
    发表于 12-09 09:02

    labview如何实现数据的采集与实时预测

    现有以下问题:labview可以实现数据的采集以及调用python代码,但如何将这两项功能集成在一个VI文件里,从而实现数据的采集与实时预测。现有条件如下:已完成数据的采集
    发表于 12-03 21:13

    JY-ULP-158+低通滤波器:DC-158MHz频段的“低频信号守门人”——技术与应用深度解析

    在射频电子系统中,低通滤波器是实现低频信号提纯、抑制高频杂波的核心器件。杰盈通讯推出的JY-ULP-158+低通滤波器,采用LC电路拓扑设计,凭借“高抑制比、超小型封装、宽温稳定性”的技术优势,在
    的头像 发表于 11-13 14:15 526次阅读
    JY-ULP-<b class='flag-5'>158</b>+低通滤波器:DC-<b class='flag-5'>158</b>MHz频段的“低频信号守门人”——技术与应用深度解析

    Python调用API教程

    两个不同系统之间的信息交互。在这篇文章中,我们将详细介绍Python调用API的方法和技巧。 一、Requests库发送HTTP请求 使用Python调用API的第一步是发送HTTP
    的头像 发表于 11-03 09:15 1202次阅读

    Python 给 Amazon 做“全身 CT”——可量产、可扩展的商品详情爬虫实战

    一、技术选型:为什么选 Python 而不是 Java? 结论: “调研阶段 Python,上线后如果 QPS 爆表再考虑 Java 重构。” 二、整体架构速览(3 分钟看懂) 三、开发前准备(5
    的头像 发表于 10-21 16:59 631次阅读
    <b class='flag-5'>用</b> <b class='flag-5'>Python</b> 给 Amazon 做“全身 CT”——可量产、可扩展的商品详情爬虫实战

    一文解读华为自动驾驶布局之鸿蒙智

    鸿蒙智代表了一种新的产业协作思路软件和算力把传统整车与新兴数字能力结合,通过平台化、模块化的技术路线来支持快速迭代和规模化交付。
    的头像 发表于 10-19 10:38 2580次阅读
    一文解读华为自动驾驶布局之鸿蒙智<b class='flag-5'>行</b>

    淘宝商品详情接口(item_get)企业级全解析:参数配置、签名机制与 Python 代码实战

    本文详解淘宝开放平台taobao.item_get接口对接全流程,涵盖参数配置、MD5签名生成、Python企业级代码实现及高频问题排查,提供可落地的实战方案,助你高效稳定获取商品数据。
    的头像 发表于 09-26 09:13 1126次阅读
    淘宝商品详情接口(item_get)企业级全解析:参数配置、签名机制与 <b class='flag-5'>Python</b> <b class='flag-5'>代码</b>实战

    FS158轻松升级DCDC协议芯片:外围精简+耐压提升一步到位

    在硬件研发与产品迭代中,协议芯片的升级往往面临“改板麻烦、外围元件多、耐压不足”等痛点。今天就为大家分享一个高效方案:FS158直接升级替代传统FS112脚位协议芯片,无需改动PCB板,还能实现
    的头像 发表于 09-22 16:51 1502次阅读
    <b class='flag-5'>用</b>FS<b class='flag-5'>158</b>轻松升级DCDC协议芯片:外围精简+耐压提升一步到位

    termux调试python猜数字游戏

    termux做一个猜数字游戏 下面是在Termux中创建猜数字游戏的步骤及完整实现方案,结合Python实现(最适配Termux环境): ? 一、环境准备(Termux基础
    发表于 08-29 17:15

    termux如何搭建python游戏

    VS Code编辑 - 版本控制:`git`管理代码,同步至GitHub/Gitee - 任务调度:通过`crontab`设置定时测试(如每分钟运行游戏脚本:`*/1 * * * * python
    发表于 08-29 07:06