0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

蒸馏无分类器指导扩散模型的方法

OpenCV学堂 来源:OpenCV学堂 作者:OpenCV学堂 2022-10-13 10:35 次阅读

斯坦福大学联合谷歌大脑使用「两步蒸馏方法」提升无分类器指导的采样效率,在生成样本质量和采样速度上都有非常亮眼的表现。

去噪扩散概率模型(DDPM)在图像生成、音频合成、分子生成和似然估计领域都已经实现了 SOTA 性能。同时无分类器(classifier-free)指导进一步提升了扩散模型的样本质量,并已被广泛应用在包括 GLIDE、DALL·E 2 和 Imagen 在内的大规模扩散模型框架中。

然而,无分类器指导的一大关键局限是它的采样效率低下,需要对两个扩散模型评估数百次才能生成一个样本。这一局限阻碍了无分类指导模型在真实世界设置中的应用。尽管已经针对扩散模型提出了蒸馏方法,但目前这些方法不适用无分类器指导扩散模型。

为了解决这一问题,近日斯坦福大学和谷歌大脑的研究者在论文《On Distillation of Guided Diffusion Models》中提出使用两步蒸馏(two-step distillation)方法来提升无分类器指导的采样效率。

在第一步中,他们引入单一学生模型来匹配两个教师扩散模型的组合输出;在第二步中,他们利用提出的方法逐渐地将从第一步学得的模型蒸馏为更少步骤的模型。

利用提出的方法,单个蒸馏模型能够处理各种不同的指导强度,从而高效地对样本质量和多样性进行权衡。此外为了从他们的模型中采样,研究者考虑了文献中已有的确定性采样器,并进一步提出了随机采样过程。

研究者在 ImageNet 64x64 和 CIFAR-10 上进行了实验,结果表明提出的蒸馏模型只需 4 步就能生成在视觉上与教师模型媲美的样本,并且在更广泛的指导强度上只需 8 到 16 步就能实现与教师模型媲美的 FID/IS 分数,具体如下图 1 所示。

此外,在 ImageNet 64x64 上的其他实验结果也表明了,研究者提出的框架在风格迁移应用中也表现良好。

方法介绍

接下来本文讨论了蒸馏无分类器指导扩散模型的方法( distilling a classifier-free guided diffusion model)。给定一个训练好的指导模型,即教师模型970d4384-4a3c-11ed-a3b6-dac502259ad0.png之后本文分两步完成。

第一步引入一个连续时间学生模型9720de08-4a3c-11ed-a3b6-dac502259ad0.png,该模型具有可学习参数η_1,以匹配教师模型在任意时间步 t∈[0,1] 处的输出。给定一个优化范围 [w_min, w_max],对学生模型进行优化:

973061e8-4a3c-11ed-a3b6-dac502259ad0.png

其中,97460e1c-4a3c-11ed-a3b6-dac502259ad0.png。为了合并指导权重 w,本文引入了一个 w - 条件模型,其中 w 作为学生模型的输入。为了更好地捕捉特征,本文还对 w 应用傅里叶嵌入。此外,由于初始化在模型性能中起着关键作用,因此本文初始化学生模型的参数与教师模型相同。

在第二步中,本文将离散时间步(discrete time-step)考虑在内,并逐步将第一步中的蒸馏模型976a42dc-4a3c-11ed-a3b6-dac502259ad0.png转化为步数较短的学生模型977da4a8-4a3c-11ed-a3b6-dac502259ad0.png,其可学习参数为η_2,每次采样步数减半。设 N 为采样步数,给定 w ~ U[w_min, w_max] 和 t∈{1,…, N},然后根据 Salimans & Ho 等人提出的方法训练学生模型。在将教师模型中的 2N 步蒸馏为学生模型中的 N 步之后,之后使用 N 步学生模型作为新的教师模型,这个过程不断重复,直到将教师模型蒸馏为 N/2 步学生模型。

N 步可确定性和随机采样:一旦模型979522b8-4a3c-11ed-a3b6-dac502259ad0.png训练完成,给定一个指定的 w ∈ [w_min, w_max],然后使用 DDIM 更新规则执行采样。

实际上,本文也可以执行 N 步随机采样,使用两倍于原始步长的确定性采样步骤,然后使用原始步长向后执行一个随机步骤 。对于97a8d538-4a3c-11ed-a3b6-dac502259ad0.png,当 t > 1/N 时,本文使用以下更新规则

97b46010-4a3c-11ed-a3b6-dac502259ad0.png

实验

实验评估了蒸馏方法的性能,本文主要关注模型在 ImageNet 64x64 和 CIFAR-10 上的结果。他们探索了指导权重的不同范围,并观察到所有范围都具有可比性,因此实验采用 [w_min, w_max] = [0, 4]。图 2 和表 1 报告了在 ImageNet 64x64 上所有方法的性能。

97dbcd80-4a3c-11ed-a3b6-dac502259ad0.png

984c2594-4a3c-11ed-a3b6-dac502259ad0.png

本文还进行了如下实验。具体来说,为了在两个域 A 和 B 之间执行风格迁移,本文使用在域 A 上训练的扩散模型对来自域 A 的图像进行编码,然后使用在域 B 上训练的扩散模型进行解码。由于编码过程可以理解为反向 DDIM 采样过程,本文在无分类器指导下对编码器和解码器进行蒸馏,并与下图 3 中的 DDIM 编码器和解码器进行比较。

审核编辑:彭静
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 编码器
    +关注

    关注

    41

    文章

    3360

    浏览量

    131537
  • 模型
    +关注

    关注

    1

    文章

    2704

    浏览量

    47686
  • 分类器
    +关注

    关注

    0

    文章

    152

    浏览量

    13113

