0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

全部代码开源:StarGAN 在 TensorFlow上的简单实现

DPVg_AI_era 来源:未知 作者:李倩 2018-07-03 10:40 次阅读

StarGAN 是去年 11 月由香港科技大学、新泽西大学和韩国大学等机构的研究人员提出的一个图像风格迁移模型,是一种可以在同一个模型中进行多个图像领域之间的风格转换的对抗生成方法。近日,有研究人员将 StarGAN 在 TensorFlow 上实现的全部代码开源,相关论文获 CVPR 2018 Oral。

看代码之前,我们先来回顾一下 StarGAN 的原始论文。

StarGAN 对抗生成网络实现多领域图像变换

图像到图像转换(image-to-image translation)这个任务是指改变给定图像的某一方面,例如,将人的面部表情从微笑改变为皱眉。在引入生成对抗网络(GAN)之后,这项任务有了显着的改进,包括可以改变头发颜色,改变风景图像的季节等等。

给定来自两个不同领域的训练数据,这些模型将学习如何将图像从一个域转换到另一个域。我们将属性(attribute)定义为图像中固有的有意义的特征,例如头发颜色,性别或年龄等,并且将属性值(attribute value)表示为属性的一个特定值,例如头发颜色的属性值可以是黑色 / 金色 / 棕色,性别的属性值是男性 / 女性。我们进一步将域(domain)表示为共享相同属性值的一组图像。例如,女性的图像可以代表一个 domain,男性的图像代表另一个 domain。

一些图像数据集带有多个标签属性。例如,CelebA 数据集包含 40 个与头发颜色、性别和年龄等面部特征相关的标签,RaFD 数据集有 8 个面部表情标签,如 “高兴”、“愤怒”、“悲伤” 等。这些设置使我们能够执行更有趣的任务,即多域图像到图像转换(multi-domain image-to-image translation),即根据来自多个域的属性改变图像。

图 1:通过从 RaFD 数据集学习迁移知识,应用到 CelebA 的多域图像到图像转换结果。第一列和第六列显示输入图像,其余列是产生的 StarGAN 图像。注意,图像是由一个单一模型网络生成的,面部表情标签如生气、高兴、恐惧是从 RaFD 学习的,而不是来自 CelebA。

在图 1 中,前 5 列显示了一个 CelebA 的图像是如何根据 4 个域(“金发”、“性别”、“年龄” 和 “白皮肤”)进行转换。我们可以进一步扩展到训练来自不同数据集的多个域,例如联合训练 CelebA 和 RaFD 图像,使用在 RaFD 上训练的特征来改变 CelebA 图像的面部表情,如图 1 最右边的列所示。

然而,现有模型在这种多域图像转换任务中既效率低,效果也不好。它们的低效性是因为在学习 k 个域之间的所有映射时,必须训练 k(k-1)个生成器。图 2 说明了如何训练 12 个不同的生成器网络以在 4 个不同的域中转换图像。

图 2: StarGAN 模型与其他跨域模型的比较。(a)为处理多个域,应该在每两个域之间都建立跨域模型。(b)StarGAN 用单个生成器学习多域之间的映射。该图表示连接多个域的拓扑图。

为了解决这类问题,我们提出了StarGAN,这是一个能够学习多个域之间映射的生成对抗网络。如图 2(b) 所示,我们的模型接受多个域的训练数据,仅使用一个生成器就可以学习所有可用域之间的映射。

这个想法很简单。我们的模型不是学习固定的转换(例如,将黑头发变成金色头发),而是将图像和域信息作为输入,学习将输入的图像灵活地转换为相应的域。我们使用一个标签来表示域信息。在训练过程中,我们随机生成一个目标域标签,并训练模型将输入图像转换为目标域。这样,我们可以控制域标签并在测试阶段将图像转换为任何想要的域。

我们还介绍了一种简单但有效的方法,通过在域标签中添加一个掩码向量(mask vector)来实现不同数据集域之间的联合训练。我们提出的方法可以确保模型忽略未知的标签,并关注特定数据集提供的标签。这样,我模型就可以很好地完成任务,比如利用从 RaFD 中学到的特征合成 CelebA 图像的面部表情,如图 1 最右边的列所示。据我们所知,这是第一个在不同的数据集上成功地完成多域图像转换的工作。

总结而言,这个研究的贡献如下:

提出 StarGAN,这是一个新的生成对抗网络,只使用一个生成器和一个鉴别器来学习多个域之间的映射,能有效地利用所有域的图像进行训练。

演示了如何通过使用 mask vector 来学习多个数据集之间的多域图像转换,使 StarGAN 能够控制所有可用的域标签。

使用 StarGAN 在面部属性转换和面部表情合成任务提供了定性和定量的结果,优于 baseline 模型

图 3:StarGAN 的概观,包含两个模块:一个鉴别器 D 和一个生成器 G。(a)D 学习区分真实图像和假图像,并将真实图像分类到相应的域。(b)G 接受图像和目标域标签作为输入并生成假图像。 (c)G 尝试在给定原始域标签的情况下,从假图像中重建原始图像。(d)G 尝试生成与真实图像非常像的假图像,并通过 D 将其分类为目标域。

