首页 > 其他分享 >T-GCN解读(论文+代码)

T-GCN解读(论文+代码)

时间:2024-10-29 15:18:10浏览次数:6  
标签:self 论文 batch GCN 解读 num nodes size

一、引言 

        提出交通预测是一个具有挑战性的任务,原因在于其复杂的时空依赖性。

        首先,交通流量随着时间动态变化,主要体现在周期性和趋势性上。左图是交通流量一周内的周期变化,右图是交通流量在一天内随着时间推移发生的变化。

  

        除了随时间动态变化外,还受城市道路网络拓扑结构的影响。上游道路的交通状况通过传递效应影响下游道路的交通状况,而下游道路的交通状况则通过反馈效应影响上游道路的交通状况。如下图所示,由于相邻道路之间的强烈影响,短期的交通流相似性会从状态1(上游道路与中游道路相似)变化为状态2(上游道路与下游道路相似)。

        在引言的后面部分,讲述了以往的许多模型考虑了交通条件的动态变化,但忽略了空间依赖性,导致交通条件的变化不受道路网络的限制,因此无法准确预测交通数据的状态。最终提出了一种新的交通预测方法,称为时空图卷积网络(T-GCN),用于基于城市道路网络的交通预测任务。 

二、相关工作

        这一部分主要讲述了在过去几十年中所使用的交通流预测方法,如下图所示:

        虽然之前已经有很多模型开始充分考虑交通流预测中的时空特性,但 CNN 本质上适用于欧几里得空间(如图像、规则网格等),对具有复杂拓扑结构的交通网络具有局限性,因此无法从根本上表征空间依赖性。因此,这类方法也存在一定的缺陷。

        随着图卷积网络模型的发展,Li等人提出了 DCRNN 模型,该模型通过图上的随机游走捕捉空间特征,并通过编码器-解码器架构捕捉时间特征。因此,在 DCRNN 模型的背景下,论文作者提出了 T-GCN 模型,希望从交通数据中捕捉复杂的时空特性,然后用于基于城市道路网络的交通预测任务。

三、方法论

        与论文《Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting》(STGCN)不同的是,T-GCN在时间依赖的建模上选择了门控循环单元(GRU),而STGCN选择的是时间卷积网络(TCN)。对于空间依赖的建模上,两者都选择了图卷积网络(GCN)

        因此,STGCN 更强调并行计算和短期时间依赖处理,而 T-GCN 更强调长时间依赖的精确预测。

1. 空间依赖建模(GCN)

        T-GCN 和 STGCN 一样,使用图卷积网络来捕捉空间依赖性。因此,两个模型都使用邻接矩阵(A)来表示道路网络的连接性,邻接矩阵的元素表示两个节点之间是否存在连接。这一矩阵使得模型能够在图的结构上进行卷积运算,获得交通网络的空间特征。

        在 STGCN 中,图卷积和时间卷积组成一个“三明治”结构(时空卷积块),使得空间和时间依赖被同时捕捉。然而在 T-GCN 中,图卷积 GCN 只是被简单地堆叠(论文提到共两层GCN堆叠),用来捕捉节点邻域的一阶邻接关系。而时间序列特征则完全有 GRU 来处理。

        如下图所示,假设节点1是一个中心道路,GCN模型可以获取中心道路与周围道路之间的拓扑关系,编码道路网络的拓扑结构和道路上的属性,从而获取空间依赖。

         接下来,作者提出了 GCN 获取空间依赖的表达式:

         其中Â表示对邻接矩阵A做一些预处理步骤,形成包含自环的拉普拉斯矩阵;X就是特征矩阵,W0表示从输入层到隐藏层的权重矩阵。然后利用激活函数并把 ReLU(ÂXW0) 当做下一次的特征输入,再做一次卷积就可得到上述表达式。可用如下代码(gcn.py)表示:

def forward(self, inputs):
    # (batch_size, seq_len, num_nodes),seq_len 表示输入特征的长度,num_nodes 表示图的节点数量
    batch_size = inputs.shape[0]

    # 1. 转置数据维度: (num_nodes, batch_size, seq_len)
    inputs = inputs.transpose(0, 2).transpose(1, 2)

    # 2. 改变数据形状: (num_nodes, batch_size * seq_len)
    inputs = inputs.reshape((self._num_nodes, batch_size * self._input_dim))

    # 3. 图卷积操作: AX (num_nodes, batch_size * seq_len)
    ax = self.laplacian @ inputs

    # 4. 恢复数据形状: (num_nodes, batch_size, seq_len)
    ax = ax.reshape((self._num_nodes, batch_size, self._input_dim))

    # 5. 进一步变换数据形状: (num_nodes * batch_size, seq_len)
    ax = ax.reshape((self._num_nodes * batch_size, self._input_dim))

    # 6. 应用权重矩阵并激活: act(AXW) (num_nodes * batch_size, output_dim)
    outputs = torch.tanh(ax @ self.weights)

    # 7. 调整输出形状: (num_nodes, batch_size, output_dim)
    outputs = outputs.reshape((self._num_nodes, batch_size, self._output_dim))

    # 8. 最后转置形状: (batch_size, num_nodes, output_dim)
    outputs = outputs.transpose(0, 1)

    return outputs

        其中在改变数据形状过程中,(num_nodes, batch_size, seq_len) 会变形为 (num_nodes, batch_size * seq_len) ,也就是合并 batch_size 和 seq_len 为一个维度。比如,假设 inputs 的初始形状为 (2, 4, 3) ,那么经过转置、合并后就会变成 (3, 8) 。这样做原始数据并不会发生改变,只是为了后续与拉普拉斯矩阵相乘时计算更加方便。

2. 时间依赖建模(GRU)

        这一部分,论文作者首先提出了原始的循环神经网络(RNN)所存在的问题,即梯度消失、梯度爆炸。而长短期记忆网络(LSTM)门控循环单元(GRU)的基本原理大致相同,但是GRU 结构相对简单,参数更少,训练速度更快。因此,作者选择了 GRU 模型来从交通数据中获取时间依赖性。

        如下图所示,ht-1 表示时刻 t-1 的隐藏状态,xt 表示时刻t的交通信息,rt 是重置门,用于控制忽略前一时刻状态信息的程度,ut 是更新门,用于控制前一时刻状态信息传递到当前状态的程度,ct 是时刻t存储的记忆内容,ht 是时刻 t 的输出状态。

        接下来,看作者的代码(gru.py)是如何构建 GRU 的。首先,gru.py 里有三个类,分别是 GRULinear、GRUCell 和 GRU。

        GRULinear 是一个自定义的线性变换层,和神经网络中的 nn.Linear 类似,用来连接输入 xt 和隐藏状态 ht-1 。GRUCell 用来实现 GRU 单元的单步更新逻辑,包括计算重置门 rt 和更新门 ut ,然后用它们更新隐藏状态。最后 GRU 会定义一个完整的多步时间序列模型,将输入序列中的每个时间步通过 GRUCell 单步更新依次处理,输出序列的最后一个时间步的作为最终的 GRU 计算结果。

        GRUCell 的前向传播代码如下:

def forward(self, inputs, hidden_state):
    # [r, u] = sigmoid([x, h]W + b)
    # [r, u] (batch_size, num_nodes * (2 * num_gru_units))
    concatenation = torch.sigmoid(self.linear1(inputs, hidden_state))  # 将 inputs 和 hidden_state 组合进行线性变换

    # r (batch_size, num_nodes * num_gru_units)
    # u (batch_size, num_nodes * num_gru_units)
    r, u = torch.chunk(concatenation, chunks=2, dim=1)  # 拆分线性变换结果,将其分成两个部分分别作为重置门 r 和更新门 u

    # c = tanh([x, (r * h)]W + b)
    # c (batch_size, num_nodes * num_gru_units)
    c = torch.tanh(self.linear2(inputs, r * hidden_state))  # 计算候选隐藏状态 c

    # h := u * h + (1 - u) * c
    # h (batch_size, num_nodes * num_gru_units)
    new_hidden_state = u * hidden_state + (1 - u) * c  # 计算新的隐藏状态 h
    return new_hidden_state, new_hidden_state

        代码中的 linear1 是用于计算门控的线性层,对应于上图的σ;而 linear2 是用于计算候选隐藏状态的线性层,对应于上图的 tanh。具体的一些计算公式如下,其中 ut 和 rt 的计算方式相同,所以可以先用 concatenation 表示,然后再做拆分。代码最后的 new_hidden_state 代表最终隐状态 ht,其计算方法也和 GRU 模型的原本计算方法相符。

         论文中 GRU 类的前向传播代码如下:

