首页 > 其他分享 >【前沿模型解析】潜在扩散模型 2-3 | 手撕感知图像压缩 基础块 自注意力块

【前沿模型解析】潜在扩散模型 2-3 | 手撕感知图像压缩 基础块 自注意力块

时间:2024-04-09 20:31:27浏览次数:27  
标签:heads nn 模型 torch channels 图像压缩 感知 self out

1 注意力机制回顾

同ResNet一样,注意力机制应该也是神经网络最重要的一部分了。

想象一下你在观看一场电影,但你的朋友在给你发短信。虽然你正在专心观看电影,但当你听到手机响起时,你会停下来查看短信,然后这时候电影的内容就会被忽略。这就是注意力机制的工作原理。

在处理输入序列时,比如一句话中的每个单词,注意力机制允许模型像你一样,专注于输入中的不同部分。模型可以根据输入的重要性动态地调整自己的注意力,注意自己觉得比较重要的部分,忽略一些不太重要的部分,以便更好地理解和处理序列数据。

具体来说,是通过q,k,v实现的

q(查询),k(键值)之间先进行计算,获得重要性权重w,w再作用于v

利用卷积操作确定q,k,v

q,k做运算得到w,缩放w

w和v做运行

最后残差

得到

2 Atten块的实现

在这里插入图片描述

2.1 初始化函数

    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = torch.nn.GroupNorm(num_groups=3, num_channels=in_channels, eps=1e-6, affine=True)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)

2.2 前向传递函数

def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention 自注意力计算
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w) #[4,12,1024]
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)
  1. b,c,h,w = q.shape:假设q是一个四维张量,其中b表示batch size,c表示通道数,hw表示高度和宽度。

  2. q = q.reshape(b,c,h*w):将q张量重新形状为三维张量,其中第三维是原高度和宽度的乘积。这样做是为了方便后续计算。

  3. q = q.permute(0,2,1):交换张量维度,将第三维移动到第二维,这是为了后续计算方便。

  4. k = k.reshape(b,c,h*w):对k做和q类似的操作,将其形状改为三维张量。

  5. w_ = torch.bmm(q,k):计算qk的批次矩阵乘积(batch matrix multiplication),得到注意力权重的初始矩阵。这里的w_是一个b x (h*w) x (h*w)的张量,表示每个位置对应的注意力权重。

  6. w_ = w_ * (int(c)**(-0.5)):对初始注意力权重进行缩放,这里使用了一个缩放因子,通常是通道数的倒数的平方根。这个缩放是为了确保在计算注意力时不会因为通道数过大而导致梯度消失或梯度爆炸。

  7. w_ = torch.nn.functional.softmax(w_, dim=2):对注意力权重进行softmax操作,将其归一化为概率分布,表示每个位置的重要性。

这段代码的作用是实现自注意力机制中计算注意力权重的过程,其中qk分别代表查询(query)和键(key),通过计算它们的相似度得到注意力权重。

        # attend to values 加注意力到值上
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] [4,12,1024]*[4,1024,1024]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_
  1. v = v.reshape(b,c,h*w):将值(value)张量v重新形状为三维张量,其中第三维是原高度和宽度的乘积。这样做是为了方便后续计算。

  2. w_ = w_.permute(0,2,1):交换注意力权重w_张量的维度,将第三维移动到第二维,这是为了后续计算方便。

  3. h_ = torch.bmm(v,w_):计算值v和经过缩放的注意力权重w_的批次矩阵乘积(batch matrix multiplication),得到自注意力的输出。这里的h_是一个b x c x (h*w)的张量,表示每个位置经过注意力计算后的输出。

  4. h_ = h_.reshape(b,c,h,w):将h_张量重新形状为四维张量,恢复其原始的高度和宽度。

  5. h_ = self.proj_out(h_):通过一个全连接层proj_out对自注意力的输出h_进行线性变换和非线性变换,这个操作有助于提取特征并保持网络的表达能力。

