首页 > 其他分享 >11

11

时间:2023-12-26 19:57:50浏览次数:21  
标签:11 dim noise nn level self channel

import math
import torch
from torch import nn
import torch.nn.functional as F
from inspect import isfunction
from kornia.filters import gaussian_blur2d


def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


class PositionalEncoding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, noise_level):
        count = self.dim // 2
        step = torch.arange(count, dtype=noise_level.dtype,
                            device=noise_level.device) / count
        encoding = noise_level.unsqueeze(
            1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
        encoding = torch.cat(
            [torch.sin(encoding), torch.cos(encoding)], dim=-1)
        return encoding


class FeatureWiseAffine(nn.Module):

    def __init__(self, in_channels, out_channels, use_affine_level=False):
        super(FeatureWiseAffine, self).__init__()
        self.use_affine_level = use_affine_level
        self.noise_func = nn.Sequential(
            nn.Linear(in_channels, out_channels * (1 + self.use_affine_level))
        )

    def forward(self, x, noise_embed):
        batch = x.shape[0]
        if self.use_affine_level:
            gamma, beta = self.noise_func(noise_embed).view(
                batch, -1, 1, 1).chunk(2, dim=1)
            x = (1 + gamma) * x + beta
        else:
            x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
        return x


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class Upsample(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv = nn.Conv2d(dim, dim, 3, padding=1)

    def forward(self, x):
        return self.conv(self.up(x))


class Downsample(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=32, dropout=0):
        super().__init__()
        self.block = nn.Sequential(
            nn.GroupNorm(groups, dim),
            Swish(),
            nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
            nn.Conv2d(dim, dim_out, 3, padding=1)
        )

    def forward(self, x):
        return self.block(x)


class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
        super().__init__()
        self.noise_func = FeatureWiseAffine(
            noise_level_emb_dim, dim_out, use_affine_level)

        self.block1 = Block(dim, dim_out, groups=norm_groups)
        self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
        self.res_conv = nn.Conv2d(
            dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb):
        b, c, h, w = x.shape
        h = self.block1(x)
        h = self.noise_func(h, time_emb)
        h = self.block2(h)
        return h + self.res_conv(x)


class SelfAttention(nn.Module):
    def __init__(self, in_channel, n_head=1, norm_groups=32):
        super().__init__()

        self.n_head = n_head

        self.norm = nn.GroupNorm(norm_groups, in_channel)
        self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
        self.out = nn.Conv2d(in_channel, in_channel, 1)

    def forward(self, input):
        batch, channel, height, width = input.shape
        n_head = self.n_head
        head_dim = channel // n_head

        norm = self.norm(input)
        qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
        query, key, value = qkv.chunk(3, dim=2)  # bhdyx

        attn = torch.einsum(
            "bnchw, bncyx -> bnhwyx", query, key
        ).contiguous() / math.sqrt(channel)
        attn = attn.view(batch, n_head, height, width, -1)
        attn = torch.softmax(attn, -1)
        attn = attn.view(batch, n_head, height, width, height, width)

        out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
        out = self.out(out.view(batch, channel, height, width))

        return out + input


class ResnetBlocWithAttn(nn.Module):
    def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
        super().__init__()
        self.with_attn = with_attn
        self.res_block = ResnetBlock(
            dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
        if with_attn:
            self.attn = SelfAttention(dim_out, norm_groups=norm_groups)

    def forward(self, x, time_emb):
        x = self.res_block(x, time_emb)
        if (self.with_attn):
            x = self.attn(x)
        return x


class FCB(nn.Module):
    def __init__(self, channel, kernel_size=3):
        super().__init__()
        self.ks = kernel_size
        self.sigma_rate = 1

        params = torch.ones((4, 1), requires_grad=True)
        self.params = nn.Parameter(params)

    def forward(self, x):
        #
        x1 = gaussian_blur2d(x, (self.ks, self.ks), (1 * self.sigma_rate, 1 * self.sigma_rate))
        R1 = x - x1

        x2 = gaussian_blur2d(x, (self.ks * 2 - 1, self.ks * 2 - 1), (2 * self.sigma_rate, 2 * self.sigma_rate))
        x3 = gaussian_blur2d(x, (self.ks * 4 - 1, self.ks * 4 - 1), (4 * self.sigma_rate, 4 * self.sigma_rate))
        R2 = x1 - x2
        R3 = x2 - x3

        R1 = R1.unsqueeze(dim=-1)
        R2 = R2.unsqueeze(dim=-1)
        R3 = R3.unsqueeze(dim=-1)
        R_cat = torch.cat([R1, R2, R3, x.unsqueeze(dim=-1)], dim=-1)

        sum_ = torch.matmul(R_cat, self.params).squeeze(dim=-1)

        return sum_


class UNet(nn.Module):
    def __init__(
            self,
            in_channel=6,
            out_channel=3,
            inner_channel=32,
            norm_groups=32,
            channel_mults=(1, 2, 4, 8, 8),
            attn_res=[8],
            res_blocks=3,
            dropout=0,
            with_noise_level_emb=True,
            image_size=128,
            fcb=True
    ):
        super().__init__()

        self.fcb = fcb

        if with_noise_level_emb:
            noise_level_channel = inner_channel
            self.noise_level_mlp = nn.Sequential(
                PositionalEncoding(inner_channel),
                nn.Linear(inner_channel, inner_channel * 4),
                Swish(),
                nn.Linear(inner_channel * 4, inner_channel)
            )
        else:
            noise_level_channel = None
            self.noise_level_mlp = None

        num_mults = len(channel_mults)
        pre_channel = inner_channel
        feat_channels = [pre_channel]
        now_res = image_size
        downs = [nn.Conv2d(in_channel, inner_channel,
                           kernel_size=3, padding=1)]
        for ind in range(num_mults):
            is_last = (ind == num_mults - 1)
            use_attn = (now_res in attn_res)
            channel_mult = inner_channel * channel_mults[ind]
            for _ in range(0, res_blocks):
                downs.append(ResnetBlocWithAttn(
                    pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                    dropout=dropout, with_attn=use_attn))
                feat_channels.append(channel_mult)
                pre_channel = channel_mult
            if not is_last:
                downs.append(Downsample(pre_channel))
                feat_channels.append(pre_channel)
                now_res = now_res // 2
        self.downs = nn.ModuleList(downs)

        self.mid = nn.ModuleList([
            ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
                               norm_groups=norm_groups,
                               dropout=dropout, with_attn=True),
            ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
                               norm_groups=norm_groups,
                               dropout=dropout, with_attn=False)
        ])

        ups = []
        fbs = []
        for ind in reversed(range(num_mults)):
            is_last = (ind < 1)
            use_attn = (now_res in attn_res)
            channel_mult = inner_channel * channel_mults[ind]
            for _ in range(0, res_blocks + 1):
                ups.append(ResnetBlocWithAttn(
                    pre_channel + feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel,
                    norm_groups=norm_groups,
                    dropout=dropout, with_attn=use_attn))
                pre_channel = channel_mult
                tmp = FCB(pre_channel) if self.fcb else pre_channel
                fbs.append(tmp)
            if not is_last:
                ups.append(Upsample(pre_channel))
                tmp = FCB(pre_channel) if self.fcb else pre_channel
                fbs.append(tmp)
                now_res = now_res * 2

        self.ups = nn.ModuleList(ups)
        self.fbs = nn.ModuleList(fbs)

        self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)

    def forward(self, x, time):
        t = self.noise_level_mlp(time) if exists(
            self.noise_level_mlp) else None

        feats = []
        for layer in self.downs:
            if isinstance(layer, ResnetBlocWithAttn):
                x = layer(x, t)
            else:
                x = layer(x)
            feats.append(x)

        for layer in self.mid:
            if isinstance(layer, ResnetBlocWithAttn):
                x = layer(x, t)
            else:
                x = layer(x)

        for layer, fb in zip(self.ups, self.fbs):
            if isinstance(layer, ResnetBlocWithAttn):
                tmp = feats.pop()
                if self.fcb:
                    tmp = fb(tmp)
                x = layer(torch.cat((x, tmp), dim=1), t)
            else:
                x = layer(x)

        tmp = self.final_conv(x)

        return tmp

 