原文标题:采样提速256倍,蒸馏扩散模型生成图像质量媲美教师模型,只需4步

文章出处:【微信号:CVSCHOOL,微信公众号:OpenCV学堂】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    【转帖】传感的故障分类与诊断方法

    最大值;漂移故障,信号以某一速率偏移原信号;周期性干扰故障,原信号上叠加某一频率的信号。传感故障的诊断方法从不同角度出发,故障诊断方法分类不完全相同。现简单地将故障诊断
    发表于 07-13 17:19

    传感的故障分类与诊断方法

    最大值;漂移故障,信号以某一速率偏移原信号;周期性干扰故障,原信号上叠加某一频率的信号。传感故障的诊断方法从不同角度出发,故障诊断方法分类不完全相同。现简单地将故障诊断
    发表于 10-30 15:57

    Edge Impulse的分类模型浅析

    就Edge Impulse的三大模型之一的分类模型进行浅析。针对于图像的分类识别模型,读者可参考OpenMv或树莓派等主流图像识别单片机系统
    发表于 12-20 06:51

    基于优化SVM模型的网络负面信息分类方法研究

    基于优化SVM模型的网络负面信息分类方法研究_郑金芳
    发表于 01-07 18:56 0次下载

    基于非参数方法分类模型检验

    本文主要研究了基于非参数方法分类模型交叉验证结果比较,主要是对实例通过非参数的方法进行模型比较的假设检验,检验两
    发表于 12-08 15:28 1次下载

    针对遥感图像场景分类的多粒度特征蒸馏方法

    了其在嵌入式设备上的应用。提出一种针对遥感图像场景分类的多粒度特征蒸馏方法,将深度网络不同阶段的特征与最终的类别概率同时作为浅层模型的监督信号,使得浅层
    发表于 03-11 17:18 20次下载
    针对遥感图像场景<b class='flag-5'>分类</b>的多粒度特征<b class='flag-5'>蒸馏</b><b class='flag-5'>方法</b>

    电池修复技术:做蒸馏水的方法是怎样的

    许多年前,该村经常停电,应急灯也很流行。 每个人都在玩电池逆变器。 电池和应急灯必须充满蒸馏水。 如果您不愿购买它们,请使用以下本机方法: 这个方法很好。 用这种蒸馏
    发表于 05-18 17:15 1978次阅读
    电池修复技术:做<b class='flag-5'>蒸馏</b>水的<b class='flag-5'>方法</b>是怎样的

    如何改进和加速扩散模型采样的方法1

      尽管扩散模型实现了较高的样本质量和多样性,但不幸的是,它们在采样速度方面存在不足。这限制了扩散模型在实际应用中的广泛采用,并导致了从这些模型
    的头像 发表于 05-07 14:25 1866次阅读
    如何改进和加速<b class='flag-5'>扩散</b><b class='flag-5'>模型</b>采样的<b class='flag-5'>方法</b>1

    如何改进和加速扩散模型采样的方法2

      事实上,扩散模型已经在深层生成性学习方面取得了重大进展。我们预计,它们可能会在图像和视频处理、 3D 内容生成和数字艺术以及语音和语言建模等领域得到实际应用。它们还将用于药物发现和材料设计等领域,以及其他各种重要应用。我们认为,基于
    的头像 发表于 05-07 14:38 2693次阅读
    如何改进和加速<b class='flag-5'>扩散</b><b class='flag-5'>模型</b>采样的<b class='flag-5'>方法</b>2

    若干蒸馏方法之间的细节以及差异

    以往的知识蒸馏虽然可以有效的压缩模型尺寸,但很难将teacher模型的能力蒸馏到一个更小词表的student模型中,而DualTrain+S
    的头像 发表于 05-12 11:39 1132次阅读

    关于快速知识蒸馏的视觉框架

    知识蒸馏框架包含了一个预训练好的 teacher 模型蒸馏过程权重固定),和一个待学习的 student 模型, teacher 用来产生 soft 的 label 用于监督 stu
    的头像 发表于 08-31 10:13 650次阅读

    如何度量知识蒸馏中不同数据增强方法的好坏?

    知识蒸馏(knowledge distillation,KD)是一种通用神经网络训练方法,它使用大的teacher模型来 “教” student模型,在各种AI任务上有着广泛应用。
    的头像 发表于 02-25 15:41 536次阅读

    蒸馏也能Step-by-Step:新方法让小模型也能媲美2000倍体量大模型

    为了解决大型模型的这个问题,部署者往往采用小一些的特定模型来替代。这些小一点的模型用常见范式 —— 微调或是蒸馏来进行训练。微调使用下游的人类注释数据升级一个预训练过的小
    的头像 发表于 05-15 09:35 421次阅读
    <b class='flag-5'>蒸馏</b>也能Step-by-Step:新<b class='flag-5'>方法</b>让小<b class='flag-5'>模型</b>也能媲美2000倍体量大<b class='flag-5'>模型</b>

    如何加速生成2 PyTorch扩散模型

    加速生成2 PyTorch扩散模型
    的头像 发表于 09-04 16:09 820次阅读
    如何加速生成2 PyTorch<b class='flag-5'>扩散</b><b class='flag-5'>模型</b>

    任意模型都能蒸馏!华为诺亚提出异构模型的知识蒸馏方法

    相比于仅使用logits的蒸馏方法,同步使用模型中间层特征进行蒸馏方法通常能取得更好的性能。然而在异构
    的头像 发表于 11-01 16:18 536次阅读
    任意<b class='flag-5'>模型</b>都能<b class='flag-5'>蒸馏</b>!华为诺亚提出异构<b class='flag-5'>模型</b>的知识<b class='flag-5'>蒸馏</b><b class='flag-5'>方法</b>