首页 > 其他分享 >学习笔记:SSTBAN 用于长期交通预测的自监督时空瓶颈注意力

学习笔记:SSTBAN 用于长期交通预测的自监督时空瓶颈注意力

时间:2024-03-23 10:44:38浏览次数:18  
标签:预测 STBA 笔记 SSTBAN 时空 数据 注意力

Self-Supervised Spatial-Temporal Bottleneck Attentive Network for Efficient Long-term Traffic Forecasting
用于高效长期交通预测的自监督时空瓶颈注意力网络
期刊会议:ICDE2023
论文地址:https://ieeexplore.ieee.org/document/10184658
代码地址:https://github.com/guoshnBJTU/SSTBAN

长期交通预测存在的问题:

  • 难以平衡准确率和效率。随着时间跨度增大,要么无法捕捉长期动态性,要么以二次计算复杂度为代价获取全局接受域。
  • 高质量的训练数据需求与模型的泛化能力的矛盾。如何提升数据的利用效率值得思考。
    SSTBAN采用多任务框架,结合自监督学习器对历史交通数据产生鲁棒的潜在表示,从而提高其泛化性能和预测的鲁棒性。此外,作者还设计了一个时空瓶颈注意机制,在编码全局时空动态的同时降低了计算复杂度。

长期预测需求分析:
与有助于及时决策的短期预测相比,长期预测为旅行者和管理员提供了必要的支持信息,以优化旅行计划和运输资源管理。特别是未来几个小时的流量预测信息,有助于用户提前制定路由计划。

选择注意力机制的原因:
目前STGNNs分为RNN-based,CNN-based和attention-based方法。RNN-based存在梯度消失问题,不利于长期预测,且序列顺序的预测方式使得模型训练时间随着预测时间线性增加;CNN-based的kernel大小限制了长期动态性的捕捉能力;Attention-based更灵活,不会受到空间和时间距离的影响,但是存在二次计算复杂度的问题,这是要解决的问题。

数据的利用效率:高质量数据需求与泛化性矛盾
现有方法普遍有较强的高质量数据需求,当训练数据存在噪声时,就会导致过拟合或是学习到虚假的关系,泛化能力不佳。于是引入了常用于NLP和CV中的自监督学习。这要正确认知NLP/CV和时空交通预测的区别:

  • NLP/CV中的基础模式,如形状和语义,在广泛数据集中是通用的;而交通数据集中则鲜有这样的共同特征,比如可能数据的特征都不一样。
  • NLP只需捕捉序列特征——即时间,CV只需要捕捉空间特征,但是STGNN要同时捕捉时空特征。

贡献:

  • 第一次提出了一种采用自监督学习器的时空交通预测模型,满足了泛化和鲁棒需求。
  • 设计了一种时空瓶颈注意力机制,能够高效捕捉长期时空动态,将时间复杂度由二次方降低至线性。(RNN:?)
  • 在九个数据集上进行了实验,证明了在精度和效率上的优势。

模型

图: SSTBAN架构

模型包含两个分支:第一个是时空预测分支,第二个是时空自监督学习分支,因此是个多任务框架。

在训练阶段,两个分支一起工作。在分支一中,原始的数据依次经过ST Encoder、Transformer Attention,最后由ST Forecasting Decoder预测;在分支二中,首先随机mask掉一些数据,将破损的数据经过ST Encoder来用剩余的数据提取特征,经过ST Reconstruction Decoder来补全丢掉的数据,并将补全的数据和分支一的完整数据进行对齐比较(为了避免噪声的影响,这里的比较是放在了潜在空间中的)。训练损失也包含了两个,一个是预测误差的MAE,另一个是对齐的MSE。

在两个分支中,encoder和两个decoder由一样的时空瓶颈注意力模块(STBA)和时空嵌入模块(STE)构成。

STBA目的是捕捉长期的时空动态性,且维持低的计算复杂度。
STE目的是提取不同时间切片和节点的独特性,来弥补基于注意力机制的STBA对顺序的不敏感缺点。我们通过端到端的方式训练空间嵌入\(E_{SP}\in R^{N\times d}\),它在所有时间中共享;通过time-of-day和day-of-week,用one-hot和MLPs得到输入时间嵌入\(E_{TP}\in R^{P\times d}\)和输出时间嵌入\(E_{TP}'\in R^{Q\times d}\),它在所有节点中共享;将它们相加得到输入序列嵌入\(\mathcal{E}\in R^{P\times N\times d}\)和输入序列嵌入\(\mathcal{E}'\in R^{Q\times N\times d}\)。

时空瓶颈注意力 STBA

图: 时空瓶颈注意力STBA,时间瓶颈注意力TBA,空间SBA

图中\(\mathcal{Z}^{(l-1)}=(\mathcal{H}^{(l-1)}||\mathcal{E})\in\mathbb{R}^{P\times N\times2d}\)。