标签:11,dim,noise,nn,level,self,channel
From: https://www.cnblogs.com/yyhappy/p/17929190.html

相关文章

  • codeforces刷题(1100):1905B_div2
    B、Begginer'sZelda跳转原题点击此:此题地址1、题目大意  给你一个子树,你可任意选择两个节点\(u、v\),这两个节点之间的所有节点(包括\(u、v\))都将结合变为一个新的节点。要求你通过该操作将所有的节点变为只有一个节点,求最小的操作数。2、题目解析  由题意可得:当\(u、v\)......
  • 洛谷B3611 【模板】传递闭包 floyd/bitset
    目录floydbitset优化题目链接:https://www.luogu.com.cn/problem/B3611参考题解:https://www.luogu.com.cn/blog/53022/solution-b3611floyd#include<bits/stdc++.h>usingnamespacestd;constintmaxn=101;intn,f[maxn][maxn];intmain(){scanf("%d"......
  • 动环监控方案,为什么推荐79元全志T113-i国产平台?
    什么是动环监控系统?通信电源及机房环境监控系统(简称“动环监控系统”),是对分布在各机房的电源柜、UPS、空调、蓄电池等多种动力设备,及门磁、红外、窗破、水浸、温湿度、烟感等机房环境的各种参数,进行遥测、遥信、遥调和遥控,实时监测其运行参数、诊断和处理故障、记录和分析相关数......
  • codeforces刷题(1100):1917B_div2
    模板B、EraseFirstorSecondLetter跳转原题点击此:该题地址1、题目大意  给你一个字符串,可以执行任意次以下操作,生成最终的字符串(不可为空),问你能生成的不重复字符串数为多少。操作一:删除字符串第一个字符;操作二:删除字符串第二个字符。2、题目解析  发现,操作一:即选......
  • ORA-01113: file 69 needs media recovery ORA-01110: data file 69: 'E:\FWPTDB\D
    继续上一篇写1、当解决了ORA-01033:ORACLEinitializationorshutdowninprogress 这个问题后重新连接此数据库的时候又出现以下问题ORA-01113:file69needsmediarecovery ORA-01110:datafile69:'E:\FWPTDB\DBFFILES\HNRZ\HNRZFW.DBF2、解决方案  ......
  • 0x11.ACCESS注入
    基本判断常见搭配:asp+access后缀:.mdb,如果有/data/data.asp,直接在留言板或者搜索框里面插入一句话,然后shell掉data.asp工具:辅臣、access密码读取。使用sqlmap时,直接--tables,不用判断database。因为所有的表都在同一个数据库中,而且access数据库不存在database()函数......
  • ie11 css适配
    1. justify-content:space-evenly;兼容性处理justify-content:space-evenly;在IE11中不生效,替换为justify-content:space-between;&:before,&:after{content:'';display:block;}2. background-clipbackground-clip设置元素的背景(背景图片或颜色)是否延......
  • 11 信息打点——红队工具篇&Fofa&Quake&Suize&水泽&Arl灯塔
    1、网络空间四大引擎-Fofa&Quake&Shodan&Zoomeye主要搜关联资产、特征资产、资产信息(在测绘引擎上直接搜IP,它会显示所有与该域名有关的信息。)https://fofa.sohttps://quake.360.cnhttps://www.shodan.iohttps://www.zoomeye.org【例】Fofa:搜关联资产:比如,搜“www.xiaodi8.com......
  • 好用小工具推荐:ExplorerPatcher,支持让Win11任务栏不再合并/右键菜单不再繁琐等
    ExplorerPatcher1、软件简介ExplorerPatcher是一款能够帮助我们让win11换回旧版win10任务栏的软件,让我们能够基于以win10上面那么高效的方式来进行生活或者是工作,不少用户或许已经在系统上安装了Windows11系统,win11在许多地方带来了全新的UI界面,但对于新版的任务栏对于很多老Win......
  • [题解]CF1811D Umka and a Long Flight
    思路假设原题目中的\(n\)在本文中为\(num\),则原长方形的长\(m=f_{num+1}\)和宽\(n=f_{num}\)。显然对于最初始的长方形,显然是要将一个\(f_{num}\timesf_{num}\)的长方形丢进去的,并且要么放最左边,要么放在最右边。因为对于当前的\(m=f_{num+1}=f_{num}+......