首页 > 其他分享 >6、关于Medical-Transformer

6、关于Medical-Transformer

时间:2024-09-05 12:51:29浏览次数:14  
标签:Transformer self torch shape 64 关于 print output Medical

6、关于Medical-Transformer

Axial-Attention原文链接:Axial-attention
Medical-Transformer原文链接:Medical-Transformer

Medical-Transformer实际上是Axial-Attention在医学领域的运行,只是在这基础上增加了门机制,实际上也就是在原来Axial-attention基础之上增加权重机制,虚弱位置信息对于数据的影响,发现虚弱之后的效果比Axial-Attention机制效果更好

Axial-Attention

Axial-Attention与传统Transformer的self-attention相比较,将2D计算转成1D计算,Axial-attention机制,对于qkv的计算,做出了简化,仅仅某个点的横竖两个方向上的特殊,同时在qkv的基础上加上了各自位置特征,这些特征都是更新学习的。

Axial-attention模型架构图

左图为传统的self-attention机制,右图为Axial-attention机制,对于qkv都加上rq,rk,rv这样的位置参数,这些参数都是可以更新的,也就是说,每个的q在和自己对应的横竖轴反向进行计算的时候,q会和自己rq先进行权重计算,同样的k和v也会进行同样的计算,随后进行q和k进行计算得到权重,计算过程和原来的self-attention机制是一样的。

在这里插入图片描述

class AxialAttention(nn.Module):
    def forward(self, x):
    # 前向传播函数
    # 如果设置了 width 参数,调整张量维度顺序
    if self.width:
        x = x.permute(0, 2, 1, 3)  # 调整维度顺序
    else:
        x = x.permute(0, 3, 1, 2)  # N, W, C, H  调整为 N, C, H, W
    N, W, C, H = x.shape  # 获取张量形状
    x = x.contiguous().view(N * W, C, H)  # 重新调整形状,合并 N 和 W 维度

    # 通过x获得对应的qkv 批归一化后计算 qkv
    qkv = self.bn_qkv(self.qkv_transform(x))  
    q, k, v = torch.split(
        qkv.reshape(N * W, self.groups, self.group_planes * 2, H),
        [self.group_planes // 2, self.group_planes // 2, self.group_planes], 
        dim=2
    )  # 将 qkv 拆分为 q, k, v

    # 计算位置嵌入
    all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
    q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)  # 拆分嵌入

    # 计算 QR, KR, QK 相似性,分别计算得出rq,rk
    qr = torch.einsum('bgci,cij->bgij', q, q_embedding)  # QR: q 和 q_embedding 的爱因斯坦求和
    kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)  # KR: k 和 k_embedding 的爱因斯坦求和,并转置
    # q和k进行计算,得到最后的权重
    qk = torch.einsum('bgci, bgcj->bgij', q, k)  # QK: q 和 k 之间的点积

    # 将 QR, KR, QK 相似性进行堆叠,连在一起进行计算
    stacked_similarity = torch.cat([qk, qr, kr], dim=1)  # 将 qk, qr, kr 连接起来
    stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)  # 批归一化并调整形状

    # similarity为q和k计算得出权重关系
    similarity = F.softmax(stacked_similarity, dim=3)  # 在第 3 维度上计算 softmax
    # 将q和v计算出来权重和v加权求和
    sv = torch.einsum('bgij,bgcj->bgci', similarity, v)  # 将相似度与 v 进行求和
    # v与位置信息结合
    sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)  # 将similarity与 v_embedding 进行求和

    # 将位置加权后的v和q和k计算结果与v加权的合并,并调整形状输出
    stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)  # 合并 sv 和 sve,并调整形状
    output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)  # 批归一化并调整形状

    # 恢复维度顺序
    if self.width:
        output = output.permute(0, 2, 1, 3)  # 调整维度顺序
    else:
        output = output.permute(0, 2, 3, 1)  # 调整维度顺序

    # 如果步长大于 1,应用池化操作
    if self.stride > 1:
        output = self.pooling(output)  # 池化

    return output  # 返回输出
横竖轴计算过程

先通过卷积把特征图缩小,然后横竖轴计算时,是将横轴一起进行计算,然后再进行纵轴计算的,完成计算后,通过1x1卷积将特征图还原为原来的大小,在传入下一层进行计算。

