首页 > 其他分享 >学习笔记:DSTAGNN: Dynamic Spatial-Temporal Aware Graph Neural Network for Traffic Flow Forecasting

学习笔记:DSTAGNN: Dynamic Spatial-Temporal Aware Graph Neural Network for Traffic Flow Forecasting

时间:2023-08-22 22:23:57浏览次数:115  
标签:mathbb Forecasting Network Neural boldsymbol times GTU mathcal 注意力

DSTAGNN: Dynamic Spatial-Temporal Aware Graph Neural Network for Traffic Flow Forecasting ICML2022

论文地址:https://proceedings.mlr.press/v162/lan22a.html
代码地址:https://github.com/SYLan2019/DSTAGNN

一个用于时空序列预测的交通流量预测模型。

可学习的地方:

  1. 提出了一种衡量不同节点时空距离的方法:时空感知距离(STAD)。简单说,就是把n个节点的数据看成n个概率分布,然后用Wasserstein distance的公式计算出概率分布相互转换的代价,作为节点之间的距离。
  2. 提出了一种多头GTU的结构
  3. 使用了一种多个注意力结果串联,然后再将各个输出合并的结构
    对比试验结果似乎还不错。

1 复现情况

模型使用torch环境。

想换数据集结果报错。代码里给了PEMS04的部分数据。由于PEMS08更小,我想训练PEMS08,但是运行时报错,说在数据预处理时出现了none或无穷大值,只好作罢。我无法直接训练PEMS04,因为显存不够。于是我调小batch_size到8,终于可以训练了。

我用的GTX1650,训练一个epoch要花很长时间。一epoch要花800s, 1600s, 2400s, 3200s, 4000s...时间越来越长,我训练了7个就终止了,熬不住了,打个游戏都卡。

2 问题和符号定义

节点数量为N,历史时间戳数量为M,预测的时间戳数量为T,特征只有一个。
任务:

\[X^{(t+1):(t+T)}=\mathcal{F}\left[X^{(t-M+1):t};\mathcal G\right] \]

其中 \(X^{(t+1):(t+T)} \in \mathbb R^{N\times T}, X^{(t-M+1):t)} \in \mathbb R^{N\times M}\)

3 模型结构

图1:结构图

3.1 构建时空感知图

交通流数据表示为:\(X^{f}\in\mathbb{R}^{D\times d_{t}\times N}\),N为节点数量,D为天数,\(d_t\)为每天的时间戳数量,那么第n个节点的数据表示为\(\boldsymbol{X}_{n}^{f}=(\boldsymbol{w}_{n1},\boldsymbol{w}_{n2},...,\boldsymbol{w}_{nD}),\boldsymbol{w}_{nd}\in\mathbb{R}^{d_{t}}\)。然后:

\[m_{nd}=\frac{\|\boldsymbol{w}_{nd}\|_2}{Z_n},\quad Z_n=\sum_{d=1}^D\|\boldsymbol{w}_{nd}\|_2 \]

这样的话,就获得了n个概率分布\(P_n\{X_d=m_{nd}\}\) 。那么两个节点的时间序列的差异就可以表示为两个概率分布的转移代价:

\[\begin{array}{l} d_{STAD}(n_1,n_2)\triangleq STAD\left(\boldsymbol{X}_{n_1},\boldsymbol{X}_{n_2}\right)\\ = \displaystyle \inf_ {\gamma\in\Pi\left[P_{n_1},P_{n_2}\right]} \int_{x}\int_{y}\gamma(x,y) cost(w_{n_1x},w_{n_2y}) dxdy \end{array} \]

其中,

\[cost(w_{n_1x},w_{n_2y})=1-\frac{w_{n_1x}^\top\cdot w_{n_2y}}{\sqrt{w_{n_1x}^\mathrm{T}w_{n_1x}}\times\sqrt{w_{n_2y}^\mathrm{T}w_{n_2y}} )} \]

不解 这种微积分表示,是求期望的意思吗?答:不,这是Wasserstein distance的式子,表示两个概率分布之间的“距离”。

上面的通俗点说,就是N个时间序列,每个时间序列是D个向量,向量大小为\(d_t\) 。然后模型只关心向量的长度的比例,于是用向量模长代替向量后,将时间序列整体缩放,使得这个时间序列之和为1。这样就可以把每个节点的时间序列视为一个概率分布。然后定义了两个概率分布相互转化的代价,也就是两个时间序列的距离计算方法。

