首页 > 其他分享 >【王喆-推荐系统】前沿篇-(task2)GNN在Pinterest的应用

【王喆-推荐系统】前沿篇-(task2)GNN在Pinterest的应用

时间:2022-11-16 19:34:04浏览次数:35  
标签:embedding task2 王喆 Pinterest GraphSAGE 邻接 Embedding GNN 节点


学习总结

(1)GraphSAGE 的主要步骤是三步“采样 - 聚合 - 预测”:

  • 采样是指在整体图数据上随机确定中心节点,采样 k 阶子图样本。
  • 聚合是指利用 GNN 把 k 阶子图样本聚合成中心节点 Embedding。
  • 预测是指利用 GNN 做有监督的标签预测或者直接生成节点 Embedding。

(2)3步中的重点在于聚合的 GNN 结构,它使用 CONVOLVE 操作把邻接点 Embedding 聚合起来,跟中心节点上一轮的 Embedding 连接后,利用全连接层生成新的 Embedding。

【王喆-推荐系统】前沿篇-(task2)GNN在Pinterest的应用_推荐算法


(3)决定什么任务的是GraphSAGE的输出层:

  • 有监督学习:输出层是LR这种二分类模型,input是之前通过GNN学习到的中心节点embedding,output是预测概率的label。
  • 无监督学习:可以用像word2vec一样用softmax作为输出层,预测每个点的ID,这样每个点ID的softmax的输出层向量就是该点的embedding,原理和word2vec一致。和上一节YouTube架构中的召回层的视频向量的生成也是一致的:

【王喆-推荐系统】前沿篇-(task2)GNN在Pinterest的应用_数据_02

文章目录

  • ​​学习总结​​
  • ​​一、以往方法​​
  • ​​二、GraphSAGE 的主要步骤​​
  • ​​三、GraphSAGE 的模型结构​​
  • ​​3.1 GraphSAGE 的训练过程:​​
  • ​​3.2 CONVOLVE​​
  • ​​3.3 第一步的Aggregate操作​​
  • ​​四、GraphSAGE 的预测目标​​
  • ​​五、GraphSAGE 在 Pinterest 推荐系统中的应用​​
  • ​​六、作业​​
  • ​​七、课后答疑​​
  • ​​Reference​​

一、以往方法

deepwalk、Node2vec等Graph embedding方法并没直接处理图结构数据,而是先把图结构数据通过随机游走采样,转为序列数据,然后再用word2vec这类序列数据embedding方法生成Graph embedding。

但是这种“搭桥”的方法对图数据进行采样的时候,破坏了信息原始的结构。

【王喆-推荐系统】前沿篇-(task2)GNN在Pinterest的应用_子图_03

图1 基于随机游走的Graph Embedding算法


二、GraphSAGE 的主要步骤

Graph Sample and Aggregate,翻译过来叫“图采样和聚集方法”。

【王喆-推荐系统】前沿篇-(task2)GNN在Pinterest的应用_GNN_04

图2 GraphSAGE的主要过程 (出自论文 Inductive Representation Learning on Large Graphs)


GraphSAGE 的过程如上图所示,主要可以分为 3 步:

  • 在整体的图数据上,从某一个中心节点开始采样,得到一个 k 阶的子图,示意图中给出的示例是一个二阶子图;
  • 有了这个二阶子图,我们可以先利用 GNN 把二阶的邻接点聚合成一阶的邻接点(图 1-2 中绿色的部分),再把一阶的邻接点聚合成这个中心节点(图 1-2 中蓝色的部分);
  • 有了聚合好的这个中心节点的 Embedding,我们就可以去完成一个预测任务,比如这个中心节点的标签是被点击的电影,那我们就可以让这个 GNN 完成一个点击率预估任务。

GNN 既可以预测中心节点的标签,比如点击或未点击,也可以单纯训练中心节点的 Embedding。主要步骤就是三个“抽样 - 聚合 - 预测”。

三、GraphSAGE 的模型结构

GraphSAGE 的模型结构到底怎么样?它到底是怎么把一个 k 阶的子图放到 GNN 中去训练,然后生成中心节点的 Embedding 的呢?

【王喆-推荐系统】前沿篇-(task2)GNN在Pinterest的应用_GNN_05


上图中处理的样本是一个以点 A 为中心节点的二阶子图,从左到右我们可以看到,点 A 的一阶邻接点包括点 B、点 C 和点 D,从点 B、C、D 再扩散一阶,可以看到点 B 的邻接点是点 A 和点 C,点 C 的邻接点是 A、B、E、F,而点 D 的邻接点是点 A。

3.1 GraphSAGE 的训练过程:

这个 GNN 的输入是二阶邻接点的 Embedding,二阶邻接点的 Embedding 通过一个叫 CONVOLVE 的操作生成了一阶邻接点的 Embedding,然后一阶邻接点的 Embedding 再通过这个 CONVOLVE 的操作生成了目标中心节点的 Embedding,至此完成了整个训练。

3.2 CONVOLVE

CONVOLVE 的中文名是卷积,但这里的卷积并不是严格意义上的数学卷积运算,而是一个由 Aggregate 操作和 Concat 操作组成的复杂操作。

【王喆-推荐系统】前沿篇-(task2)GNN在Pinterest的应用_GNN_05


CONVOLVE 操作是由两个步骤组成的:

  • 第一步叫 Aggregate 操作,就是图 4 中 gamma 符号代表的操作,它把点 A 的三个邻接点 Embedding 进行了聚合,生成了一个 Embedding hN(A);
  • 第二步,我们再把 hN(A) 与点 A 上一轮训练中的 Embedding hA 连接起来,然后通过一个全联接层生成点 A 新的 Embedding。

3.3 第一步的Aggregate操作

就是把多个 Embedding 聚合成一个 Embedding 的操作。比如,我们最开始使用的 Average Pooling,在 DIN 中使用过的 Attention 机制,在序列模型中讲过的基于 GRU 的方法,以及可以把这些 Embedding 聚合起来的 MLP 等等。

四、GraphSAGE 的预测目标

预测节点的标签(如点击or未点击)是一个有监督学习任务;
生成节点的embedding是一个无监督学习任务。

决定什么任务的是GraphSAGE的输出层:

(1)有监督学习:输出层是LR这种二分类模型,input是之前通过GNN学习到的中心节点embedding,output是预测概率的label。

(2)无监督学习:可以用像word2vec一样用softmax作为输出层,预测每个点的ID,这样每个点ID的softmax的输出层向量就是该点的embedding,原理和word2vec一致。和上一节YouTube架构中的召回层的视频向量的生成也是一致的:

【王喆-推荐系统】前沿篇-(task2)GNN在Pinterest的应用_数据_02

五、GraphSAGE 在 Pinterest 推荐系统中的应用

【王喆-推荐系统】前沿篇-(task2)GNN在Pinterest的应用_推荐算法_08


在 PinSAGE 应用的构成中,它没有直接分析图片内容,而只是把图片当作一个节点,利用节点和周围节点的关系生成的图片 Embedding。因此,这个例子可以说明,PinSAGE 某种程度上理解了图片的语义信息,而这些语义信息正是埋藏在 Pinterest 的商品关系图中。

六、作业

使用 GraphSAGE 是为了生成每个节点的 Embedding,那我们有没有办法在 GraphSAGE 中加入物品的其他特征,如物品的价格、种类等等特征,让最终生成的物品 Embedding 中包含这些物品特征的信息呢?

【答】可以在k阶聚合完成后,像wide&deep钟一样,将节点的embedding和物品其他特征拼接后接入全连接层和softmax层得到embedding。

七、课后答疑

(1)在实际公司推荐场景中如果要应用这个算法,数据是通过图数据库来存储吗?能否推荐一个生产环境适合的图数据库?

【答】最近大家提neo4j比较多,https://neo4j.com/,可以研究一下。也可以用spark xgraph直接处理原始数据。

Reference

(1)https://github.com/wzhe06/Reco-papers
(2)《深度学习推荐系统实战》,王喆


标签:embedding,task2,王喆,Pinterest,GraphSAGE,邻接,Embedding,GNN,节点
From: https://blog.51cto.com/u_15717393/5857169

相关文章

  • task2-3
    #include<stdio.h>#include<math.h>intmain(){floatF,C;while(scanf_s("%f",&C)!=EOF){F=9*C/5+32;printf("摄......
  • 如何不影响搜索排名情况下降低搜索爬虫对网站的影响(google ,bing, yandex,pinterest
    引言:网站资源(CPU,内存等)使用率在某个时段很高,甚至出现502状态,查日志(看请求头user-agent)发现爬虫在某个时段大量爬取,超出网站正常的负荷,如何在不影响搜索排名情况下合理降低......
  • task2
    #include<stdio.h>#include<math.h>intmain(){doublex,ans;while(scanf("%lf",&x)!=EOF){ans=pow(x,365);printf("%.2f的365......
  • task2.3
    #include<stdio.h>#include<math.h>intmain(){doublex,ans;while(scanf("%lf",&x)!=EOF){ans=9.0/5.0*x+32.0;printf("摄氏度c=%.2lf;华氏度f=%.2lf\n",x,ans)......