在这里插入图片描述

Medical-Transformer
Medical-Transformer架构图

Medical-Transformer实际就是Axial-attention在医学图像分割领域的应用,medical-tranformer大模型架构采用整个图像进行Axial-attention特征提取,同时也将图像分成多个窗口,对每个窗口进行axial-attention特征提取,窗口由于计算量小,可以多进行几层Axial-attention,最终将整个图像特征和窗口特征融合,完成整个的特征提取,值得一提的是在进行窗口Axial-attention时,qkv都没有加上位置编码(也就是下面部分的图像)。

在这里插入图片描述

主体架构
class medt_net(nn.Module):
    def _forward_impl(self, x):
    xin = x.clone()  # 保存输入数据的副本
    x = self.conv1(x)  # 第一个卷积层
    x = self.bn1(x)  # 第一个批归一化层
    x = self.relu(x)  # ReLU 激活函数
    x = self.conv2(x)  # 第二个卷积层
    x = self.bn2(x)  # 第二个批归一化层
    x = self.relu(x)  # ReLU 激活函数
    x = self.conv3(x)  # 第三个卷积层
    x = self.bn3(x)  # 第三个批归一化层
    x = self.relu(x)  # ReLU 激活函数

    x1 = self.layer1(x)  # 第一个残差层 实际上就是 Gated Axial Attention Layer
    x2 = self.layer2(x1)  # 第二个残差层 同样是 Gated Axial Attention Layer

    # 对输入进行插值放大,并通过解码器处理
    x = F.relu(F.interpolate(self.decoder4(x2), scale_factor=(2, 2), mode='bilinear'))
    x = torch.add(x, x1)  # 将放大的特征图与 x1 相加
    x = F.relu(F.interpolate(self.decoder5(x), scale_factor=(2, 2), mode='bilinear'))
    # 以上完成就是图上方整个图像的卷积过程
	# -------------------------------------------------------------------------------------------
    
    x_loc = x.clone()  # 生成一个本地副本
    # 下面对图像进行切分,分别对每个窗口进行局部处理,实际上是16个窗口
    for i in range(0, 4):
        for j in range(0, 4):
            x_p = xin[:, :, 32 * i:32 * (i + 1), 32 * j:32 * (j + 1)]  # 提取32x32的局部patch

            # 逐层卷积处理patch
            x_p = self.conv1_p(x_p)
            x_p = self.bn1_p(x_p)
            x_p = self.relu(x_p)

            x_p = self.conv2_p(x_p)
            x_p = self.bn2_p(x_p)
            x_p = self.relu(x_p)

            x_p = self.conv3_p(x_p)
            x_p = self.bn3_p(x_p)
            x_p = self.relu(x_p)
			# 进行四个
            x1_p = self.layer1_p(x_p)  # 第一个残差层(patch-wise) 这里进行的axial-attention在进行qkv计算时,qkv都没有加入位置信息计算
            x2_p = self.layer2_p(x1_p)  # 第二个残差层(patch-wise)
            x3_p = self.layer3_p(x2_p)  # 第三个残差层(patch-wise)
            x4_p = self.layer4_p(x3_p)  # 第四个残差层(patch-wise)

            # 对patch进行插值放大并通过解码器处理
            x_p = F.relu(F.interpolate(self.decoder1_p(x4_p), scale_factor=(2, 2), mode='bilinear'))
            x_p = torch.add(x_p, x4_p)  # 将放大的特征图与 x4_p 相加
            x_p = F.relu(F.interpolate(self.decoder2_p(x_p), scale_factor=(2, 2), mode='bilinear'))
            x_p = torch.add(x_p, x3_p)  # 将放大的特征图与 x3_p 相加
            x_p = F.relu(F.interpolate(self.decoder3_p(x_p), scale_factor=(2, 2), mode='bilinear'))
            x_p = torch.add(x_p, x2_p)  # 将放大的特征图与 x2_p 相加
            x_p = F.relu(F.interpolate(self.decoder4_p(x_p), scale_factor=(2, 2), mode='bilinear'))
            x_p = torch.add(x_p, x1_p)  # 将放大的特征图与 x1_p 相加
            x_p = F.relu(F.interpolate(self.decoder5_p(x_p), scale_factor=(2, 2), mode='bilinear'))

            x_loc[:, :, 32 * i:32 * (i + 1), 32 * j:32 * (j + 1)] = x_p  # 将局部处理后的结果放回原始位置
	# 将整个图片的axial-attention,和每个窗口得出的结果进行结合
    x = torch.add(x, x_loc)  # 将全局和局部特征进行融合
    x = F.relu(self.decoderf(x))  # 通过最终的解码器层
    x = self.adjust(F.relu(x))  # 调整输出
    return x  # 返回最终输出
