首页 > 其他分享 >图注意网络(GAT)的可视化实现详解

图注意网络(GAT)的可视化实现详解

时间:2023-09-20 12:55:07浏览次数:43  
标签:特征 GAT CV 详解 可视化 维度 hidden 节点 size

前言 能够可视化的查看对于理解图神经网络(gnn)越来越重要,所以这篇文章将介绍传统GNN层的实现,然后展示ICLR论文“图注意力网络”中对传统GNN层的改进。

本文转载自DeepHub IMBA

作者:David Winer

仅用于学术分享,若侵权请联系删除

欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。

CV各大方向专栏与各个部署框架最全教程整理

【CV技术指南】CV全栈指导班、基础入门班、论文指导班 全面上线!!

假设我们有一个表示为有向无环图(DAG)的文本文档图。文档0与文档1、2和3有一条边,为了实现可视化,这里将使用Graphbook,一个可视化的人工智能建模工具。

我们还为每个文档提供了一些节点特征。将每个文档作为单个[5] 1D文本数组放入BERT中,这样就得到了一个[5,768]形状的嵌入。

为了方便演示,我们只采用BERT输出的前8个维度作为节点特征,这样可以更容易地跟踪数据形状。这样我们就有了邻接矩阵和节点特征。

GNN层

GNN层的一般公式是,对于每个节点,我们取每个节点的所有邻居对特征求和,乘以一个权重矩阵,最后通过一个激活函数得到输出结果。所以这里创建一个以这个公式为标题的空白块,并将其传递给Adj矩阵和节点特征,我将在块中实现上面说的公式。

我们将节点特征平铺(即广播)为3D形状,也就初始的[5,8]形状的节点特征,扩展成有[5,5,8]形状,其中第0维的每个单元格都是节点特征的重复。所以现在可以把最后一个维度看作是“邻居”特征。每个节点有5个可能的邻居。

因为不能直接将节点特征从[5,8]广播到[5,5,8],我们必须首先广播到[25,8],因为在广播时,形状中的每个维度都必须大于或等于原始维度。所以得到形状的5和8部分(get_sub_arrays),然后乘以第一部分得到25,然后将它们全部连接在一起。将结果[25,8]重塑回[5,5,8],结果可以在Graphbook中验证最终2维中的每个节点特征集是相同的。

下一步就是广播邻接矩阵到相同的形状。对于第i行和col j的邻接矩阵中的每一个1,在维数[i, j]上有一行1.0的num_feat。所以在这个邻接关系中,在第0个单元格中第1、2和3行有一行num_feat 1.0(即[0,1:3,:])。

这里的实现非常简单,只需将邻接矩阵解析为十进制并从[5,5]形状广播到[5,5,8]。将这个邻接掩码与平铺节点邻居特征相乘。

我们还想在邻接矩阵中包含一个自循环,这样当对邻居特征求和时,也包括了该节点自己的节点特征。

这样就得到了每个节点的邻居特征,其中没有被一条边连接的节点(不是邻居)的特征为零。对于第0个节点,它包括节点0到3的特征。对于第三个节点,它包括第三和第四个节点。

下一步就是重塑为[25,8],使每个相邻特征都是它自己的行,并将其传递给具有所需隐藏大小的参数化线性层。这里隐藏层大小是32并保存为全局常量,以便可以重用。线性层的输出将是[25,hidden_size]。所以经过重塑就可以得到[5,5,hidden_size]。

最后对中间维度(维度索引为1)求和,对每个节点的相邻特征求和。结果是经过1层的节点嵌入集[5,hidden_size],得到了一个GNN网络。

图注意力层

图注意层关键是注意力系数,如上式所示。从本质上讲,在应用softmax之前,我们将边缘中的节点嵌入连接起来,并通过另一个线性层。

然后使用这些注意系数来计算与原始节点特征对应的特征的线性组合。

我们要做的是为每个邻居平铺每个节点的特征,然后将其与节点的邻居特征连接起来。