然后用这种方法计算出N×N的矩阵\(A_{STAD}\) ,其中\(A_{STAD}[i,j]=1-d_{STAD}(i,j)\in[0,1]\) (时空感知距离矩阵)。求出每个点最接近的\(N_r\)个点,其中\(N_r=N\times P_{sp}\) ,\(P_{sp}\)是超参数,表示稀疏程度,也就是在\(A_{STAD}\)的每一行找到最大的\(N_r\)个值。将\(A_{STAD}\)上的其余值设为0,得到了\(A_{STRG}\) (时空相关图)。可以用\(A_{STRG}\)作为注意力机制的先验知识,这个之后会用到。将\(A_{STRG}\)上的非0值设为1(二值化),得到\(A_{STAG}\)(时空感知图),可以用作图卷积中的聚合操作。

思考 这种定义两个点距离的方式,实际上是关心了整体时间上,两个点每天流量比例的相似性。也就是,以天为单位观察流量,如果两个节点的流量几乎都是同时增加和减少,不管他们本身的数值有多大,这两个点都是相似的。 由于只有1个特征,这个式子看起来很好计算。

3.2 时空注意力块(Spatial-Temporal Attention Block, ST Block)

一个ST块(ST Block)由三部分组成:时间注意力(Temporal Attention, TA),空间注意力(Spatial Attention, SA)和图卷积。
整个模型就是由若干个ST块组成的。

时间注意力

不解 “多头注意力提供了平行机制,能够有效关注序列数据的长范围的相关性”为什么说这样就利于长范围相关性的提取?

这里的多头注意力机制,实际上是若干个注意力机制的串联,最后将多个注意力机制的结果合并。

图2:时间注意力TA

第\(l\)个TA的QKV:

\[\mathcal{X'}^{(l)}\boldsymbol{W}_q^{(l)}\triangleq Q^{(l)},\quad\mathcal{X'}^{(l)}\boldsymbol{W}_k^{(l)}\triangleq K^{(l)},\quad\mathcal{X'}^{(l)}\boldsymbol{W}_\nu^{(l)}\triangleq V^{(l)} \]

\[\mathrm{Att}\left(Q^{(l)},K^{(l)},V^{(l)}\right)=\mathrm{Softmax}\left(A^{(l)}\right)V^{(l)},\quad A^{(l)}=\frac{Q^{(l)}K^{(l)^{\top}}}{\sqrt{d_{h}}}+A^{(l-1)} \]

其中\(d_h=d/H\) 。
之后,再对\(Q^{(l)},K^{(l)},V^{(l)}\)进行H次不同的线性变换,经过注意力操作后,缝在一起:

\[\begin{aligned} O^{(h)}&=\operatorname{Att}\left(Q\boldsymbol{W}_q^{(h)},K\boldsymbol{W}_k^{(h)},V\boldsymbol{W}_\nu^{(h)}\right)\\ \mathcal O&=\left[O^{(1)},O^{(2)},\ldots,O^{(H)}\right]\\ Y&=\operatorname{LayerNorm}(\operatorname{Linear}(\operatorname{Reshape}(\mathcal{O})+\mathcal{X}^{\prime})) \end{aligned} \]

最终输出Y到空间注意力模块中,它的形状是\(c^{(l-1)}\times M\times N\)
不解 我觉得模型图这里画的有问题,缺少了这H次注意力和合并操作

空间注意力

图3:空间注意力SA

\(Y\) 变形为 \(Y^\#\),即形状为 \(c^{(l-1)}\times N\times M\)
将维度\(M\)映射成更高维度\(d_E\)
通过一维卷积1x1conv聚合\(c^{(l-1)}\)那一维,得到\(Y'\),形状 \(N\times d_E\)
嵌入位置信息,得到\(Y_E\)
利用\(A_{STRG}\):

\[\boldsymbol{P}^{(h)}=\mathrm{Softmax}\left(\frac{\left(\boldsymbol{Y}_E\boldsymbol{W}_k^{\prime{(h)}}\right)^\top\left(\boldsymbol{Y}_E\boldsymbol{W}_q^{\prime{(h)}}\right)}{\sqrt{d_h}}+\boldsymbol{W}_m^{(h)}\odot\boldsymbol{A}_{STRG}\right) \]

\[\mathcal{P}=\begin{bmatrix}\boldsymbol P^{(1)},\boldsymbol P^{(2)},\ldots,\boldsymbol P^{(H)}\end{bmatrix} \]

最终\(\mathcal P\)的形状为\(H\times N\times N\)

时空卷积

空间图卷积 Spatial Graph-Conv
将所有计算中用到的\(A\)换成\(A_{STAG}\):

\[\tilde{\boldsymbol{L}}=\frac{2}{\lambda_{\max}}(\boldsymbol{D}-\boldsymbol{A})-\boldsymbol{I}_{N} \]

使用K阶切比雪夫多项式,也就是说每个节点聚合了k阶邻点。

\[g_\theta*Gx=g_\theta(\boldsymbol{L})x=\sum_{k=0}^{K-1}\boldsymbol{\theta}_k\left(T_k(\tilde{\boldsymbol{L}})\odot\boldsymbol{P}^{(k)}\right)x \]

其中\(\boldsymbol \theta_k\in\mathbb{R}^K\)是可学习参数,\(\boldsymbol P^{(k)}\in\mathbb R^{N\times N}\)是空间注意力中的第k个头。
输入为\(\mathcal X^{(l)}\in\mathbb R^{N\times c^{(l-1)}\times M}\),卷积核为\(g_\theta\in\mathbb R^{K\times c^{(l-1)}\times c^{(l)}}\),输出为\(\mathcal Z^{(l)}\in\mathbb R^{N\times M\times C^{(l)}}\)

图4:图卷积

时间门控卷积 Temporal Gated-Conv
由三个不同野的门控Tanh单元(GTU)组成。输入为\(\mathcal Z^{(l)}\) 。
不解 这个GTU有什么用

传统GTU通过卷积核\(\Gamma\in\mathbb{R}^{1\times S\times c^{(l)}\times2c^{(l)}}\)来使通道数加倍。\(\mathcal Z'^{(l)}=\Gamma*\mathcal Z^{(l)}\),那么\(\mathcal Z'^{(l)}\in\mathbb R^{N\times(M-(S-1))\times2C}\) 。GTU的过程为:

\[\Gamma*_\tau\mathcal{Z}^{(l)}=\phi(E)\odot\sigma(F)\in\mathbb{R}^{N\times(M-(S-1))\times2C^{(l)}} \]

其中\(E\)和\(F\)分别是\(\mathcal Z^{(l)}\)的前一半和后一半,\(\phi(\cdot)\)是tanh函数,\(\sigma(\cdot)\)是sigmoid函数。
通过叠加门控卷积,提高提取长范围时间依赖性的能力。

模型提出了改进后的M-GTU:使用三个具有不同S大小的GTU,

\[\begin{aligned}Z_{out}^{(l)}&=\mathrm{M\mbox{-}GTU}(\mathcal{Z}^{(l)})=\mathrm{ReLU}(\mathrm{Concat}(\mathrm{Pooling}(\Gamma_1*_\tau\mathcal{Z}^{(l)}),\\&\mathrm{Pooling}(\Gamma_2*_\tau\mathcal{Z}^{(l)}),\mathrm{Pooling}(\Gamma_3*_\tau\mathcal{Z}^{(l)}))+\mathcal{Z}^{(l)})\end{aligned} \]

Concat操作后,数据维度为\(3M-(S_1+S_2+S_3-3)\),然后通过窗口宽度为W的pooling层,变成\((3M-(S_1+S_2+S_3-3))/W\)。通过调节\(W,S_1,S_2,S_3\),可以让\((3M-(S_1+S_2+S_3-3))/W=M\),最后通过ReLU函数得到\(Z_{out}^{(l)}\in\mathbb R^{N\times M\times C^{(l)}}\)。

不解 总感觉这个论文模型介绍的文字和图片没有对应上。不过我还没有看代码,也不好说那个才是对的。

4 实验结果

图5:对比试验

图6:消融实验

需要说明的点:

  • 基线模型:FC-LSTM(2014), TCN(2018), DCRNN(2017), STGCN(2017), ASTGCN(2019), STSGCN(2020), STFGNN(2021), STGODE(2021), Z-GCNETs(2021), AGCRN(2020)
  • DSTAGNN-G使用的是预定义的时空关系图,DSTAGNN使用的是\(A_{STAG}\) 。可以看出使用\(A_{STAG}\)确实会提升一点效果。这里预定义的时空关系图应该就是数据集中路网结构,而\(A_{STAG}\)实际上是根据节点流量自动生成的空间结构,因此说这个模型可以用于没有空间先验信息的情况,且从实验结果上看效果可能会比有空间先验信息的更好。
  • 打*号的是作者重新跑了代码复现的,其余是抄的对应论文的结果。可以看出复现的主要是STFGNN和Z-GCNETs这两个模型,或许是因为这两个模型的结果相对较好,有复现一遍的价值?
  • 消融实验中,RemSTA去除时空注意力机制,RemM-A去除多头机制,RemM-GTU用传统GTU代替M-GTU,RemRC-OUT去除每个ST块后的残差连接。可以看出时空注意力机制是必不可少的,多头机制是提升最大的,M-GTU和残差连接提升不多。另外,这个消融曲线图画得真好,用颜色深浅表示不同实验,曲线的上下关系和图注的上下对应,看起来很舒服。

图7:注意力可视化

论文把时空关系可视化了。左图是第一个注意力头的全局注意力,右图是通过第2、3个头获得的红星点与其它点的时空依赖。论文说这样就提取了路网中的复杂信息,不过我不理解这两个图代表什么含义。

标签:mathbb,Forecasting,Network,Neural,boldsymbol,times,GTU,mathcal,注意力
From: https://www.cnblogs.com/white514/p/17649839.html

相关文章

  • [KDD 2023] All in One- Multi-Task Prompting for Graph Neural Networks
    [KDD2023]AllinOne-Multi-TaskPromptingforGraphNeuralNetworks总结提出了个多任务prompt学习框架,扩展GNN的泛化能力:统一了NLP和图学习领域的prompt格式,包括prompttoken、tokenstructure、insertingpattern构建诱导子图,将点级和边级任务改造为图级任务,统一不同......
  • SocialLGN Light graph convolution network for social recommendation
    目录概SocialLGN代码LiaoJ.,ZhouW.,LuoF.,WenJ.,GaoM.,LiX.andZengJ.SocialLGN:Lightgraphconvolutionnetworkforsocialrecommendation.InformationSciences,2022.概LightGCN+Social.方法很简单,利于理解socialrecommendation.SocialLGN......
  • Cisco CCNA——Network Design Model And Case Study
    NetworkDesignModelAndCaseStudy园区网分层结构接入层技术:子网划分、vlan划分、trunk汇聚层:trunk、vrp、链路聚合核心层:静态路由、动态路由、默认路由、ospf、vtp、acl出口层:nat转换、PPPoE、stp常见网络模型SMB中小型企业网教育行业模型(中小型网络)教育行业模型(大型网络)金融行......
  • Docker搭建lnmp之network篇
    dockerpullnginx#拉去最新的nginx镜像一、搭建vagrant+VagrantBoxVM环境创建Vagrantfile文件vagrantinit编辑Vagrantfile文件Vagrant.configure("2")do|config|config.vm.box="centos7"#指定BOXconfig.vm.networ......
  • rails3学习系列(二)MVC---NetworkError: 500 Internal Server Error
    当我创建了一个control文件:backup_for_sqlserver_controller.rb              classBackupForSqlServerController<ScreenController                   defconfig_wizard                   end          ......
  • Long-term Forecasting with TiDE: Time-series Dense Encoder 学习笔记
    Long-termForecastingwithTiDE:Time-seriesDenseEncoder论文地址:https://arxiv.org/abs/2304.08424代码地址:https://github.com/google-research/google-research/tree/master/tide摘要:这篇论文提出了时间序列密集编码器(TiDE),一种基于MLP的编码器-解码器模型,用于长期时......
  • CS231n: Convolutional Neural Networks for Visual Recognition
    CS231n:ConvolutionalNeuralNetworksforVisualRecognitionEventTypeDateDescriptionCourseMaterialsLecture1Tuesday April4CourseIntroduction Computervisionoverview Historicalcontext Courselogistics[slides] [video]Lecture2Thursday April6Image......
  • README_network
    目录1、功能2、环境3、配置blog_url:blog_id:username:password4、运行5、添加快捷方式设置「右键→→发送到」的功能注意:Typora偏好设置——》图像设置1、功能一键拖拽上传默认“未发布”,可选择直接发布重复上传,提示是否更新博客2、环境(1)Python3安装pyyaml库:cmd中......
  • Paper Reading: NBDT: Neural-Backed Decision Trees
    目录研究动机文章贡献本文方法推理建立层次结构用WordNet标记决策节点微调和树监督损失实验结果对比实验结果可解释性识别错误的模型预测引导图像分类人更倾向的解释识别有缺陷的数据标签优点和创新点PaperReading是从个人角度进行的一些总结分享,受到个人关注点的侧重和实力......
  • HS-GCN Hamming Spatial Graph Convolutional Networks for Recommendation
    目录概符号说明HS-GCNInitialLayerPropagationLayerHashCodeEncoding矩阵表示PredictionLayerOptimization代码LiuH.,WeiY.,YinJ.andNieL.HS-GCN:Hammingspatialgraphconvolutionalnetworksforrecommendation.IEEETKDE.概二值化的nodeembedding.符......