STBA包含了空间注意力(SBA)和时间注意力(TBA)。它们并没有直接和其他点相连,而是和参考点相连,而参考点的数量远小于时间点和空间点。我们还希望参考点能够编码通用的全局信息。由于整体形状像瓶颈而得名。

TBA平行地处理每个点的输入。(这里的过程还不是很懂。)

STBA具有以下特点:

  • 由于参考点的设置,运算复杂度从\(O(N^2)\)降低到了\(O(NN')\),因为\(N'\)是个小的超参数。相比于GCN,STBA不需要预定义的图结构,同时能动态调整节点间关系强度。
  • 参考点起到了编码全局模式的作用,可以理解输入,如用来聚类。

时空预测分支

图:分支一 时空预测分支结构

组成部分:
(1)时空编码器:由时空瓶颈注意力组成,映射到潜在表征空间
(2)Transformer attention:将潜在空间下的历史信号适配于预测信号尺寸。为了缓解长期预测存在的比较严重的误差传播问题,我们通过自适应地融合历史中的不同特征,用注意力机制直接把每步的历史信号和预测信号连接起来。即

\[\mathcal{H^{\prime}}_{:,v}^{(0)}=\mathrm{MHSA}(\mathcal{E}_{:,v}^{\prime},\mathcal{E}_{:,v},\mathcal{H}_{:,v}^{(L)})\in\mathbb{R}^{Q\times d} \]

(3)时空预测解码器:由若干层时空瓶颈注意力,最后加上全连接层组成。

时空自监督学习分支

这一分支从mask掉部分信号的不完整数据中,理解时空关系,并在潜在空间中重构缺失的信号。目的是训练潜在空间表征能力。包括如下部分:
(1)Masking:考虑到mask掉单独一个时间点的数据,很容易通过前后数据算出来,因此在时间或空间维度上mask掉连续的段,以此来学习趋势模式。Mask策略是,将输入数据分成若干patch,并将一定比例的patch全部清0。

图:Masking算法

(2)时空编码器:和分支一中的一样。只是,被mask掉的数据不参与时空瓶颈注意力的计算。
(3)时空重构编码器:输入(2)提供的残缺潜在表征,以及指示mask位置的token向量。由若干时空瓶颈组成,并将重构后的表征与分支一的完整表征匹配。

实验

数据集
分别是Seattle Loop,PEMS04, PEMS08
这个Seattle Loop我还是第一次见。

Loop Seattle 数据集由部署在西雅图地区高速公路(I-5、I-405、I-90和SR-520)上的感应环路探测器收集,包含来自323个传感器站的交通状态数据。

图:数据集信息

超参数表
\(L\):ST Encoder中STBA的数量
\(L'\):STF Decoder中STBA的数量
\(d\):多头注意力机制的维数
\(h\):多头注意力机制的头数
\(l_m\):Masking过程的patch length
\(\alpha_m\):Masking过程的mask率
\(\lambda\):预测损失和对齐损失的权重。越大代表对齐损失占比越大。
时间和空间参考点的数量都是3。

图:超参数设置

对比试验

实际上作者选的这些基线模型都比较老,说服力比较差。我在本文的最后放上了自己做的一点对比实验,可以当作参考。

图:PEMS对比试验
图:SeattleLoop对比实验

随着时间跨度增加,SSTBAN的优势也在增加。

图:预测表现与预测长度的关系

鲁棒测试

作者还进行了以下两个实验。可以看到,模型在这两个方面还是有优点的。遗憾就是,对比的模型只有两个,且比较古老,依然是缺乏说服力。
注:GMAN AAAI2020,DMSTGCN KDD2021

图:减少训练数据
图:随机添加噪声

消融实验

将STBA与普通注意力网络进行对比。也觉是说,STBA在减小时间复杂度的同时,还能增加准确率。

图:消融实验

算力消耗

可以看出,时间消耗和空间小号还是比较小的。不过在实验中,模型的时间和空间占用与batch size等参数设置有关,所以这个只能做参考吧。在我用PEMS08做复现,预测48步时,设置batch size=8,显存占用20G左右。

图:算力实验

复现

以下是我的复现结果。以下实验每个仅做了一次,并没有重复实验,所以仅作参考。

batch size对时间的影响较大。越大,时间越短,但相应的占用显存越多。在SSTBAN用PEMS08预测48步的实验中,设置batch size=16时,40G显存的A100就已经跑不动了,可以看出SSTBAN的空间占用还是比较大的。SSTBAN的特点是,训练一个epoch耗时比较长,但是epoch数量少,就触发早停了。

注:TrendGCN属于CIKM2023

模型 数据集 步数 epoch MAE MAPE RMSE 时间
TrendGCN 08 12,12 120(batch64) 15.11 9.68 24.25 0h41
TrendGCN 08 24,24 120(batch128) 16.84 10.77 27.14 0h40
TrendGCN 08 36,36 120(batch64) 17.70 11.95 28.63 1h30
TrendGCN 08 48,48 120(batch128) 18.86 12.91 29.97 1h20
模型 数据集 步数 epoch MAE MAPE RMSE 时间
SSTBAN 08 12,12 71(batch32) 15.36 10.79 24.26 0h50
SSTBAN 08 24,24 16(batch32) 15.40 10.68 26.20 1h30
SSTBAN 08 36,36 18(batch4) 16.56 11.65 29.33 3h30
SSTBAN 08 48,48 15(batch8) 17.29 15.10 29.04 2h40
SSTBAN 04 36,36 15(batch8) 21.09 15.29 37.42 2h

标签:预测,STBA,笔记,SSTBAN,时空,数据,注意力
From: https://www.cnblogs.com/white514/p/18090811

相关文章

  • 《自动机理论、语言和计算导论》阅读笔记:p1-p4
    《自动机理论、语言和计算导论》学习第1天,p1-p4,总计4页。这只是个人的学习记录,因为很多东西不懂,难免存在理解错误的地方。一、技术总结1.有限自动机(finiteautomata)示例1.softwareforcheckingdigitalcircuits。2.lexicalanalyzerofcompiler。3.softwareforscannin......
  • web CSS笔记
    CSS(CascadingStyleSheets)美化样式CSS通常称为CSS样式表或层叠样式表(级联样式表),主要用于设置HTML页面中的文本内容(字体、大小、对齐方式等)、图片的外形(宽高、边框样式、边距等)以及版面的布局等外观显示样式。CSS以HTML为基础,提供了丰富的功能,如字体、颜色、背景的控制及整......
  • iOS模拟器 Unable to boot the Simulator —— Ficow笔记
     本文首发于FicowShen'sBlog,原文地址:iOS模拟器UnabletoboottheSimulator——Ficow笔记。内容概览前言终结模拟器进程命令行改权限清除模拟器缓存总结 前言 iOS模拟器和Xcode一样不靠谱,问题也不少。......
  • 生成函数学习笔记
    生成函数(generatingfunction,简称GF),一般只应用两种:OGF和EGF。OGF和EGF都是定义在一个数列上的。【OGF】【定义】对于一个有限序列\(\{a_i\}(i=0\simN)\),其OGF为\(f(x)=\displaystyle\sum_{i=0}^Na_i\cdotx^i\)。对于一个无限序列\(\{a_i\}\),其OGF为\(f(x)=\d......
  • Android开发笔记[16]-简单使用wasmedge运行时
    摘要使用wasmedge运行时在Android实现"容器化"运行,将fibonacci计算函数打包进入wasm然后打包进入APK中.关键信息AndroidStudio:Iguana|2023.2.1Gradle:distributionUrl=https://services.gradle.org/distributions/gradle-8.4-bin.zipjvmTarget='1.8'minSdk24targe......
  • Java笔记
    Java背景1.官网oraclejava文档https://docs.oracle.com/en/java/index.htmloraclejdk下载https://www.oracle.com/java/technologies/downloads/openjdkhttps://openjdk.org/2.java和JVMjava是基于类的、纯粹的面向对象编程语言java是解释执行类的语言WOR......
  • 数据结构笔记
    数据结构数据在内存中的存储方式(存储结构)程序=数据+算法算法=操作的步骤算法的时间复杂度动态数组链表迭代器栈和队列二叉搜索树二叉平衡树哈希表1.时间复杂度考量算法的时间复杂度有一个前提就是控制变量,语句的执行时间相同,数据的样本量相同……......
  • Android开发笔记[15]-设置页
    摘要使用MMKV数据框架实现设置页数据同步,设置页可以对其他页面进行设置;设置页数据通过MMKV框架持久化存储,重启APP不丢失.关键信息AndroidStudio:Iguana|2023.2.1Gradle:distributionUrl=https://services.gradle.org/distributions/gradle-8.4-bin.zipjvmTarget='1.......
  • [计算化学]分子动力学笔记
    本文为某计算机本科生的分子动力学学习笔记,在gpt4的辅助下,非体系化地整理相关生物、化学、统计力学知识。所有生成内容经过检查和调整,均直接代表本人观点。有科学性错误的话欢迎指教。什么是分子动力学定义:分子动力学是一门结合物理,数学和化学的综合技术。分子动力学是一套分子......
  • JavaWeb学习笔记——第一天
    Web开发什么是WebWeb:全球广域网,也称为万维网(wwwWorldWideWeb),能够通过浏览器访问的网站。Web网站的工作流程用户通过浏览器访问Web网站服务端的程序分为三部分:运行前端程序的前端服务器、运行Java后端程序的后端服务器和数据库服务器。用户通过浏览器对网站发起请求后,......