def forward(self, inputs):
    batch_size, seq_len, num_nodes = inputs.shape

    # 检查输入节点数量是否符合GRU的预期输入维度
    assert self._input_dim == num_nodes  # assert用来测试表示式,其返回值为假,就会触发异常。

    # 初始化输出列表和隐藏状态
    outputs = list()
    hidden_state = torch.zeros(batch_size, num_nodes * self._hidden_dim).type_as(
        inputs
    )

    # 遍历每个时间步
    for i in range(seq_len):
        output, hidden_state = self.gru_cell(inputs[:, i, :], hidden_state)  # 使用GRU单元更新隐藏状态
        output = output.reshape((batch_size, num_nodes, self._hidden_dim))
        outputs.append(output)  # 保存每个时间步的输出
    last_output = outputs[-1]  # 获取最后一个时间步的输出,作为整个GRU的输出结果
    return last_output

        这一段代码的核心点在于遍历每个时间步,对每个时间步的输入处理 。处理的方法就是利用上面的 GRUCell 类,使得旧的隐藏状态得到更新,并把最后一个更新的隐藏状态作为整个 GRU 的输出。

3. 时空图卷积网络(T-GCN)

        这一部分是对时空图卷积网络模型(T-GCN)的详细解释。如下图所示,左侧是时空交通预测的过程,右侧展示了一个 T-GCN 单元的具体结构,ht-1 表示时刻 t-1 的输出,GC表示图卷积过程,ut和rt分别是时刻t的更新门和重置门,ht 表示时刻 t 的输出。

        上面谈到两个代码文件,即 gcn.py 和 gru.py。这两个 python 文件是 gcn 和 gru 模型的通用代码,而作者所做的创新是:

        T-GCN 通过图卷积直接计算门控单元的更新,从而实现了图结构数据的时空依赖性建模也就是将图卷积(GCN)和门控循环单元(GRU)深度融合。

        那么接下来看真正的 T-GCN 代码(tgcn.py)。这里同样分为了三个类,分别是TGCNGraphConvolution、TGCNCell 和 TGCN。

        TGCNGraphConvolution 类使用图卷积计算更新每个节点的新特征。前面我们知道,在时间依赖建模中将输入 x (inputs)和 隐藏状态 ht (hidden_state)通过线性变化(sigmoid),合并为 concatenation 。而论文作者将会利用图卷积实现门控机制,在图卷积中就融合输入和隐藏状态。那么这样一来,GRU 的重置门(reset gate)和更新门(update gate)的计算就不需要额外的门控结构,直接可以通过一个卷积核完成。其前向传播代码如下:

def forward(self, inputs, hidden_state):
    batch_size, num_nodes = inputs.shape

    # 调整输入和隐藏状态的形状
    # inputs (batch_size, num_nodes) -> (batch_size, num_nodes, 1)
    inputs = inputs.reshape((batch_size, num_nodes, 1))
    # hidden_state (batch_size, num_nodes, num_gru_units)
    hidden_state = hidden_state.reshape(
        (batch_size, num_nodes, self._num_gru_units)
    )

    # 1.先拼接 inputs 和 hidden_state (实现门控机制),再调整形状
    # [x, h] (batch_size, num_nodes, num_gru_units + 1)
    concatenation = torch.cat((inputs, hidden_state), dim=2)
    # [x, h] (num_nodes, num_gru_units + 1, batch_size)
    concatenation = concatenation.transpose(0, 1).transpose(1, 2)
    # [x, h] (num_nodes, (num_gru_units + 1) * batch_size)
    concatenation = concatenation.reshape(
        (num_nodes, (self._num_gru_units + 1) * batch_size)
    )

    # 2.先图卷积操作,再恢复维度
    # A[x, h] (num_nodes, (num_gru_units + 1) * batch_size)
    a_times_concat = self.laplacian @ concatenation
    # A[x, h] (num_nodes, num_gru_units + 1, batch_size)
    a_times_concat = a_times_concat.reshape(
        (num_nodes, self._num_gru_units + 1, batch_size)
    )
    # A[x, h] (batch_size, num_nodes, num_gru_units + 1)
    a_times_concat = a_times_concat.transpose(0, 2).transpose(1, 2)
    # A[x, h] (batch_size * num_nodes, num_gru_units + 1)
    a_times_concat = a_times_concat.reshape(
        (batch_size * num_nodes, self._num_gru_units + 1)
    )

    # 3. 应用权重矩阵和偏置计算图卷积结果
    # A[x, h]W + b (batch_size * num_nodes, output_dim)
    outputs = a_times_concat @ self.weights + self.biases
    # A[x, h]W + b (batch_size, num_nodes, output_dim)
    outputs = outputs.reshape((batch_size, num_nodes, self._output_dim))
    # A[x, h]W + b (batch_size, num_nodes * output_dim)
    outputs = outputs.reshape((batch_size, num_nodes * self._output_dim))
    return outputs

        这一段代码大致完成了三件事情:1. 实现门控机制;2. 图卷积;3. 计算结果。其中需要注意第2步的操作,这里的图卷积和 GCN 有区别,拉普拉斯矩阵不是和输入矩阵点乘,而是和 concatenation (输入和隐藏状态的拼接)点乘。这样乘积后的结果既包含节点本身的信息,也反映了节点邻居的特征,进而在 T-GCN 中为每个时间步的更新提供了空间依赖性。

        而 TGCNCell 类实现了之前的 GRUCell 类类似的功能,TGCN 类实现了之前 GRU 类似的功能。这里不再赘述。