最后,将输入x和自注意力的输出h_相加,得到最终的自注意力输出。这样做是为了在保留原始输入信息的同时,加入了经过自注意力计算后的新信息,从而使模型能够更好地理解输入序列的语义信息。

2.3 Atten注意力完整代码

from torch import nn
import torch
from einops import rearrange


class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = torch.nn.GroupNorm(num_groups=3, num_channels=in_channels, eps=1e-6, affine=True)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)


    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention 自注意力计算
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w) #[4,12,1024]
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values 加注意力到值上
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] [4,12,1024]*[4,1024,1024]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_

def make_attn(in_channels, attn_type="vanilla"):
    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
    if attn_type == "vanilla":
        return AttnBlock(in_channels)
    else:
        return nn.Identity(in_channels)
    

atten_block=make_attn(12)
x=torch.ones(4,12,32,32)
y=atten_block(x)
print(y.shape)

3 源代码中的另一种注意力实现

源代码中还实现了LinearAttention,是另一种注意力机制

可以看看

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
        k = k.softmax(dim=-1)  
        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhde,bhdn->bhen', context, q)
        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
        return self.to_out(out)

class LinAttnBlock(LinearAttention):
    """to match AttnBlock usage"""
    def __init__(self, in_channels):
        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)

对于forward函数

  1. b, c, h, w = x.shape:假设输入张量x是一个四维张量,其中b表示batch size,c表示通道数,hw表示高度和宽度。

  2. qkv = self.to_qkv(x):将输入张量x通过一个线性变换(可能包括分别计算查询(query)、键(key)和值(value))得到qkv张量,其形状为b x (3*heads*c) x h x w,其中heads是多头注意力的头数。

  3. q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3):将qkv张量重新排列为三个张量qkv,分别表示查询、键和值,形状为b x heads x c x (h*w)

  4. k = k.softmax(dim=-1):对键张量k进行softmax操作,将其归一化为概率分布,以便计算注意力权重。

  5. context = torch.einsum('bhdn,bhen->bhde', k, v):使用torch.einsum函数计算注意力权重与值的加权和,得到上下文张量context,形状为b x heads x c x (h*w)

  6. out = torch.einsum('bhde,bhdn->bhen', context, q):使用torch.einsum函数计算上下文张量与查询张量的加权和,得到输出张量out,形状为b x heads x c x (h*w)

  7. out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w):将输出张量out重新排列为形状b x (heads*c) x h x w,恢复其原始形状。

  8. return self.to_out(out):将输出张量out通过一个线性变换得到最终的输出。

如果注意力机制type=None的话,则不进行注意力机制的计算~

用一个torch函数

nn.Identity 这是一个恒等变化的一个函数,不做任何处理

4 完整代码及其测试

from torch import nn
import torch
from einops import rearrange

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
        k = k.softmax(dim=-1)  
        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhde,bhdn->bhen', context, q)
        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
        return self.to_out(out)

class LinAttnBlock(LinearAttention):
    """to match AttnBlock usage"""
    def __init__(self, in_channels):
        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)

class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = torch.nn.GroupNorm(num_groups=3, num_channels=in_channels, eps=1e-6, affine=True)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)


    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention 自注意力计算
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w) #[4,12,1024]
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values 加注意力到值上
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] [4,12,1024]*[4,1024,1024]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_

def make_attn(in_channels, attn_type="vanilla"):
    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
    if attn_type == "vanilla":
        return AttnBlock(in_channels)
    elif attn_type=="line":
        return LinAttnBlock(in_channels)
    else:
        return nn.Identity(in_channels)
    

atten_block=make_attn(12)
x=torch.ones(4,12,32,32)
y=atten_block(x)
print(y.shape)

标签:heads,nn,模型,torch,channels,图像压缩,感知,self,out
From: https://blog.csdn.net/Q52099999/article/details/137566236

