首页 > 其他分享 >图卷积网络(GCN)与图注意力网络(GAT)基础实现及其应用

图卷积网络(GCN)与图注意力网络(GAT)基础实现及其应用

时间:2024-09-22 17:49:18浏览次数:3  
标签:torch features nn self GAT 网络 GCN adj out

图卷积

创作不易,您的打赏、关注、点赞、收藏和转发是我坚持下去的动力!

图卷积网络(Graph Convolutional Networks, GCN)是一种能够直接在图结构数据上进行操作的神经网络模型。它能够处理不规则的数据结构,捕获节点之间的依赖关系,广泛应用于社交网络分析、推荐系统、图像识别、化学分子分析等领域。

主流的图卷积网络包括以下几种:

1. 经典图卷积网络(GCN)

经典GCN使用图拉普拉斯算子将卷积操作推广到图数据中,具体而言,它通过对图的邻接矩阵进行归一化操作来进行信息传播。GCN的核心思想是通过卷积操作在每一层中聚合节点邻居的信息,最终对节点进行表示。

在这里插入图片描述

示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
import numpy as np

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, X, adj):
        support = torch.mm(X, self.weight)
        output = torch.mm(adj, support)
        return output

class GCN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super(GCN, self).__init__()
        self.layer1 = GCNLayer(in_features, hidden_features)
        self.layer2 = GCNLayer(hidden_features, out_features)

    def forward(self, X, adj):
        X = self.layer1(X, adj)
        X = F.relu(X)
        X = self.layer2(X, adj)
        return F.log_softmax(X, dim=1)

# 创建一个简单的图
G = nx.karate_club_graph()
adj = nx.adjacency_matrix(G).todense()
adj = torch.FloatTensor(adj + np.eye(adj.shape[0]))  # 添加自环
degree = np.diag(np.power(np.array(adj.sum(1)), -0.5).flatten())
adj_normalized = torch.FloatTensor(degree @ adj @ degree)

# 节点特征
features = torch.eye(adj.shape[0])

# 创建模型
model = GCN(in_features=adj.shape[0], hidden_features=16, out_features=2)
output = model(features, adj_normalized)
print(output)

2. 图注意力网络(Graph Attention Network, GAT)

GAT是另一种流行的图卷积网络,它通过注意力机制对邻居节点赋予不同的权重,从而实现更灵活的信息聚合。GAT的核心思想是计算目标节点与其邻居节点之间的注意力系数,将邻居节点信息加权求和。

公式:
[ H^{(l+1)}i = \sigma\left( \sum{j \in \mathcal{N}(i)} \alpha_{ij} W H^{(l)}j \right) ]
其中,(\alpha
{ij}) 是节点 (i) 和 (j) 之间的注意力系数,(W) 是可训练的权重矩阵。

示例代码:

class GATLayer(nn.Module):
    def __init__(self, in_features, out_features, alpha, concat=True):
        super(GATLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, h, adj):
        Wh = torch.mm(h, self.W)  # [N, out_features]
        a_input = self._prepare_attentional_mechanism_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        h_prime = torch.matmul(attention, Wh)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def _prepare_attentional_mechanism_input(self, Wh):
        N = Wh.size()[0]
        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
        Wh_repeated_alternating = Wh.repeat(N, 1)
        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
        return all_combinations_matrix.view(N, N, 2 * self.out_features)

