0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

GNN教程:GraghSAGE算法细节详解!

深度学习自然语言处理 来源:深度学习自然语言处理 作者:深度学习自然语言 2020-11-24 09:32 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

引言

本文为GNN教程的第三篇文章 【GraghSAGE算法】,在GCN的博文中我们重点讨论了图神经网络的逐层传播公式是如何推导的,然而,GCN的训练方式需要将邻接矩阵和特征矩阵一起放到内存或者显存里,在大规模图数据上是不可取的。 其次,GCN在训练时需要知道整个图的结构信息(包括待预测的节点), 这在现实某些任务中也不能实现(比如用今天训练的图模型预测明天的数据,那么明天的节点是拿不到的)。GraphSAGE的出现就是为了解决这样的问题,这篇博文中我们将会详细得讨论它。

一、Inductive learning v.s. Transductive learning

首先我们介绍一下什么是inductive learning。与其他类型的数据不同,图数据中的每一个节点可以通过边的关系利用其他节点的信息,这样就产生了一个问题,如果训练集上的节点通过边关联到了预测集或者验证集的节点,那么在训练的时候能否用它们的信息呢? 如果训练时用到了测试集或验证集样本的信息(或者说,测试集和验证集在训练的时候是可见的), 我们把这种学习方式叫做transductive learning, 反之,称为inductive learning。

显然,我们所处理的大多数机器学习问题都是inductive learning, 因为我们刻意的将样本集分为训练/验证/测试,并且训练的时候只用训练样本。然而,在GCN中,训练节点收集邻居信息的时候,用到了测试或者验证样本,所以它是transductive的。

二、概述

GraphSAGE是一个inductive框架,在具体实现中,训练时它仅仅保留训练样本到训练样本的边。inductive learning 的优点是可以利用已知节点的信息为未知节点生成Embedding. GraphSAGE 取自 Graph SAmple and aggreGatE, SAmple指如何对邻居个数进行采样。aggreGatE指拿到邻居的embedding之后如何汇聚这些embedding以更新自己的embedding信息。下图展示了GraphSAGE学习的一个过程:

对邻居采样

采样后的邻居embedding传到节点上来,并使用一个聚合函数聚合这些邻居信息以更新节点的embedding

根据更新后的embedding预测节点的标签

三、算法细节

3.1 节点 Embedding 生成(即:前向传播)算法

这一节讨论的是如何给图中的节点生成(或者说更新)embedding, 假设我们已经完成了GraphSAGE的训练,因此模型所有的参数(parameters)都已知了。具体来说,这些参数包括个聚合器(见下图算法第4行)中的参数, 这些聚合器被用来将邻居embedding信息聚合到节点上,以及一系列的权重矩阵(下图算法第5行), 这些权值矩阵被用作在模型层与层之间传播embedding的时候做非线性变换。

下面的算法描述了我们是怎么做前向传播的:

算法的主要部分为:

(line 1)初始化每个节点embedding为节点的特征向量

(line 3)对于每一个节点

(line 4)拿到它采样后的邻居的embedding并将其聚合,这里表示对邻居采样

(line 5)根据聚合后的邻居embedding()和自身embedding()通过一个非线性变换()更新自身embedding.

算法里的这个比较难理解,下面单独来说他,之前提到过,它既是聚合器的数量,也是权重矩阵的数量,还是网络的层数,这是因为每一层网络中聚合器和权重矩阵是共享的。

网络的层数可以理解为需要最大访问到的邻居的跳数(hops),比如在figure 1中,红色节点的更新拿到了它一、二跳邻居的信息,那么网络层数就是2。

为了更新红色节点,首先在第一层()我们会将蓝色节点的信息聚合到红色节点上,将绿色节点的信息聚合到蓝色节点上。在第二层()红色节点的embedding被再次更新,不过这次用的是更新后的蓝色节点embedding,这样就保证了红色节点更新后的embedding包括蓝色和绿色节点的信息。

3.2 采样 (SAmple) 算法