四、实验

1. 数据

        作者通过两个实际交通数据集(SZ-taxi数据集和Los-loop数据集)来评估T-GCN模型的预测性能。并且将交通速度作为作为主要的交通信息。adj 为邻接矩阵,speed 为特征矩阵。对于数据集,将80%的数据用于训练集,其余20%用于测试集。对数据集 csv 文件的处理在 SpatioTemporalCSVDataModule 类中。

# spatiotemporal_csv_data.py中定义训练与测试比例
seq_len: int = 12,  # 历史时间步数
pre_len: int = 3,  # 预测序列长度
split_ratio: float = 0.8,  # 训练集和验证集的分割比例

# main.py中定义数据集
DATA_PATHS = {
    "shenzhen": {"feat": "data/sz_speed.csv", "adj": "data/sz_adj.csv"},
    "losloop": {"feat": "data/los_speed.csv", "adj": "data/los_adj.csv"},
}

2. 评估指标

        在 SupervisedForecastTask 类中有一个 validation_step 的函数方法用来评估 T-GCN 模型的性能。其中用了5个指标,分别是:

  1. 均方根误差(RMSE):反映预测值与真实值的差异,单位与原始数据一致,越小越好。
  2. 平均绝对误差(MAE):同样表示预测值与真实值的偏差,避免了平方操作带来的放大效果,数值越小越好。
  3. 准确率(Accuracy):根据应用场景定义了自定义的准确率计算方式,适用于分类任务或特定要求。
  4. 确定系数(R²):衡量模型对数据的拟合程度,1 表示完美拟合,0 表示完全无法解释的模型。
  5. 解释方差分数(var):表征预测值对实际值的方差解释能力。该值接近 1 表示模型预测的方差接近真实值的方差。
def validation_step(self, batch, batch_idx):
    # 1. 获取预测结果和真实标签
    predictions, y = self.shared_step(batch, batch_idx)

    # 2. 将预测结果和标签缩放回原始值范围
    predictions = predictions * self.feat_max_val
    y = y * self.feat_max_val

    # 3. 计算损失
    loss = self.loss(predictions, y)

    # 4. 计算多种性能度量
    rmse = torch.sqrt(torchmetrics.functional.mean_squared_error(predictions, y))  # 均方根误差
    mae = torchmetrics.functional.mean_absolute_error(predictions, y)  # 平均绝对误差
    accuracy = utils.metrics.accuracy(predictions, y)  # 准确率
    r2 = utils.metrics.r2(predictions, y)  # 决定系数
    explained_variance = utils.metrics.explained_variance(predictions, y)  # 解释方差

    # 5. 将度量日志化,便于训练过程中跟踪模型性能
    metrics = {
        "val_loss": loss,
        "RMSE": rmse,
        "MAE": mae,
        "accuracy": accuracy,
        "R2": r2,
        "ExplainedVar": explained_variance,
    }
    self.log_dict(metrics)

    # 6. 返回调整后的预测值和真实标签
    return predictions.reshape(batch[1].size()), y.reshape(batch[1].size())

3. 模型参数选择

        下面两张图表示,在 SZ-taxi 数据集的实验中隐藏单元数量设置为100,Los_loop 数据集的实验中隐藏单元数量设置为64时,在多种指标下都是预测精度最高,预测误差最低。

         因此,作者的代码是这样设计的:

# 在 main.py 中设置了默认使用 los_loop 数据集
parser.add_argument(
        "--data", 
        type=str, 
        help="The name of the dataset", 
        choices=("shenzhen", "losloop"), 
        default="losloop"
    )

# 在 tgcn.py 中设置隐藏单元数量为64
@staticmethod
def add_model_specific_arguments(parent_parser):
    parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
    parser.add_argument("--hidden_dim", type=int, default=64)
    return parser