这里需要注意的是mask掩码需要在平铺节点特征之前交换0和1维。

这用结果仍然是一个[5,5,8]形数组,但现在[i,:,:]中的每一行都是相同的,并且对应于节点i的特征。然后我们就可以使用乘法来创建只在包含邻居时才重复的节点特征。最后就是将其与上面的GNN创建的相邻特征连接起来,生成连接的特征。

现在我们有了连接的特征,需要把它们输入到一个线性层中,所以还需要重塑回到[5,5,hidden_size],这样我们就可以在中间维度上进行softmax产生我们的注意力系数。

得到了形状为[5,5,hidden_size]的注意力系数,这实际上是在n个节点的图中每个图边嵌入一次。论文说这些应该被转置(维度交换),我们在ReLU之前已经做过了,现在我对最后一个维度进行softmax,这样它们就可以沿着隐藏的尺寸维度进行每个维度索引的标准化。

将[5,hidden_size, 5]形状乘以[5,5,8]形状得到[5,hidden_size, 8]形状。然后我们对hidden_size维度求和,最终输出[5,8],匹配我们的输入形状。这样就可以把这个层串起来多次使用。

总结

本文介绍二零单个GNN层和GAT层的可视化实现。在论文中,他们还解释了是如何扩展多头注意方法的,我们这里没有进行演示。

Graphbook是用于AI和深度学习模型开发的可视化IDE,Graphbook仍处于测试阶段,但是他却是一个很有意思的工具,通过可视化的实现,我们可以了解更多的细节。

本文的项目地址:https://github.com/drwiner/Graphbook-GNN-GAT

Graphbook地址:https://github.com/cerbrec/graphbook

 

欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。

计算机视觉入门1v3辅导班

【技术文档】《从零搭建pytorch模型教程》122页PDF下载

QQ交流群:470899183。群内有大佬负责解答大家的日常学习、科研、代码问题。

其它文章

分享一个CV知识库,上千篇文章、专栏,CV所有资料都在这了

明年毕业,还不知道怎么做毕设的请抓紧机会了

LSKA注意力 | 重新思考和设计大卷积核注意力,性能优于ConvNeXt、SWin、RepLKNet以及VAN

CVPR 2023 | TinyMIM:微软亚洲研究院用知识蒸馏改进小型ViT

ICCV2023|涨点神器!目标检测蒸馏学习新方法,浙大、海康威视等提出

ICCV 2023 Oral | 突破性图像融合与分割研究:全时多模态基准与多交互特征学习

听我说,Transformer它就是个支持向量机

HDRUNet | 深圳先进院董超团队提出带降噪与反量化功能的单帧HDR重建算法

南科大提出ORCTrack | 解决DeepSORT等跟踪方法的遮挡问题,即插即用真的很香

1800亿参数,世界顶级开源大模型Falcon官宣!碾压LLaMA 2,性能直逼GPT-4

SAM-Med2D:打破自然图像与医学图像的领域鸿沟,医疗版 SAM 开源了!

GhostSR|针对图像超分的特征冗余,华为诺亚&北大联合提出GhostSR

Meta推出像素级动作追踪模型,简易版在线可玩 | GitHub 1.4K星

CSUNet | 完美缝合Transformer和CNN,性能达到UNet家族的巅峰!

AI最全资料汇总 | 基础入门、技术前沿、工业应用、部署框架、实战教程学习

计算机视觉入门1v3辅导班

计算机视觉交流群

聊聊计算机视觉入门

标签:特征,GAT,CV,详解,可视化,维度,hidden,节点,size
From: https://www.cnblogs.com/wxkang/p/17717048.html