Gated Axial Attention Layer

从架构图中可以看出,就是在Axial-attention的基础上,加上了门机制,说白了,也就是在qkv和各自的rq,rk,rv计算完成后,再进行下一步计算之前,进行了一个加权计算,虚弱了位置变量对特征提取结果的影响。

在这里插入图片描述

横向或纵向Gated Axial-attention过程

注意里面qr,kr实际上就是图片中的rq,rk,而

class AxialAttention_dynamic(nn.Module):
    def forward(self, x):
    # 判断是否需要对宽度维度进行变换
    if self.width:
        x = x.permute(0, 2, 1, 3)  # 交换维度顺序,形状变为 [N, C, W, H]
    else:
        x = x.permute(0, 3, 1, 2)  # 交换维度顺序,形状变为 [N, W, C, H]

    N, W, C, H = x.shape  # 获取输入张量的形状
    x = x.contiguous().view(N * W, C, H)  # 将张量变形为 [N * W, C, H]
    print(x.shape)  # 输出形状: [64, 16, 64]

    # 变换操作
    qkv = self.bn_qkv(self.qkv_transform(x))  # 对qkv进行批归一化
    print(qkv.shape)  # 输出形状: [64, 32, 64]
    
    # 将qkv张量拆分为q、k、v,分别表示查询、键和值
    q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), 
                          [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)
    print(q.shape)  # 输出q的形状: [64, 8, 1, 64]
    print(k.shape)  # 输出k的形状: [64, 8, 1, 64]
    print(v.shape)  # 输出v的形状: [64, 8, 2, 64],v有两份

    # 计算位置嵌入
    all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
    print(all_embeddings.shape)  # 输出嵌入的形状: [4, 64, 64],共有4份
    q_embedding, k_embedding, v_embedding =
    								torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
    print(q_embedding.shape)  # 输出q的位置嵌入形状: [1, 64, 64]
    print(k_embedding.shape)  # 输出k的位置嵌入形状: [1, 64, 64]
    print(v_embedding.shape)  # 输出v的位置嵌入形状: [2, 64, 64],v有两份位置编码

    # 计算q与位置嵌入的乘积
    qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
    print(qr.shape)  # 输出qr的形状: [64, 8, 64, 64]

    # 计算k与位置嵌入的乘积,并进行转置
    kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
    print(kr.shape)  # 输出kr的形状: [64, 8, 64, 64]

    # 计算q和k的点积
    qk = torch.einsum('bgci, bgcj->bgij', q, k)
    print(qk.shape)  # 输出qk的形状: [64, 8, 64, 64]

    # 对qr和kr进行初始化,使用self.f_qr和self.f_kr作为初始化的权重
    qr = torch.mul(qr, self.f_qr)
    print(qr.shape)  # 输出qr的形状: [64, 8, 64, 64]
    kr = torch.mul(kr, self.f_kr)
    print(kr.shape)  # 输出kr的形状: [64, 8, 64, 64]

    # 将qk、qr和kr拼接起来
    stacked_similarity = torch.cat([qk, qr, kr], dim=1)
    print(stacked_similarity.shape)  # 输出拼接后的形状: [64, 24, 64, 64]

    # 进行批归一化,重新变形为[N * W, 3, groups, H, H],并对维度1求和
    stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)
    print(stacked_similarity.shape)  # 输出归一化后的形状: [64, 8, 64, 64]

    # 计算相似度
    similarity = F.softmax(stacked_similarity, dim=3)
    print(similarity.shape)  # 输出相似度的形状: [64, 8, 64, 64]

    # 使用相似度与v相乘,获得加权后的值
    sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
    print(sv.shape)  # 输出加权后的形状: [64, 8, 2, 64]

    # 使用相似度与v的位置嵌入相乘
    sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)
    print(sve.shape)  # 输出位置嵌入加权后的形状: [64, 8, 2, 64]

    # 对sv和sve进行初始化
    sv = torch.mul(sv, self.f_sv)
    print(sv.shape)  # 输出sv的形状: [64, 8, 2, 64]
    sve = torch.mul(sve, self.f_sve)
    print(sve.shape)  # 输出sve的形状: [64, 8, 2, 64]

    # 将sv和sve拼接在一起,并重新变形为[N * W, out_planes * 2, H]
    stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)
    print(stacked_output.shape)  # 输出拼接后的形状: [64, 32, 64]

    # 进行批归一化,并变形为[N, W, out_planes, 2, H],对维度-2求和
    output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)
    print(output.shape)  # 输出归一化后的形状: [1, 64, 16, 64]

    # 根据宽度调整维度顺序
    if self.width:
        output = output.permute(0, 2, 1, 3)
    else:
        output = output.permute(0, 2, 3, 1)
    print(output.shape)  # 输出最终的形状: [1, 16, 64, 64]

    # 如果步幅大于1,进行池化操作
    if self.stride > 1:
        output = self.pooling(output)

    return output