实验结果

图4:CelebA 数据集上面部属性转换的结果对凯勒巴数据集。第1列显示输入图像,后4列显示单个属性转换的结果,最右边的列显示多个属性的转换结果。H:头发的颜色;G:性别;A:年龄

图5:RaFD 数据集上面部表情合成的结果

图6:StarGAN-SNG 和 StarGAN-JNT 在 CelebA 数据集上的面部表情合成结果。

TensorFlow模型的实现

要求:

Tensorflow 1.8

Python 3.6

> python download.py celebA

下载数据集

> python download.py celebA

训练

python main.py --phase train

测试

python main.py --phase test

celebA 测试图像和你想要的图像同时运行

预训练模型

下载celebA_checkpoint

结果 (128x128, wgan-gp)

女性

男性

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 数据集
    +关注

    关注

    4

    文章

    1178

    浏览量

    24349
  • 图像转换
    +关注

    关注

    0

    文章

    6

    浏览量

    6134
  • tensorflow
    +关注

    关注

    13

    文章

    313

    浏览量

    60242

原文标题:【CVPR Oral】TensorFlow实现StarGAN代码全部开源,1天训练完

文章出处:【微信号:AI_era,微信公众号:新智元】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    关于 TensorFlow

    的底层数据操作,你也可以自己写一点c++代码来丰富底层的操作。真正的可移植性(Portability)Tensorflow CPU和GPU运行,比如说可以运行在台式机、服务器、手机
    发表于 03-30 19:57

    阿里云Kubernetes容器服务打造TensorFlow实验室

    GPU的使用,同时支持最新的TensorFLow版本, 对于数据科学家来说既是复杂的,同时也是浪费精力的。阿里云的Kubernetes集群,您可以通过简单的按钮提交创建一套完整的
    发表于 05-10 10:24

    概述tensorflow代码

    tensorflow代码分析
    发表于 07-24 14:27

    情地使用Tensorflow吧!

    Google Cloud ML服务,我们可以把TensorFlow应用代码直接提交到云端运行,甚至可以把训练好的模型直接部署,通过API就可以直接访问,也得益于
    发表于 07-22 10:13

    TensorFlow是什么

    更长。TensorFlow 使这一切变得更加简单快捷,从而缩短了想法到部署之间的实现时间。本教程中,你将学习如何利用 TensorFlow
    发表于 07-22 10:14

    TensorFlow常用Python扩展包

    Matplotlib 的专门的统计数据可视化工具。H5fs:H5fs 是能够 HDFS(分层数据格式文件系统)运行的 Linux 文件系统(也包括其他带有 FUSE 实现的操作
    发表于 07-28 14:35

    TensorFlow实现简单线性回归

    本小节直接从 TensorFlow contrib 数据集加载数据。使用随机梯度下降优化器优化单个训练样本的系数。实现简单线性回归的具体做法导入需要的所有软件包: 神经网络中,所有的
    发表于 08-11 19:34

    TensorFlow实现多元线性回归(超详细)

    TensorFlow 实现简单线性回归的基础,可通过权重和占位符的声明中稍作修改来对相同
    发表于 08-11 19:35

    TensorFlow的特点和基本的操作方式

    Tensorflow是Google开源的深度学习框架,来自于Google Brain研究项目,Google第一代分布式机器学习框架DistBelief的基础发展起来。
    发表于 11-23 09:56

    基于TensorFlow Micro代码为何要这么实现

    Hello World是什么?基于TensorFlow Micro代码为何要这么实现
    发表于 11-10 07:48

    Ubuntu 18.04 for Arm运行的TensorFlow和PyTorch的Docker映像

    ,并做出优化以实现尽可能高的性能。我们希望这些 Docker 镜像和创建它们的方法对希望 AArch64 使用 TensorFlow 和 PyTorch 的人有所帮助。包括什么?构
    发表于 10-14 14:25

    TensorFlow是什么?如何启动并运行TensorFlow

    TensorFlow 是一款用于数值计算的强大的开源软件库,特别适用于大规模机器学习的微调。 它的基本原理很简单:首先在 Python 中定义要执行的计算图(例如图 9-1),然后 Tenso
    的头像 发表于 07-29 11:16 1.6w次阅读

    TensorFlow都有哪些功能,大家是否都全部了解呢?

    ,AlphaGo 和 Google Cloud Vision 也是基于 TensorFlow 开发的。而且 TensorFlow开源的,你可以免费下载并立刻上手操作。
    的头像 发表于 09-02 10:20 2.1w次阅读

    开源机器学习平台TensorFlow的更新内容

    TensorFlow 2.2.0-rc0已发布,据官方介绍,TensorFlow 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库。
    的头像 发表于 03-15 14:53 1755次阅读

    TensorFlow手势识别树莓派开源

    电子发烧友网站提供《TensorFlow手势识别树莓派开源.zip》资料免费下载
    发表于 11-09 09:27 1次下载
    <b class='flag-5'>TensorFlow</b>手势识别树莓派<b class='flag-5'>开源</b>