GraphSAGE采用了定长抽样的方法,具体来说,定义需要的邻居个数, 然后采用有放回的重采样/负采样方法达到,。保证每个节点(采样后的)邻居个数一致是为了把多个节点以及他们的邻居拼成Tensor送到GPU中进行批训练。

3.3 聚合器 (Aggregator) 架构

GraphSAGE 提供了多种聚合器,实验中效果最好的平均聚合器(mean aggregator),平均聚合器的思虑很简单,每个维度取对邻居embedding相应维度的均值,这个和GCN的做法基本一致(GCN实际上用的是求和):

举个简单例子,比如一个节点的3个邻居的embedding分别为 ,按照每一维分别求均值就得到了聚合后的邻居embedding为.

论文中还阐述了另外两种aggregator:LSTM aggregator和Pooling aggregator, 有兴趣的可以去论文中看下。

3.4 参数学习

到此为止,整个模型的架构就讲完了,那么GraphSAGE是如何学习聚合器的参数以及权重变量的呢? 在有监督的情况下,可以使用每个节点的预测label和真实label的交叉熵作为损失函数。在无监督的情况下,可以假设相邻的节点的输出embeding应当尽可能相近,因此可以设计出如下的损失函数:

其中是节点的输出embedding,是节点的邻居(这里邻居是广义的,比如说如果和在一个定长的随机游走中可达,那么我们也认为他们相邻),是负采样分布,是负采样的样本数量,所谓负采样指我们还需要一批不是邻居的节点作为负样本,那么上面这个式子的意思是相邻节点的embedding的相似度尽量大的情况下保证不相邻节点的embedding的期望相似度尽可能小。

四、后话

GraphSAGE采用了采样的机制,克服了GCN训练时内存和显存上的限制,使得图模型可以应用到大规模的图结构数据中,是目前几乎所有工业上图模型的雏形。然而,每个节点这么多邻居,采样能否考虑到邻居的相对重要性呢,或者我们在聚合计算中能否考虑到邻居的相对重要性? 这个问题在我们的下一篇博文Graph Attentioin Networks中做了详细的讨论。

责任编辑:lq

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4842

    浏览量

    108189
  • 算法
    +关注

    关注

    23

    文章

    4807

    浏览量

    98573
  • 模型
    +关注

    关注

    1

    文章

    3826

    浏览量

    52276

原文标题:GNN教程:GraghSAGE算法细节详解!

文章出处:【微信号:zenRRan,微信公众号:深度学习自然语言处理】欢迎添加关注!文章转载请注明出处。