标签:Transformer,self,torch,shape,64,关于,print,output,Medical
From: https://blog.csdn.net/sgr011215/article/details/141928040

相关文章

  • Towards Robust Blind Face Restoration with Codebook Lookup Transformer(NeurIPS 2
    TowardsRobustBlindFaceRestorationwithCodebookLookupTransformer(NeurIPS2022)这篇论文试图解决的是盲目面部恢复(blindfacerestoration)问题,这是一个高度不确定的任务,通常需要辅助指导来改善从低质量(LQ)输入到高质量(HQ)输出的映射,或者补充输入中丢失的高质量细节。具体......
  • 【深度学习 transformer】使用pytorch 训练transformer 模型,hugginface 来啦
    HuggingFace是一个致力于开源自然语言处理(NLP)和机器学习项目的社区。它由几个关键组件组成:Transformers:这是一个基于PyTorch的库,提供了各种预训练的NLP模型,如BERT、GPT、RoBERTa、DistilBERT等。它还提供了一个简单易用的API来加载这些模型,并进行微调以适应特定的下游任务......
  • 【HuggingFace Transformers】OpenAIGPTModel源码解析
    OpenAIGPTModel源码解析1.GPT介绍2.OpenAIGPTModel类源码解析说到ChatGPT,大家可能都使用过吧。2022年,ChatGPT的推出引发了广泛的关注和讨论。这款对话生成模型不仅具备了强大的语言理解和生成能力,还能进行非常自然的对话,给用户带来了全新的互动体验。然而,ChatGPT......
  • 关于对“像素”的误解
    你很可能理解错了“像素”被误会是表达者的宿命,却也不必因此就把别人都当无可救药的傻瓜或一概斥为别有用心。——王朔无论是日常生活中的电子产品,还是在某些专业领域,我们都经常会使用像素(Pixel)这个词。但是很可惜,很多时候人们都误解了它。同样说“像素”,在不同语境下,表......
  • LSTM+transformer+稀疏注意力机制(ASSA)时间序列预测(pytorch框架)
    LSTM+transformer+稀疏注意力机制transformer,LSTM,ASSA注意力首发原创!纯个人手打代码,自己研究的创新点,超级新。可以发刊,先发先的,高精度代码。需知:好的创新性模型可以事半功倍。目前太多流水paper,都是旧模型,老师已经审美疲劳,很难发好一点的刊,这种模型很新,让paper审核老师眼......
  • 使用zig语言制作简单博客网站(八)归档页和关于页
    后端代码注册路由//归档文章router.get("/api/article/archive",&articleController.getArchiveArticles);model/article.zig增加以下代码///用于存放归档文章信息pubconstArchiveArticle=struct{id:u32,title:[]constu8,cate_name:......
  • 高创新 | Matlab实现Transformer-GRU-SVM多变量时间序列预测
    高创新|Matlab实现Transformer-GRU-SVM多变量时间序列预测目录高创新|Matlab实现Transformer-GRU-SVM多变量时间序列预测效果一览基本介绍程序设计参考资料效果一览基本介绍1.Matlab实现Transformer-GRU-SVM多变量时间序列预测,Transformer+门控循环单......