首页 > 其他分享 >DDPM生成人脸代码

DDPM生成人脸代码

时间:2024-07-07 15:31:02浏览次数:9  
标签:__ ch nn 代码 torch DDPM init 人脸 self

基于DDPM介绍的理论,简单实现DDPM生成人脸,代码如下:

utils.py

import os
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import glob
import cv2


class MyDataset(Dataset):
    def __init__(self, img_path, device):
        super(MyDataset, self).__init__()
        self.device = device
        self.fnames = glob.glob(os.path.join(img_path+"*.jpg"))
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((32, 32)),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = cv2.imread(fname, cv2.IMREAD_COLOR)
        img = self.transforms(img)
        img = img.to(self.device)
        return img

    def __len__(self):
        return len(self.fnames)

 

model.py

import math
import torch
import torch.nn as nn
from torch.nn import init


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0, "error d_model!"
        super(TimeEmbedding, self).__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        X = pos[:, None] * emb[None, :]
        emb = torch.zeros(T, d_model)
        emb[:, 0::2] = torch.sin(X)
        emb[:, 1::2] = torch.cos(X)

        self.time_embedding = nn.Sequential(nn.Embedding.from_pretrained(emb),
                                            nn.Linear(d_model, dim),
                                            Swish(),
                                            nn.Linear(dim, dim))
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)

    def forward(self, t):
        emb = self.time_embedding(t)
        return emb


class DownSample(nn.Module):
    def __init__(self, in_ch):
        super(DownSample, self).__init__()
        self.down = nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=3, stride=2, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.down.weight)
        init.zeros_(self.down.bias)

    def forward(self, x, temb):
        x = self.down(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_ch):
        super(UpSample, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels=in_ch, out_channels=in_ch, kernel_size=2, stride=2, padding=0)
        self.conv = nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=3, stride=1, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.conv.weight)
        init.zeros_(self.conv.bias)

        init.xavier_uniform_(self.up.weight)
        init.zeros_(self.up.bias)

    def forward(self, x, temb):
        x = self.up(x)
        x = self.conv(x)
        return x


class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super(AttnBlock, self).__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=1, stride=1, padding=0)
        self.initialize()

    def initialize(self):
        for module in [self.proj_q, self.proj_k, self.proj_v]:
            init.xavier_uniform_(module.weight)
            init.zeros_(module.bias)
        init.xavier_uniform_(self.proj.weight, gain=1e-5)

    def forward(self, x):
        N, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(N, H * W, C)
        k = k.view(N, C, H * W)
        score = q @ k * (C**-0.5) # N, H*W, H*W
        score = score.softmax(dim=-1)
        v = v.permute(0, 2, 3, 1).view(N, H*W, C)
        h = score @ v
        h = h.view(N, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)
        return x + h


# DownBlock = ResBlock + AttnBlock
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, t_dim, dropout, attn=False):
        super(ResBlock, self).__init__()
        self.block1 = nn.Sequential(nn.GroupNorm(32, in_ch),
                                    Swish(),
                                    nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1))

        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(t_dim, out_ch)
        )

        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1)
        )

        if in_ch != out_ch:
            self.short_cut = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1, stride=1, padding=0)
        else:
            self.short_cut = nn.Identity()

        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()

        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)
        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[..., None, None]
        h = self.block2(h)
        h += self.short_cut(x)

        h = self.attn(h)
        return h


class UNet(nn.Module):
    def __init__(self, T, ch, ch_ratio, num_res_block, dropout):
        super(UNet, self).__init__()
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)

        self.head = nn.Conv2d(in_channels=3, out_channels=ch, kernel_size=3, stride=1, padding=1)
        self.down_blocks = nn.ModuleList()
        chs = [ch]
        in_ch = ch
        for i, ratio in enumerate(ch_ratio):
            out_ch = ch * ratio
            for _ in range(num_res_block):
                self.down_blocks.append(ResBlock(in_ch=in_ch, out_ch=out_ch, t_dim=tdim,
                                                 dropout=dropout, attn=True))
                in_ch = out_ch
                chs.append(in_ch)

            if i != len(ch_ratio) - 1:
                self.down_blocks.append(DownSample(in_ch=in_ch))
                chs.append(in_ch)

        self.middle_blocks = nn.ModuleList([ResBlock(in_ch=in_ch, out_ch=in_ch, t_dim=tdim, dropout=dropout, attn=True),
                                            ResBlock(in_ch=in_ch, out_ch=in_ch, t_dim=tdim, dropout=dropout, attn=False)])

        self.up_blocks = nn.ModuleList()

        for i, ratio in reversed(list(enumerate(ch_ratio))):
            out_ch = ch * ratio
            for _ in range(num_res_block+1):
                self.up_blocks.append(ResBlock(in_ch=chs.pop()+in_ch, out_ch=out_ch, t_dim=tdim, dropout=dropout, attn=True))
                in_ch = out_ch

            if i != 0:
                self.up_blocks.append(UpSample(in_ch=in_ch))

        self.tail = nn.Sequential(nn.GroupNorm(32, in_ch),
                                  Swish(),
                                  nn.Conv2d(in_channels=in_ch, out_channels=3, kernel_size=3, stride=1, padding=1))

        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)

        init.xavier_uniform_(self.tail[-1].weight)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        temb = self.time_embedding(t)
        h = self.head(x)
        # down
        hs = [h]
        for layer in self.down_blocks:
            h = layer(h, temb)
            hs.append(h)

        # middle
        for layer in self.middle_blocks:
            h = layer(h, temb)

        # up
        for layer in self.up_blocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)

        h = self.tail(h)
        return h