4. 实验结果

        T-GCN 相比 STGCN 其优势在于长期预测能力。因为GRU具备记忆机制,能够保留长期的历史信息,避免了传统RNN在长序列数据上出现的梯度消失或梯度爆炸问题。

        其次,作者还将T-GCN模型与GCN模型和GRU模型进行了比较。发现均方根误差(RMSE)都是T-GCN最小,即预测值和真实值的差距最小。

        作者在代码中将默认的模型设置为了 GCN ,在运行测试时可以调整修改。

# 注意修改 main.py 中的main函数的模型设置
parser.add_argument(
        "--model_name",
        type=str,
        help="The name of the model for spatiotemporal prediction",
        choices=("GCN", "GRU", "TGCN"),
        default="GCN",
    )

 

         

标签:self,论文,batch,GCN,解读,num,nodes,size
From: https://blog.csdn.net/weixin_51418964/article/details/143206119

相关文章

  • java计算机毕业设计翰明教育教学管理系统(开题+程序+论文)
    本系统(程序+源码)带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容一、研究背景随着教育事业的不断发展,学校规模的扩大以及教育管理要求的日益提高,传统的教育教学管理方式已经难以满足需求。在当今数字化时代,教育领域也迫切需......
  • java计算机毕业设计成都医学院考研信息交流平台(开题+程序+论文)
    本系统(程序+源码)带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容研究背景随着高等教育的普及和就业压力的增加,越来越多的本科毕业生选择继续深造,攻读硕士学位。成都医学院作为四川省的一所重要医学院校,吸引了大量学生报考其......
  • java计算机毕业设计宠物社区app(开题+程序+论文)
    本系统(程序+源码)带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容研究背景随着社会经济的快速发展和人们生活水平的提高,宠物已经成为许多家庭的重要成员。宠物社区app应运而生,旨在为宠物主人提供一个交流平台,分享养宠心得和......
  • 【Python原创毕设|课设】基于Python、机器学习的垃圾邮件分类与安全宣传网站-文末附下
    基于Python、机器学习的垃圾邮件分类与安全宣传网站-(获取方式访问文末官网)一、项目简介二、开发环境三、项目技术四、功能结构五、运行截图六、数据库设计七、功能实现八、源码获取一、项目简介该该系统是一个基于Python的邮件分类和安全宣传网站,结合了机器学习和数......
  • 阿里云消息团队创新论文被软件工程顶会 FM 2024 录用
    近日,由阿里云消息队列团队发表的关于RocketMQ锁性能优化论文被CCF-A类软件工程顶级会议FM2024录用。FM2024是由欧洲形式化方法协会(FME)组织的第24届国际研讨会,会议汇聚了来自各国的形式化研究学者,是形式化方法领域的顶级会议。FM2021强调形式化方法在广泛领域的开发......
  • 深度解读RDS for MySQL 审计日志功能和原理
    本文分享自华为云社区《【华为云MySQL技术专栏】RDSforMySQL审计日志功能介绍》,作者:GaussDB数据库。1.背景在生产环境中,当数据库出现故障或问题时,运维人员需要快速定位出异常或者高危的SQL语句。这时,审计日志能够提供详细的记录,帮助追踪每个数据库操作的执行者、执行时间以......
  • 【论文精读】On the Relationship Between Self-Attention and Convolutional Layers
    【论文精读】OntheRelationshipBetweenSelf-AttentionandConvolutionalLayers作者:Jean-BaptisteCordonnier,AndreasLoukas,MartinJaggi发表会议:ICLR2020论文地址:arXiv:1911.03584v2目录【论文精读】OntheRelationshipBetweenSelf-AttentionandConv......
  • GaussDB数据库技术解读——高性能关键技术
    GaussDB数据库技术解读——高性能关键技术内容概要:本章节介绍GaussDB中实现的高性能关键技术,内容涉及优化器、执行器、分布式数据库、存储引擎等多个方面。目的:通过对GaussDB数据库关键高性能技术的学习,能够让读者更加清晰的理解数据库内核哪些优化是性能关键点同时也为类似的应......
  • GaussDB技术解读——GaussDB架构介绍之OM运维管理关键技术方案
    ​GaussDBKernelV5OM运维管理关键模块如下。OM运维主要功能有:安装升级节点替换扩容、缩容自动告警巡检备份恢复、容灾日志分析系统在华为云的部署模式下,OM相关组件部署示意图如下:图7华为云OM运维管理用户登录华为云Console,访问GaussDBKernelV5的管控页面,输入......
  • SpringBoot农村人居环境治理监管系统5aul2 带论文文档
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表系统内容:员工,部门信息,土壤质量,水质监管,空气监管,卫生监管,人文环境,设施监管开题报告内容一、选题背景及意义随着我国乡村振兴战略的深入实施,农村人居环境......