相关文章

  • 通过termux tailscale huggingface 来手把手一步一步在手机上部署LLAMA2-7b和LLAMA2-7
    前言首先截图一下我的小米手机的配置我这个配置其实一般,当时主要为了存储空间大,而我对配置要求又不高,买的。在手机上安装termuxapp,然后在termux上就可以使用各种命令进行安装和使用,就像在linux操作系统上一样了。再通过termux安装上openssh,这样你就可以在window、mac等电......
  • CSS -层叠性、继承性、盒子模型、盒子模型表格、盒子模型margin、盒子阴影
    层叠性CSS层叠性(Cascading)是指在网页中应用多个样式规则时,根据一定的规则来确定最终应用的样式。层叠性使得样式可以按照一定的优先级和规则进行组合和覆盖,从而实现对元素的样式控制。层叠性的影响因素:选择器的特殊性(Specificity):选择器的特殊性决定了样式规则的优先级。......
  • 54、C++内存模型
    在 C++ 中,程序运行时,内存主要分成四个区,分别是栈、堆、数据段和代码段。                栈:存储局部变量、函数参数和返回值。堆:存储动态开辟内存的变量。数据段:存储全局变量和静态变量。代码段:存储可执行程序的代码和常量(例如字符常量),此存储区不可修......
  • 决策树模型(4)Cart算法
    Cart算法Cart是Classificationandregressiontree的缩写,即分类回归树。它和前面的ID3,C4.5等算法思想一致都是通过对输入空间进行递归划分并确定每个单元上预测的概率分布,进而进行回归和分类任务。只不过由于任务的不同,所以回归树和分类树的划分准则并不相同。Cart生成回归......
  • 运用预训练 Keras 模型来处理图像分类请求,学习如何使用从 Keras 创建 SavedModel
    前置import'''importosimporttempfilefrommatplotlibimportpyplotaspltimportnumpyasnpimporttensorflowastftmpdir=tempfile.mkdtemp()'''介绍如何用keras检测自己找的图片'''file=tf.keras.utils.get_file(&quo......
  • R语言多元Copula GARCH 模型时间序列预测|附代码数据
    原文链接  http://tecdat.cn/?p=2623原文出处:拓端数据部落公众号 最近我们被要求撰写关于CopulaGARCH的研究报告,包括一些图形和统计输出。和宏观经济数据不同,金融市场上多为高频数据,比如股票收益率序列。直观的来说,后者是比前者“波动”更多且随机波动的序列,在一元或多元......
  • HSPF(Hydrological Simulation Program Fortran)模型
    HSPF模型与SWAT模型一样都是著名的水文模型软件,在世界各地的水文模拟中得到广泛的应用。由于种种原因,HSPF模型在国内的影响力不如SWAT;但是,HSPF模型也有其自身的优势,比如:1.它有很高集成度的前后处理软件,减轻建模的负担;2.它可以自主调节水文响应单元的大小,模型有更好的灵活性;3.它......
  • R+VIC模型融合实践技术应用及未来气候变化模型预测
    在气候变化问题日益严重的今天,水文模型在防洪规划,未来预测等方面发挥着不可替代的重要作用。目前,无论是工程实践或是科学研究中都存在很多著名的水文模型如SWAT/HSPF/HEC-HMS等。虽然,这些软件有各自的优点;但是,由于适用的尺度主要的是中小流域,所以在预测气候变化对水文过程影响......
  • 【最新】Claude Pro订阅充值教程,超大杯模型Claude 3 Opus模型体验方法
    一、关于ClaudePro|Claude3OpusClaude3系列包含三个大模型,按能力由弱到强别是:Claude3Haiku(最小/速度最快)Claude3Sonnet(标准/免费使用)Claude3Opus(最强/需要付费订阅)其中,最强的Opus在多项基准测试中得分都超过了GPT-4和Gemini1.0Ultra,......
  • visionOS 专门应用提交数大幅下降;Kimi 不断「吊打」国内各大厂 AI 模型丨 RTE 开发者
       开发者朋友们大家好: 这里是「RTE开发者日报」,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享RTE(RealTimeEngagement)领域内「有话题的新闻」、「有态度的观点」、「有意思的数据」、「有思考的文章」、「有看点的会议」,但内容仅代表编辑的个人观......