0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

GNN教程:GraghSAGE算法细节详解!

深度学习自然语言处理 来源:深度学习自然语言处理 作者:深度学习自然语言 2020-11-24 09:32 次阅读

引言

本文为GNN教程的第三篇文章 【GraghSAGE算法】,在GCN的博文中我们重点讨论了图神经网络的逐层传播公式是如何推导的,然而,GCN的训练方式需要将邻接矩阵和特征矩阵一起放到内存或者显存里,在大规模图数据上是不可取的。 其次,GCN在训练时需要知道整个图的结构信息(包括待预测的节点), 这在现实某些任务中也不能实现(比如用今天训练的图模型预测明天的数据,那么明天的节点是拿不到的)。GraphSAGE的出现就是为了解决这样的问题,这篇博文中我们将会详细得讨论它。

一、Inductive learning v.s. Transductive learning

首先我们介绍一下什么是inductive learning。与其他类型的数据不同,图数据中的每一个节点可以通过边的关系利用其他节点的信息,这样就产生了一个问题,如果训练集上的节点通过边关联到了预测集或者验证集的节点,那么在训练的时候能否用它们的信息呢? 如果训练时用到了测试集或验证集样本的信息(或者说,测试集和验证集在训练的时候是可见的), 我们把这种学习方式叫做transductive learning, 反之,称为inductive learning。

显然,我们所处理的大多数机器学习问题都是inductive learning, 因为我们刻意的将样本集分为训练/验证/测试,并且训练的时候只用训练样本。然而,在GCN中,训练节点收集邻居信息的时候,用到了测试或者验证样本,所以它是transductive的。

二、概述

GraphSAGE是一个inductive框架,在具体实现中,训练时它仅仅保留训练样本到训练样本的边。inductive learning 的优点是可以利用已知节点的信息为未知节点生成Embedding. GraphSAGE 取自 Graph SAmple and aggreGatE, SAmple指如何对邻居个数进行采样。aggreGatE指拿到邻居的embedding之后如何汇聚这些embedding以更新自己的embedding信息。下图展示了GraphSAGE学习的一个过程:

对邻居采样

采样后的邻居embedding传到节点上来,并使用一个聚合函数聚合这些邻居信息以更新节点的embedding

根据更新后的embedding预测节点的标签

三、算法细节

3.1 节点 Embedding 生成(即:前向传播)算法

这一节讨论的是如何给图中的节点生成(或者说更新)embedding, 假设我们已经完成了GraphSAGE的训练,因此模型所有的参数(parameters)都已知了。具体来说,这些参数包括个聚合器(见下图算法第4行)中的参数, 这些聚合器被用来将邻居embedding信息聚合到节点上,以及一系列的权重矩阵(下图算法第5行), 这些权值矩阵被用作在模型层与层之间传播embedding的时候做非线性变换。

下面的算法描述了我们是怎么做前向传播的:

算法的主要部分为:

(line 1)初始化每个节点embedding为节点的特征向量

(line 3)对于每一个节点

(line 4)拿到它采样后的邻居的embedding并将其聚合,这里表示对邻居采样

(line 5)根据聚合后的邻居embedding()和自身embedding()通过一个非线性变换()更新自身embedding.

算法里的这个比较难理解,下面单独来说他,之前提到过,它既是聚合器的数量,也是权重矩阵的数量,还是网络的层数,这是因为每一层网络中聚合器和权重矩阵是共享的。

网络的层数可以理解为需要最大访问到的邻居的跳数(hops),比如在figure 1中,红色节点的更新拿到了它一、二跳邻居的信息,那么网络层数就是2。

为了更新红色节点,首先在第一层()我们会将蓝色节点的信息聚合到红色节点上,将绿色节点的信息聚合到蓝色节点上。在第二层()红色节点的embedding被再次更新,不过这次用的是更新后的蓝色节点embedding,这样就保证了红色节点更新后的embedding包括蓝色和绿色节点的信息。

3.2 采样 (SAmple) 算法

GraphSAGE采用了定长抽样的方法,具体来说,定义需要的邻居个数, 然后采用有放回的重采样/负采样方法达到,。保证每个节点(采样后的)邻居个数一致是为了把多个节点以及他们的邻居拼成Tensor送到GPU中进行批训练。

3.3 聚合器 (Aggregator) 架构

GraphSAGE 提供了多种聚合器,实验中效果最好的平均聚合器(mean aggregator),平均聚合器的思虑很简单,每个维度取对邻居embedding相应维度的均值,这个和GCN的做法基本一致(GCN实际上用的是求和):

举个简单例子,比如一个节点的3个邻居的embedding分别为 ,按照每一维分别求均值就得到了聚合后的邻居embedding为.

论文中还阐述了另外两种aggregator:LSTM aggregator和Pooling aggregator, 有兴趣的可以去论文中看下。

3.4 参数学习

到此为止,整个模型的架构就讲完了,那么GraphSAGE是如何学习聚合器的参数以及权重变量的呢? 在有监督的情况下,可以使用每个节点的预测label和真实label的交叉熵作为损失函数。在无监督的情况下,可以假设相邻的节点的输出embeding应当尽可能相近,因此可以设计出如下的损失函数:

其中是节点的输出embedding,是节点的邻居(这里邻居是广义的,比如说如果和在一个定长的随机游走中可达,那么我们也认为他们相邻),是负采样分布,是负采样的样本数量,所谓负采样指我们还需要一批不是邻居的节点作为负样本,那么上面这个式子的意思是相邻节点的embedding的相似度尽量大的情况下保证不相邻节点的embedding的期望相似度尽可能小。

四、后话

GraphSAGE采用了采样的机制,克服了GCN训练时内存和显存上的限制,使得图模型可以应用到大规模的图结构数据中,是目前几乎所有工业上图模型的雏形。然而,每个节点这么多邻居,采样能否考虑到邻居的相对重要性呢,或者我们在聚合计算中能否考虑到邻居的相对重要性? 这个问题在我们的下一篇博文Graph Attentioin Networks中做了详细的讨论。

责任编辑:lq

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4572

    浏览量

    98735
  • 算法
    +关注

    关注

    23

    文章

    4454

    浏览量

    90747
  • 模型
    +关注

    关注

    1

    文章

    2704

    浏览量

    47681

原文标题:GNN教程:GraghSAGE算法细节详解!

文章出处:【微信号:zenRRan,微信公众号:深度学习自然语言处理】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    详解从均值滤波到非局部均值滤波算法的原理及实现方式

    将再啰嗦一次,详解从均值滤波到非局部均值滤波算法的原理及实现方式。 细数主要的2D降噪算法,如下图所示,从最基本的均值滤波到相对最好的BM3D降噪,本文将尽量用最同属的语言,详解这些
    的头像 发表于 12-19 16:30 347次阅读

    一文带你详解门电路

    【科普】详解门电路
    的头像 发表于 12-15 10:41 690次阅读
    一文带你<b class='flag-5'>详解</b>门电路

    一文详解pcb的msl等级

    一文详解pcb的msl等级
    的头像 发表于 12-13 16:52 2711次阅读

    一文详解TVS二极管

    一文详解TVS二极管
    的头像 发表于 11-29 15:10 714次阅读
    一文<b class='flag-5'>详解</b>TVS二极管

    光伏逆变系统细节知多少

    电子发烧友网站提供《光伏逆变系统细节知多少.doc》资料免费下载
    发表于 11-15 11:13 2次下载
    光伏逆变系统<b class='flag-5'>细节</b>知多少

    PID算法详解及实例分析

    PID算法详解及实例分析#include using namespace std;struct _pid{   float SetSpeed; //定义设定值   float ActualSpeed
    发表于 11-09 16:33 0次下载

    一文详解pcb和smt的区别

    一文详解pcb和smt的区别
    的头像 发表于 10-08 09:31 1556次阅读

    全文详解A*算法及其变种

    相比于 BFS,Dijkstra 算法新增了cost_so_far用于记录从当前点current到起点的路径所需要的代价,并将搜索规则改为优先搜索cost最小的点.如下图所示,,Dijkstra 算法会绕过中央难走的草地.
    发表于 09-14 09:25 935次阅读
    全文<b class='flag-5'>详解</b>A*<b class='flag-5'>算法</b>及其变种

    基于Transformer的目标检测算法

    掌握基于Transformer的目标检测算法的思路和创新点,一些Transformer论文涉及的新概念比较多,话术没有那么通俗易懂,读完论文仍然不理解算法细节部分。
    发表于 08-16 10:51 426次阅读
    基于Transformer的目标检测<b class='flag-5'>算法</b>

    常用的电机控制算法详解

    最近看到一些朋友都在玩各种电机,对于电机重要的就是控制了,控制得稳、准、快是一名控制算法软件工程师的终极目标,首先可以玩一些比较成熟的控制算法来体验一下,所以这里收集这块内容分享给大家。
    发表于 07-18 10:43 1020次阅读
    常用的电机控制<b class='flag-5'>算法</b><b class='flag-5'>详解</b>

    物理设计中的问题详解

    物理设计中的问题详解
    的头像 发表于 07-05 16:56 534次阅读
    物理设计中的问题<b class='flag-5'>详解</b>

    详解DeepMind排序算法

    DeepMind 的这一发现确实居功至伟,但不幸的是,他们未能解释清楚算法。下面,我们来详细看看他们发布的一段汇编代码,这是一个包含三个元素的数组的排序,我们将伪汇编转换为汇编:
    的头像 发表于 06-21 15:38 258次阅读

    KMP算法详解

    KMP 算法主要用于字符串匹配的,他的时间复杂度 O(m+n) 。
    的头像 发表于 06-07 16:23 404次阅读
    KMP<b class='flag-5'>算法</b><b class='flag-5'>详解</b>

    [源代码]Python算法详解

    [源代码]Python算法详解[源代码]Python算法详解
    发表于 06-06 17:50 0次下载

    PFC电路详解教程

    PFC电路详解教程
    发表于 05-31 18:12