class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        super(GAT, self).__init__()
        self.dropout = dropout

        self.attentions = [GATLayer(nfeat, nhid, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        self.out_att = GATLayer(nhid * nheads, nclass, alpha=alpha, concat=False)

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.out_att(x, adj)
        return F.log_softmax(x, dim=1)

# 实例化并使用GAT模型
model_gat = GAT(nfeat=features.shape[1], nhid=8, nclass=2, dropout=0.6, alpha=0.2, nheads=8)
output_gat = model_gat(features, adj_normalized)
print(output_gat)

以上两个示例展示了GCN和GAT的基础实现。GCN适合对图结构信息进行卷积聚合,而GAT则通过引入注意力机制使得信息的聚合更为灵活。这些方法都能很好地应用于节点分类、链接预测等任务。

大家有技术交流指导、论文及技术文档写作指导、课程知识点讲解、项目开发合作的需求可以搜索关注我私信我

在这里插入图片描述

标签:torch,features,nn,self,GAT,网络,GCN,adj,out
From: https://blog.csdn.net/weixin_40841269/article/details/142439137

相关文章

  • 药物分子生成算法综述:从生成对抗网络到变换器模型的多样化选择
    创作不易,您的打赏、关注、点赞、收藏和转发是我坚持下去的动力!基于已有的药物数据生成新的药物分子是一项复杂的任务,通常涉及到生成模型和机器学习算法。以下是一些常用的算法和方法:1.生成对抗网络(GANs)特点:由生成器和判别器两个神经网络组成,生成器生成新分子,判别......
  • 计算机网络(月考一知识点)
    文章目录计算机网络背诵默写版计算机网络知识点(月考1版)计算机网络背诵默写版为我自己留个印记,本来荧光笔画的是没记住的,但是后面用紫色的,结果扫描的时候就看不见了。计算机网络知识点(月考1版)......
  • 神经网络:激活函数选择
        结论直接看——激活函数的选择方式        神经网络主体分为输入层、隐藏层和输出层三个模块。一般情况下,输入层只负责对数据的输入,并不做任何的变换。故而,激活函数的选择只涉及隐藏层和输出层两个模块。     神经网络主体图激......
  • 网络安全在2024好入行吗?
      前言024年的今天,慎重进入网安行业吧,目前来说信息安全方向的就业对于学历的容忍度比软件开发要大得多,还有很多高中被挖过来的大佬。理由很简单,目前来说,信息安全的圈子人少,985、211院校很多都才建立这个专业,加上信息安全法的存在,形成了小圈子的排他效应,大佬们的技术交流都......
  • 网络安全在2024好入行吗?
      前言024年的今天,慎重进入网安行业吧,目前来说信息安全方向的就业对于学历的容忍度比软件开发要大得多,还有很多高中被挖过来的大佬。理由很简单,目前来说,信息安全的圈子人少,985、211院校很多都才建立这个专业,加上信息安全法的存在,形成了小圈子的排他效应,大佬们的技术交流都......
  • 这才是CSDN最系统的网络安全学习路线(建议收藏)
      01什么是网络安全网络安全可以基于攻击和防御视角来分类,我们经常听到的“红队”、“渗透测试”等就是研究攻击技术,而“蓝队”、“安全运营”、“安全运维”则研究防御技术。无论网络、Web、移动、桌面、云等哪个领域,都有攻与防两面性,例如Web安全技术,既有Web渗透,也......
  • 【网络安全】学过编程就是黑客?
      前言黑客,相信经常接触电脑的朋友们对这个词都不陌生,各类影视视频中黑客总是身处暗处,运筹帷幄,正是这种神秘感让我走向学习编程的道路,也正是如此让我明白黑客远没有我想象中那么“帅气”。黑客......
  • 【网络安全】学过编程就是黑客?
      前言黑客,相信经常接触电脑的朋友们对这个词都不陌生,各类影视视频中黑客总是身处暗处,运筹帷幄,正是这种神秘感让我走向学习编程的道路,也正是如此让我明白黑客远没有我想象中那么“帅气”。黑客......
  • 如何用3个月零基础入门网络安全?_网络安全零基础怎么学习
      前言写这篇教程的初衷是很多朋友都想了解如何入门/转行网络安全,实现自己的“黑客梦”。文章的宗旨是:1.指出一些自学的误区2.提供客观可行的学习表3.推荐我认为适合小白学习的资源.大佬绕道哈!→点击获取网络安全资料·攻略←一、自学网络安全学习的误区和陷阱1.......
  • 如何用3个月零基础入门网络安全?_网络安全零基础怎么学习
      前言写这篇教程的初衷是很多朋友都想了解如何入门/转行网络安全,实现自己的“黑客梦”。文章的宗旨是:1.指出一些自学的误区2.提供客观可行的学习表3.推荐我认为适合小白学习的资源.大佬绕道哈!→点击获取网络安全资料·攻略←一、自学网络安全学习的误区和陷阱1.......