收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    三电平中性点电压平衡算法详解:从传统空间矢量调制(SVPWM)到模型预测控制(MPC)

    三电平中性点电压平衡算法详解:从传统空间矢量调制(SVPWM)到模型预测控制(MPC) 多电平逆变器(Multilevel Inverters, MLI)在现代电力电子技术中扮演着至关重要的角色
    的头像 发表于 04-04 08:47 429次阅读
    三电平中性点电压平衡<b class='flag-5'>算法</b><b class='flag-5'>详解</b>:从传统空间矢量调制(SVPWM)到模型预测控制(MPC)

    电机转子外表贴磁钢有什么细节要求?

    在电机转子上贴磁钢是为了产生磁场,从而实现电机的正常运转,以下是在贴磁钢时需要考虑的一些细节要求。
    的头像 发表于 03-12 15:56 163次阅读
    电机转子外表贴磁钢有什么<b class='flag-5'>细节</b>要求?

    FFT算法原理详解

    /* 功能:将input里的数据进行快速傅里叶变换 并且输出 */ #include #include #define FFT_LENGTH 8 double input[FFT_LENGTH]={1,1,1,1,1,1,1,1}; struct complex1{ //定义一个复数结构体 double real; //实部 double image; //虚部 }; //将input的实数结果存放为复数 struct complex1 result_dat[8]; /* 虚数的乘法 */ struct complex1 con_complex(struct complex1 a,struct complex1 b){ struct complex1 temp; temp.real=(a.real*b.real)-(a.image*b.image); temp.image=(a.image*b.real)+(a.real*b.image); return temp; } /* 简单的a的b次方 */ int mypow(int a,int b){ int i,sum=a; if(b==0)return 1; for(i=1;i sum*=a; } return sum; } /* 简单的求以2为底的正整数 */ int log2(int n){ unsigned i=1; int sum=1; for(i;;i++){ sum*=2; if(sum>=n)break; } return i; } /* 简单的交换数据的函数 */ void swap(struct complex1 *a,struct complex1 *b){ struct complex1 temp; temp=*a; *a=*b; *b=temp; } /* dat为输入数据的数组 N为抽样次数 也代表周期 必须是2^N次方 */ void fft(struct complex1 dat[],unsigned char N){ /*最终 dat_buf计算出 当前蝶形运算奇数项与W 乘积 dat_org存放上一个偶数项的值 */ struct complex1 dat_buf,dat_org; /* L为几级蝶形运算 也代表了2进制的位数 n为当前级蝶形的需要次数 n最初为N/2 每级蝶形运算后都要/2 i j为倒位时要用到的自增符号 同时 i也用到了L碟级数 j是计算当前碟级的计算次数 re_i i_copy均是倒位时用到的变量 k为当前碟级 cos(2*pi/N*k)的 k 也是e^(-j2*pi/N)*k 的 k */ unsigned char L,i,j,re_i=0,i_copy=0,k=0,fft_flag=1; //经过观察,发现每级蝶形运算需要N/2次运算,共运算N/2*log2N 次 unsigned char fft_counter=0; //在此要进行补2 N必须是2^n 在此略 //蝶形级数 (L级) L=log2(N); //计算每级蝶形计算的次数(这里只是一个初始值) 之后每次要/2 //n=N/2; //对dat的顺序进行倒位 for(i=1;i i_copy=i; re_i=0; for(j=L-1;j>0;j--){ //判断i的副本最低位的数字 并且移动到最高位 次高位 .. //re_i为交换的数 每次它的数字是不能移动的 并且循环之后要清0 re_i|=((i_copy 0x01)< i_copy>>=1; } swap( dat, dat[re_i]); } //进行fft计算 for(i=0;i fft_flag=1; fft_counter=0; for(j=0;j if(fft_counter==mypow(2,i)){ //控制隔几次,运算几次 fft_flag=0; }else if(fft_counter==0){ //休止结束,继续运算 fft_flag=1; } //当不判断这个语句的时候 fft_flag保持 这样就可以持续运算了 if(fft_flag){ dat_buf.real=cos((2*PI*k)/(N/mypow(2,L-i-1))); dat_buf.image=-sin((2*PI*k)/(N/mypow(2,L-i-1))); dat_buf=con_complex(dat[j+mypow(2,i)],dat_buf); //计算 当前蝶形运算奇数项与W 乘积 dat_org.real=dat[j].real; dat_org.image=dat[j].image; //暂存 dat[j].real=dat_org.real+dat_buf.real; dat[j].image=dat_org.image+dat_buf.image; //实部加实部 虚部加虚部 dat[j+mypow(2,i)].real=dat_org.real-dat_buf.real; dat[j+mypow(2,i)].image=dat_org.image-dat_buf.image; //实部减实部 虚部减虚部 k++; fft_counter++; }else{ fft_counter--; //运算几次,就休止几次 k=0; } } } } void main{ int i; //先将输入信号转换成复数 for(i=0;i result_dat.image=0; //输入信号是二维的,暂时不存在复数 result_dat.real=input; //result_dat.real=10; //输入信号都为实数 } fft(result_dat,FFT_LENGTH); for(i=0;i input=sqrt(result_dat.real*result_dat.real+result_dat.image*result_dat.image); //取模 printf(\"%lfn\",input); } while(1); } 这个程序中input这个数组是输入信号,在这里只模拟抽样了8次,输出的数据也是input,如果想看其它序列的话,可以改变FFT_LENGTH 的值以及 input里的内容,程序输出的是实部和虚部的模
    发表于 01-22 06:36

    网络跳线:细节决定成败的网络构建者

    在构建一个高效、稳定的网络环境时,我们往往会关注到大型的网络设备、复杂的网络架构或是先进的网络技术,而往往忽略了那些看似微不足道却至关重要的细节——网络跳线。然而,正是这些小小的跳线,在网络的构建
    的头像 发表于 01-09 10:10 378次阅读

    高速PCB打样必知:细节决定成败,这些点你不能忽视!

    23年PCBA一站式行业经验PCBA加工厂家今天为大家讲讲高速pcb打样需要注意什么细节?高速pcb打样需要注意的细节。在高速PCB(印刷电路板)打样阶段,为确保最终产品的性能和可靠性,需要注意以下
    的头像 发表于 12-16 09:19 447次阅读
    高速PCB打样必知:<b class='flag-5'>细节</b>决定成败,这些点你不能忽视!

    SM4算法实现分享(一)算法原理

    SM4分组加密算法采用的是非线性迭代结构,以字为单位进行加密、解密运算,每次迭代称为一轮变换,每轮变换包括S盒变换、非线性变换、线性变换、合成变换。加解密算法与密钥扩展都是采用32轮非线性迭代结构
    发表于 10-30 08:10

    复杂的软件算法硬件IP核的实现

    的实现的技术细节,知道这些技术细节将有利于在使用 C 语言编写算法时实现一些有针对性的优化。 2.1 C to HASM HASM 是一种在 C 语言编译到HDL 时、经过严格定义的专用的语言
    发表于 10-30 07:02

    国密系列算法简介及SM4算法原理介绍

    一、 国密系列算法简介 国家商用密码算法(简称国密/商密算法),是由我国国家密码管理局制定并公布的密码算法标准。其分类1所示: 图1 国家商用密码
    发表于 10-24 08:25

    加密算法的应用

    加密是一种保护信息安全的重要手段,近年来随着信息技术的发展,加密技术的应用越来越广泛。本文将介绍加密算法的发展、含义、分类及应用场景。 1. 加密算法的发展 加密算法的历史可以追溯到古代。在
    发表于 10-24 08:03

    数据滤波算法的具体实现步骤是怎样的?

    (高频电磁、瞬时脉冲等),选择适配的滤波算法并落地。以下以电能质量监测中最常用的 IIR 低通滤波(抗高频干扰)、滑动平均滤波(抗瞬时脉冲)、卡尔曼滤波(抗动态波动) 为例,详解具体实现步骤: 一、前置准备:明确滤波目标与硬件基
    的头像 发表于 10-10 16:45 1013次阅读

    DFT算法与FFT算法的优劣分析

    一概述 在谐波分析仪中,我们常常提到的两个词语,就是DFT算法与FFT算法,那么一款功率分析仪/谐波分析仪采用DFT算法或者FFT算法,用户往往关注的是能否达到所要分析谐波次数的目的,
    的头像 发表于 08-04 09:30 1784次阅读

    达梦数据库常用管理SQL命令详解

    达梦数据库常用管理SQL命令详解
    的头像 发表于 06-17 15:12 7659次阅读
    达梦数据库常用管理SQL命令<b class='flag-5'>详解</b>

    SVPWM的原理及法则推导和控制算法详解

    小,使得电机转矩脉动降低,旋转磁场更逼近圆形,而且使直流母线电压的利用率有了很大提高,且更易于实现数字化。下面将对该算法进行详细分析阐述。 1.1 SVPWM 基本原理 SVPWM 的理论基础
    发表于 06-16 17:11

    安徽京准:北斗卫星同步时钟的安装与调试详解

    安徽京准:北斗卫星同步时钟的安装与调试详解
    的头像 发表于 06-05 10:08 1747次阅读
    安徽京准:北斗卫星同步时钟的安装与调试<b class='flag-5'>详解</b>

    SSH常用命令详解

    SSH常用命令详解
    的头像 发表于 06-04 11:30 2191次阅读