if __name__ == '__main__':
    batch_size = 8
    net = UNet(T=1000, ch=128, ch_ratio=[1, 2, 2, 2], num_res_block=3, dropout=0.1)
    x = torch.rand(batch_size, 1, 32, 32)
    y = torch.randint(1000, (batch_size, ))
    y = net(x, y)
    torch.save(net.state_dict(), "model_______.pth")

 

diffusion.py

import torch
import torch.nn as nn
import torch.nn.functional as F


def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    out = torch.gather(v, index=t, dim=0).float()
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))


class GaussionDiffusion(nn.Module):
    def __init__(self, model, image_size, image_channel, beta_1, beta_T, T):
        super(GaussionDiffusion, self).__init__()
        self.model = model
        self.image_size = image_size
        self.image_channel = image_channel
        self.T = T

        betas = torch.linspace(beta_1, beta_T, T).double()
        alphas = 1. - betas
        alphas_bar = torch.cumprod(alphas, dim=0)

        self.register_buffer("betas", betas)
        self.register_buffer("sqrt_alphas_bar", torch.sqrt(alphas_bar))
        self.register_buffer("sqrt_one_minus_alphas_bar", torch.sqrt(1. - alphas_bar))
        self.register_buffer("remove_noise_coef", betas/torch.sqrt(1-alphas_bar))
        self.register_buffer("reciprocal_sqrt_alphas", 1./torch.sqrt(alphas))
        self.register_buffer("sigma", torch.sqrt(betas))

    def forward(self, x_0):
        """
        Algorithm 1.
        """
        t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
        noise = torch.randn_like(x_0)
        x_t = (
            extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
            extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
        loss = F.mse_loss(self.model(x_t, t), noise, reduction='mean')
        return loss

    def sample(self, batch_size, device):
        x = torch.randn(batch_size, self.image_channel, self.image_size, self.image_size, device=device)

        for t in reversed(range(self.T)):
            # t = x_t.new_ones([x_t.shape[0], ], dtype=torch.long) * time_step
            t_batch = torch.tensor([t], device=device).repeat(batch_size)
            x = (x - extract(self.remove_noise_coef, t_batch, x.shape) * self.model(x, t_batch)) * \
                extract(self.reciprocal_sqrt_alphas, t_batch, x.shape)

            if t > 0:
                x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
        return x

 

train.py

import torch
import argparse
from torch.utils.data import DataLoader
from torch.optim import Adam
from utils import MyDataset
from torchvision.utils import save_image
from tqdm import tqdm
# from unet import UNet
from model import UNet
from diffusion import GaussionDiffusion


def args_parser():
    parser = argparse.ArgumentParser(description="Parameters of training vae model")
    parser.add_argument("-b", "--batch_size", type=int, default=64)
    parser.add_argument("-i", "--in_channels", type=int, default=3)
    parser.add_argument("-d", "--latent_dim", type=int, default=64)
    parser.add_argument("-l", "--lr", type=float, default=1e-4)
    parser.add_argument("-w", "--weight_decay", type=float, default=1e-5)
    parser.add_argument("-e", "--epoch", type=int, default=500)
    parser.add_argument("-v", "--snap_epoch", type=int, default=1)
    parser.add_argument("-n", "--num_samples", type=int, default=64)
    parser.add_argument("-p", "--path", type=str, default="./results_linear")
    parser.add_argument("--T", type=int, default=1000)
    parser.add_argument("--ch", type=int, default=32)
    parser.add_argument("--ch_ratio", type=list, default=[1, 2, 2, 2])
    parser.add_argument("--num_res_block", type=int, default=2)
    parser.add_argument("--dropout", type=float, default=0.1)

    parser.add_argument('--beta_1', type=float, default=1e-4, help='start beta value')
    parser.add_argument('--beta_T', type=float, default=0.02, help='end beta value')
    parser.add_argument('--mean_type', type=str, choices=['xprev', 'xstart', 'epsilon'], default='epsilon', help='predict variable')
    parser.add_argument('--var_type', choices=['fixedlarge', 'fixedsmall'], default='fixedlarge', help='variance type')
    parser.add_argument("--image_size", type=int, default=32)
    parser.add_argument("--image_channels", type=int, default=3)

    return parser.parse_args()


if __name__ == '__main__':
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    opt = args_parser()

    dataset = MyDataset(img_path="../faces/", device=DEVICE)
    # trans = transforms.Compose([transforms.ToTensor(),transforms.Resize((32, 32))])
    # dataset = MNIST("./mnist_data", download=True, transform=trans)
    train_loader = DataLoader(dataset=dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4)
    model = UNet(T=opt.T, ch=opt.ch, ch_ratio=opt.ch_ratio, num_res_block=opt.num_res_block, dropout=opt.dropout)
    # model = UNet(opt.in_channels, opt.latent_dim)

    diffusion = GaussionDiffusion(model, opt.image_size, opt.image_channels, opt.beta_1, opt.beta_T, opt.T).to(DEVICE)
    optimizer = Adam(diffusion.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)

    for epoch in range(opt.epoch):
        diffusion.train()
        # data_bar = tqdm(train_loader)

        for step, x_0 in enumerate(train_loader):
            pbar = tqdm(total=len(train_loader), desc=f"Epoch {step + 1}/{len(train_loader)}")
            optimizer.zero_grad()
            loss = diffusion(x_0.to(DEVICE))
            loss.backward()
            optimizer.step()

        if epoch % opt.snap_epoch == 0 or epoch == opt.epoch - 1:
            diffusion.eval()
            with torch.no_grad():
                images = diffusion.sample(opt.num_samples, device=DEVICE)
                imgs = images.detach().cpu().numpy()
                fname = './my_generated-images-epoch_{0:0=4d}.png'.format(epoch)
                save_image(images, fname, nrow=8)
                torch.save(diffusion.state_dict(), f"./model_step_{epoch}.pth")

 

标签:__,ch,nn,代码,torch,DDPM,init,人脸,self
From: https://www.cnblogs.com/xjlearningAI/p/18288550

相关文章

  • let 声明的变量,只在代码块内有效
    {leta=10;varb=1;}a//ReferenceError:aisnotdefinedb//1for循环的计数器,就很适合使用let命令。for(leti=0;i<10;i++){//...}console.log(i);//ReferenceError:iisnotdefined上面代码中,计数器i只在for循环体内......
  • 代码的坏味道——长参数
        前言:一个函数的参数越少越好,并不是参数少或不传更优雅,而是有其他方案来优化长参数。一个函数的参数尽量不要超过3个,如果超过了这个限制,那么代码的坏味道就产生了。一、整合参数如果参数很多,那么第一就要考虑,这些参数是否存在关联?若存在是否可以归为一组?badCase:......
  • 微信小程序-首页制作 - (图解+代码流程)
    目录首页制作效果图一、轮播图的制作1.首页轮播图.wxml代码2.swiper和swiper-item组件二、滑动视图效果图1.首页滑动视图.wxml代码scroll-view组件2.首页滑动视图.wxss代码white-space:nowrap;三、标题和学员作品图片布局效果图1.标题和作品图片.wxml代......
  • 使用zdppy_api+onlyoffice word文档在线共同编辑,附完整的vue3前端代码和python后端代
    参考文档:https://api.onlyoffice.com/zh/editors/basichttps://api.onlyoffice.com/zh/editors/coedit基本的架构思考:文档表:记录的是文档信息key:这个key可以标识唯一的一个文档,可以是文档的hash值fileType:文档的类型,docx,txt,pdf,其他title:文档的标题,也就是文档的实际......
  • 轻松解决win7和win10共享打印机出现错误代码0x00000709的办法
    轻松解决win7系统共享打印机错误代码0x00000709的办法轻松解决win10系统共享打印机错误代码0x00000709的办法为了方便用户更方便充分的利用打印机,配置打印机共享功能,开启共享后可以查询到共享的打印机,但是点击选择连接时出现错误代码0x00000709,尝试了各种方法修改注册表等还......
  • Java毕设项目汇总 - 1 - springboot框架+vue+源代码+论文等完整资料
    逃逸的卡路里博主介绍:✌️码农一枚|毕设布道师,专注于大学生项目实战开发、讲解和毕业......
  • stm32串口 环形缓冲区 代码
    voidHAL_UART_RxCpltCallback(UART_HandleTypeDef*huart){ //printf("ITIN\r\n");// printf("%d\r\n",HAL_GetTick()); //置零设定电流值PID时间if(huart->Instance==USART3){ //将数据放入缓冲区 circular_buffer.buffe......
  • 重识Java中的代码块
    目录一、基本介绍二、基本语法三、代码块的好处3.1普通代码块 3.2静态代码块四、代码块调用顺序 一、基本介绍代码块又称为初始化块,属于类的成员,类似于方法,将逻辑语句封装在方法体中,通过{}包围起来。但是与类不同的是,它没有方法名,没有返回,没有参数,只有方法体,......
  • 「代码随想录算法训练营」第一天(补) | 数组 part1
    704.二分查找题目链接:https://leetcode.cn/problems/binary-search/题目难度:简单文章讲解:https://programmercarl.com/0704.二分查找.html视频讲解:https://www.bilibili.com/video/BV1fA4y1o715题目状态:通过个人思路:就是简单的二分查找思路,判断数组中间的值是否等于目......
  • NET 中的 12 个简单干净代码技巧
    编写干净的代码对于可维护性、可读性和可扩展性至关重要。这里有12个简单的技巧可以帮助您在.Net中编写更干净的代码,每个技巧都附有好的和坏的代码片段。1.使用有意义的名字糟糕的代码publicclassC{publicvoidM(){vara=10;var......