首页 > 其他分享 >2023CVPR_Spatial-Frequency Mutual Learning for Face Super-Resolution

2023CVPR_Spatial-Frequency Mutual Learning for Face Super-Resolution

时间:2023-11-06 16:13:18浏览次数:58  
标签:dim Mutual nn self Face channels fuse Resolution out

一. Network:SFMNet

1.网络采用U-Net结构,其中SFMLM-i是不同分辨率的每层结构

2.SPB是空域分支,FRB是频域分支,分别经过FRB和SPB的两个分支信息经过FSIB分支进行信息的融合

3. FRB结构:

class FreBlock9(nn.Module):
    def __init__(self, channels, args):
        super(FreBlock9, self).__init__()

        self.fpre = nn.Conv2d(channels, channels, 1, 1, 0)
        self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True),
                                      nn.Conv2d(channels, channels, 3, 1, 1))
        self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True),
                                      nn.Conv2d(channels, channels, 3, 1, 1))
        self.post = nn.Conv2d(channels, channels, 1, 1, 0)


    def forward(self, x):
        # print("x: ", x.shape)
        _, _, H, W = x.shape
        msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward')

        msF_amp = torch.abs(msF)
        msF_pha = torch.angle(msF)
        # print("msf_amp: ", msF_amp.shape)
        amp_fuse = self.amp_fuse(msF_amp)
        # print(amp_fuse.shape, msF_amp.shape)
        amp_fuse = amp_fuse + msF_amp
        pha_fuse = self.pha_fuse(msF_pha)
        pha_fuse = pha_fuse + msF_pha

        real = amp_fuse * torch.cos(pha_fuse)+1e-8
        imag = amp_fuse * torch.sin(pha_fuse)+1e-8
        out = torch.complex(real, imag)+1e-8
        out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward'))
        out = self.post(out)
        out = out + x
        out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5)
        # print("out: ", out.shape)
        return out
FreBlock

4. FSIB结构:

class Attention(nn.Module):
    def __init__(self, dim=64, num_heads=8, bias=False):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias)
        self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias)
        self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias)
        self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

    def forward(self, x, y):
        b, c, h, w = x.shape

        kv = self.kv_dwconv(self.kv(y))
        k, v = kv.chunk(2, dim=1)
        q = self.q_dwconv(self.q(x))

        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)

        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

class FuseBlock7(nn.Module):
    def __init__(self, channels):
        super(FuseBlock7, self).__init__()
        self.fre = nn.Conv2d(channels, channels, 3, 1, 1)
        self.spa = nn.Conv2d(channels, channels, 3, 1, 1)
        self.fre_att = Attention(dim=channels)
        self.spa_att = Attention(dim=channels)
        self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid())


    def forward(self, spa, fre):
        ori = spa
        fre = self.fre(fre)
        spa = self.spa(spa)
        fre = self.fre_att(fre, spa)+fre
        spa = self.fre_att(spa, fre)+spa
        fuse = self.fuse(torch.cat((fre, spa), 1))
        fre_a, spa_a = fuse.chunk(2, dim=1)
        spa = spa_a * spa
        fre = fre * fre_a
        res = fre + spa

        res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5)
        return res
Fuse Block

 


二. 表达:

1. For PSNR-oriented model, both pixel-level and frequency-level loss functions are adopted to guide the learning of the network.

 

标签:dim,Mutual,nn,self,Face,channels,fuse,Resolution,out
From: https://www.cnblogs.com/yyhappy/p/17812974.html

相关文章

  • 互信息(Mutual Information)的介绍
    互信息指的是两个随机变量之间的关联程度,即给定一个随机变量后,另一个随机变量不确定性的削弱程度,因而互信息取值最小为0,意味着给定一个随机变量对确定一另一个随机变量没有关系,最大取值为随机变量的熵,意味着给定一个随机变量,能完全消除另一个随机变量的不确定性。 互信息(MutualI......
  • face-api基于tensorflow 的人像检测npm 包
    face-api基于tensorflow的人像检测npm包,原始项目为justadudewhohacks/face-api.js但是因为缺少维护,社区有人自己fork了一个新的vladmandic/face-api,可以更好的支持tensorflow新版本,当然很不错还可以支持基于wasm的backend(@tensorflow/tfjs-backend-wasm)参考使用demo.......
  • [win]Surface book2 添加自定义分辨率
    surfacebook213.5英寸 是3:2的屏幕,因为默认分辨率3000*2000实在是太高了,看字的时候眼睛有点吃不消 即使开启windows的自定义缩放也有点难受,同时在个性化里面,内置的分辨率居然没有3:2的了...加上windows的文字渲染机制,在高分辨率下开启cleartype后汉字开始有虚边了,所以决定......
  • MITK编译错误C2220 mitkLabelSetImageToSurfaceFilter.cpp
    错误 C2220 以下警告被视为错误(编译源文件E:\0_MITK\MITK\Modules\Multilabel\mitkLabelSetImageToSurfaceFilter.cpp)[E:\0_MITK\MITK\SuperBuild\MITK-build\Modules\Multilabel\MitkMultilabel.vcxproj] MITK-build E:\0_MITK\MITK\SuperBuild\ep\include\ITK-5.2\i......
  • interface
    2023.10.291.接口中的成员变量默认是publicstaticfinal修饰的2.成员变量不可用private、default、protected修饰3.因为不能属于对象实例的定义方法体,所以不可能有成员变量的getter、setter方法,可见,成员变量属于类(static)4.接口中可以实现default方法......
  • HuggingFace机器视觉学习
    HuggingFace中计算机视觉的现状:https://huggingface.co/blog/zh/cv_state从0开始timm库的quickstarthttps://huggingface.co/docs/timm/quickstart例子中通过调用模型mobilenetv3_large_100识别图像mobilenetv3_large_100模型的说明页https://huggingface.co/timm/mobi......
  • 2023ACMMM_Mutual Information-driven Triple Interaction Network for Efficient Ima
    一.Motivation之前网络存在的缺点:1.使用的有限的频域信息 2. 不充足的信息交互:(1)第一阶段的输出直接作为第二阶段的输入,忽略了中间特征从早期到后期的传播(2)在编码器解码器结构同尺度之间进行特征融合,忽略了阶段内和跨阶段的跨尺度信息交换3. 严重的特征......
  • # 由于我只能访问hugginface网站,但是不能下载里面的数据,所以编写下面的代码,获取从hugg
    #由于我只能访问hugginface网站,但是不能下载里面的数据,所以编写下面的代码,获取从huggingface下载数据的链接。在从其它路径下载数据。#获取huggingface某个模型所有要下载数据的命令行。#可以把结果复制到autodl里,进行执行。速度可以达到13M/s#然后在autodl里进行训练推理......
  • Jlink V8 Interface Description
     JTAGInterfaceConnection(20pin) J-LinkandJ-TracehaveaJTAGconnectorcompatibletoARM'sMulti-ICE.TheJTAGconnectorisa20wayInsulationDisplacementConnector(IDC)keyedboxheader(2.54mmmale)thatmateswithIDCsocketsmou......
  • unity 使用interface 判断 null错误的问题
     在使用Interface,并且由Monobehaviour继承Interface情况下,判断interface的实际UnityEngine.Object是否null,出现错误,没有成功的判断出已经Destroy https://gamedev.stackexchange.com/questions/128971/unity-c-interface-object-never-equals-null解决方案:https://discuss......