相关文章

  • iOS app上架app store流程详解
    前提条件在有效期内的苹果开发者账号(类型为个人或者公司账号)。还有一种情况,就是你的AppleID被添加到公司开发者账号团队里面,这样也是可以的,但是需要叫管理员给你开通相应的账号权限,如下截图:这里可能有些同学会问,苹果开发者账号是什么?如何申请?那么可以看看我的上一篇文章:iOS苹果开......
  • HFile详解-基于HBase0.90.5
    1.HFile详解HFile文件分为以下六大部分 序号名称描述1数据块由多个block(块)组成,每个块的格式为:[块头]+[key长]+[value长]+[key]+[value]。2元数据块元数据是key-value类型的值,但元数据快只保存元数据的value值,元数据的key值保存在第五项(元数据索引块)中。该块由多个元数......
  • anaconda navigator,启动!
    今天重新安装了一下anaconda,本想着应该不会再出什么问题,先打开anacondanavigator试试水,没想到还是一直卡在loadingapplications的地方,过了半天好不容易消失了,但是什么都没有显示。再次尝试打开navigator,就提示“Thereisaninstancealreadyrunning”。按照网上的教程,我用......
  • 模拟退火详解
    模拟退火学习(030920一上午成果)目录模拟退火学习(030920一上午成果)前言:1.爬山算法由来:2.模拟退火:算法流程:初学(我)的问题用题来进行理解:BZOJ1844RunAway(cqbz的oj上有)回顾上面的问题:对于Q1对于Q2:Q2的补充:对于Q3前言:emmmm。你还在考虑dp死活写不出来吗?你还在担忧贪心算法的正确......
  • MySQL篇:第九章_详解流程控制结构
    流程控制结构系统变量一、全局变量作用域:针对于所有会话(连接)有效,但不能跨重启查看所有全局变量SHOWGLOBALVARIABLES;查看满足条件的部分系统变量SHOWGLOBALVARIABLESLIKE'%char%';查看指定的系统变量的值SELECT@@global.autocommit;为某个系统变量赋值SET@@glo......
  • 提升工作效率的秘密武器:常用Shell脚本详解
    检测网卡流量,并按规定格式记录在日志中#!/bin/bash########################################################检测网卡流量,并按规定格式记录在日志中#规定一分钟记录一次#日志格式如下所示:#2019-08-1220:40#ens33input:1234bps#ens33output:1235bps###......
  • InnoDB锁详解(共享/排他锁、意向锁、记录锁、间隙锁、临键锁、插入意向锁、自增锁)
    原文地址:两万字详解InnoDB的锁-掘金(juejin.cn)1.为什么需要加锁?为什么需要加锁呢?在日常生活中,如果你心情不好想静静,不想被比别人打扰,你就可以把自己关进房间里,并且反锁。同理,对于MySQL数据库来说的话,一般的对象都是一个事务一个事务来说的。所以,如果一个事务内,正在写某......
  • RCC时钟详解
    目录一.STM32时钟树1.这里主要熟悉SYSCLK的时钟流程即可.(主要是配置好锁相环)二.核心寄存器分析1.RCC_CR时钟控制寄存器2.RCC_CFGR时钟配置寄存器3.其他寄存器三.RCC外设驱动1.操作寄存器方式驱动1.固件库方式驱动四.RCC外设驱动总结一.STM32时钟树1.这里主要熟悉SYSCL......
  • Eclipse Java注释模板设置详解
    设置注释模板的入口:Window->Preference->Java->CodeStyle->CodeTemplate然后展开Comments节点就是所有需设置注释的元素啦。现就每一个元素逐一介绍:文件(Files)注释标签:/***@Title:${file_name}*@Package${package_name}*@Description:${todo}(用一句话描......
  • 详解Spring缓存注解@Cacheable、@CachePut和@CacheEvict
    详解Spring缓存注解@Cacheable、@CachePut和@CacheEvict的使用简介在大型的应用程序中,缓存是一项关键技术,用于提高系统的性能和响应速度。Spring框架提供了强大的缓存功能,通过使用缓存注解可以轻松地集成缓存机制到应用程序中。本文将详细介绍Spring框架中的@Cacheable、@CachePu......