首页 > 其他分享 >diffusion model 代码

diffusion model 代码

时间:2024-07-08 23:19:25浏览次数:10  
标签:diffusion alphas tensor 代码 torch num model grad fn

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torch

s_curve,_ = make_s_curve(10**4,noise=0.1)
s_curve = s_curve[:,[0,2]]/10.0

print("shape of s:",np.shape(s_curve))

data = s_curve.T

fig,ax = plt.subplots()
ax.scatter(*data,color='blue',edgecolor='white');

ax.axis('off')

dataset = torch.Tensor(s_curve).float()
shape of s: (10000, 2)

png

2、确定超参数的值

num_steps = 100

#制定每一步的beta
betas = torch.linspace(-6,6,num_steps)
betas = torch.sigmoid(betas)*(0.5e-2 - 1e-5)+1e-5

#计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1-betas
alphas_prod = torch.cumprod(alphas,0)
alphas_prod_p = torch.cat([torch.tensor([1]).float(),alphas_prod[:-1]],0)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)

assert alphas.shape==alphas_prod.shape==alphas_prod_p.shape==\
alphas_bar_sqrt.shape==one_minus_alphas_bar_log.shape\
==one_minus_alphas_bar_sqrt.shape
print("all the same shape",betas.shape)
all the same shape torch.Size([100])

3、确定扩散过程任意时刻的采样值

#计算任意时刻的x采样值,基于x_0和重参数化
def q_x(x_0,t):
    """可以基于x[0]得到任意时刻t的x[t]"""
    noise = torch.randn_like(x_0)
    alphas_t = alphas_bar_sqrt[t]
    alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
    return (alphas_t * x_0 + alphas_1_m_t * noise)#在x[0]的基础上添加噪声
    

4、演示原始数据分布加噪100步后的结果

num_shows = 20
fig,axs = plt.subplots(2,10,figsize=(28,3))
plt.rc('text',color='black')

#共有10000个点,每个点包含两个坐标
#生成100步以内每隔5步加噪声后的图像
for i in range(num_shows):
    j = i//10
    k = i%10
    q_i = q_x(dataset,torch.tensor([i*num_steps//num_shows]))#生成t时刻的采样数据
    axs[j,k].scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white')
    axs[j,k].set_axis_off()
    axs[j,k].set_title('$q(\mathbf{x}_{'+str(i*num_steps//num_shows)+'})$')

png

5、编写拟合逆扩散过程高斯分布的模型

import torch
import torch.nn as nn

class MLPDiffusion(nn.Module):
    def __init__(self,n_steps,num_units=128):
        super(MLPDiffusion,self).__init__()
        
        self.linears = nn.ModuleList(
            [
                nn.Linear(2,num_units),
                nn.ReLU(),
                nn.Linear(num_units,num_units),
                nn.ReLU(),
                nn.Linear(num_units,num_units),
                nn.ReLU(),
                nn.Linear(num_units,2),
            ]
        )
        self.step_embeddings = nn.ModuleList(
            [
                nn.Embedding(n_steps,num_units),
                nn.Embedding(n_steps,num_units),
                nn.Embedding(n_steps,num_units),
            ]
        )
    def forward(self,x,t):
#         x = x_0
        for idx,embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t)
            x = self.linears[2*idx](x)
            x += t_embedding
            x = self.linears[2*idx+1](x)
            
        x = self.linears[-1](x)
        
        return x

6、编写训练的误差函数

def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):
    """对任意时刻t进行采样计算loss"""
    batch_size = x_0.shape[0]
    
    #对一个batchsize样本生成随机的时刻t
    t = torch.randint(0,n_steps,size=(batch_size//2,))
    t = torch.cat([t,n_steps-1-t],dim=0)
    t = t.unsqueeze(-1)
    
    #x0的系数
    a = alphas_bar_sqrt[t]
    
    #eps的系数
    aml = one_minus_alphas_bar_sqrt[t]
    
    #生成随机噪音eps
    e = torch.randn_like(x_0)
    
    #构造模型的输入
    x = x_0*a+e*aml
    
    #送入模型,得到t时刻的随机噪声预测值
    output = model(x,t.squeeze(-1))
    
    #与真实噪声一起计算误差,求平均值
    return (e - output).square().mean()

7、编写逆扩散采样函数(inference)

def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):
    """从x[T]恢复x[T-1]、x[T-2]|...x[0]"""
    cur_x = torch.randn(shape)
    x_seq = [cur_x]
    for i in reversed(range(n_steps)):
        cur_x = p_sample(model,cur_x,i,betas,one_minus_alphas_bar_sqrt)
        x_seq.append(cur_x)
    return x_seq

def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):
    """从x[T]采样t时刻的重构值"""
    t = torch.tensor([t])
    
    coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
    
    eps_theta = model(x,t)
    
    mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))
    
    z = torch.randn_like(x)
    sigma_t = betas[t].sqrt()
    
    sample = mean + sigma_t * z
    
    return (sample)

8、开始训练模型,打印loss及中间重构效果

seed = 1234

class EMA():
    """构建一个参数平滑器"""
    def __init__(self,mu=0.01):
        self.mu = mu
        self.shadow = {}
        
    def register(self,name,val):
        self.shadow[name] = val.clone()
        
    def __call__(self,name,x):
        assert name in self.shadow
        new_average = self.mu * x + (1.0-self.mu)*self.shadow[name]
        self.shadow[name] = new_average.clone()
        return new_average
    
print('Training model...')
batch_size = 128
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
num_epoch = 4000
plt.rc('text',color='blue')

model = MLPDiffusion(num_steps)#输出维度是2,输入是x和step
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

for t in range(num_epoch):
    for idx,batch_x in enumerate(dataloader):
        loss = diffusion_loss_fn(model,batch_x,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,num_steps)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.)
        optimizer.step()
        
    if(t%100==0):
        print(loss)
        x_seq = p_sample_loop(model,dataset.shape,num_steps,betas,one_minus_alphas_bar_sqrt)
        
        fig,axs = plt.subplots(1,10,figsize=(28,3))
        for i in range(1,11):
            cur_x = x_seq[i*10].detach()
            axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');
            axs[i-1].set_axis_off();
            axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')
Training model...
tensor(0.5574, grad_fn=<MeanBackward0>)
tensor(0.3723, grad_fn=<MeanBackward0>)
tensor(0.2658, grad_fn=<MeanBackward0>)
tensor(0.2712, grad_fn=<MeanBackward0>)
tensor(0.3424, grad_fn=<MeanBackward0>)
tensor(0.1650, grad_fn=<MeanBackward0>)
tensor(0.1895, grad_fn=<MeanBackward0>)
tensor(0.5016, grad_fn=<MeanBackward0>)
tensor(0.2864, grad_fn=<MeanBackward0>)
tensor(0.3404, grad_fn=<MeanBackward0>)
tensor(0.2952, grad_fn=<MeanBackward0>)
tensor(0.2890, grad_fn=<MeanBackward0>)
tensor(0.4553, grad_fn=<MeanBackward0>)
tensor(0.4187, grad_fn=<MeanBackward0>)
tensor(0.3025, grad_fn=<MeanBackward0>)
tensor(0.2109, grad_fn=<MeanBackward0>)
tensor(0.2516, grad_fn=<MeanBackward0>)
tensor(0.3910, grad_fn=<MeanBackward0>)
tensor(0.1707, grad_fn=<MeanBackward0>)
tensor(0.2568, grad_fn=<MeanBackward0>)
tensor(0.4519, grad_fn=<MeanBackward0>)


C:\Users\ASUS\anaconda3\lib\site-packages\ipykernel_launcher.py:39: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).


tensor(0.2200, grad_fn=<MeanBackward0>)
tensor(0.3391, grad_fn=<MeanBackward0>)
tensor(0.3316, grad_fn=<MeanBackward0>)
tensor(0.3322, grad_fn=<MeanBackward0>)
tensor(0.4364, grad_fn=<MeanBackward0>)
tensor(0.1252, grad_fn=<MeanBackward0>)
tensor(0.3090, grad_fn=<MeanBackward0>)
tensor(0.2585, grad_fn=<MeanBackward0>)
tensor(0.4264, grad_fn=<MeanBackward0>)
tensor(0.3773, grad_fn=<MeanBackward0>)
tensor(0.2524, grad_fn=<MeanBackward0>)
tensor(0.5144, grad_fn=<MeanBackward0>)
tensor(0.3425, grad_fn=<MeanBackward0>)
tensor(0.6134, grad_fn=<MeanBackward0>)
tensor(0.3546, grad_fn=<MeanBackward0>)
tensor(0.1943, grad_fn=<MeanBackward0>)
tensor(0.4279, grad_fn=<MeanBackward0>)
tensor(0.4014, grad_fn=<MeanBackward0>)
tensor(0.2760, grad_fn=<MeanBackward0>)

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

9、动画演示扩散过程和逆扩散过程

import io
from PIL import Image

imgs = []
for i in range(100):
    plt.clf()
    q_i = q_x(dataset,torch.tensor([i]))
    plt.scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white',s=5);
    plt.axis('off');
    
    img_buf = io.BytesIO()
    plt.savefig(img_buf,format='png')
    img = Image.open(img_buf)
    imgs.append(img)

png

reverse = []
for i in range(100):
    plt.clf()
    cur_x = x_seq[i].detach()
    plt.scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white',s=5);
    plt.axis('off')
    
    img_buf = io.BytesIO()
    plt.savefig(img_buf,format='png')
    img = Image.open(img_buf)
    reverse.append(img)

png

imgs = imgs +reverse
imgs[0].save("diffusion.gif",format='GIF',append_images=imgs,save_all=True,duration=100,loop=0)



标签:diffusion,alphas,tensor,代码,torch,num,model,grad,fn
From: https://www.cnblogs.com/zhangxianrong/p/18290865

相关文章

  • 最简 INA226 写寄存器的代码
    #include"hardware/i2c.h"#include"pico/binary_info.h"#defineI2C_SDA16#defineI2C_SCL17voidsetup(){//putyoursetupcodehere,torunonce:Serial.begin(115200);i2c_init(i2c_default,100*1000);gpio_set_functio......
  • 代码随想录(day1)二分法
    if语句的基本语法if要判断的条件:条件成立的时候,要做的事举例:ifnums[middle]<target:left=middle+1while语句的基本语法:while判断条件(condition):'''执行语句(statements)'''举例:whileleft<=right:middle=left+(right-left)//2题目:代码:class......
  • 代码随想录刷题day 6 | 哈希表理论基础 242.有效的字母异位词 349. 两个数组的交
    242.有效的字母异位词383.赎金信classSolution{//这里只给出了242的代码,赎金信的解法可以说是基本相同的publicbooleanisAnagram(Strings,Stringt){int[]map=newint[26];for(charc:s.toCharArray())map[c-'a']++;for(char......
  • C++基础入门语法--代码基础框架
    文章内容概括:了解学习导入头文件、使用usingnamespacestd简化代码、创建程序基础框架、学习使用return(如需要直接复制请到文章最末尾)正文:1.学习导入头文件:    在Dev-C++编辑器中新建文件,在文件的第一行中输入:#include<iostream>    以上代码为C++导入......
  • python 自动化神器 多平台纯代码RPA办公自动化python框架
    ​ Pyaibote是一款专注于纯代码RPA(机器人流程自动化)的强大工具,支持Android、Browser和Windows三大主流平台。无论您需要自动化安卓应用、浏览器操作还是Windows应用程序,Pyaibote都能轻松应对Pyaibote可以同时协作Windows、Web和Android平台机器人,满足您多样化的办公自动化需求......
  • 代码随想录算法训练营第26天 | 455.分发饼干 53. 最大子序和
    455.分发饼干假设你是一位很棒的家长,想要给你的孩子们一些小饼干。但是,每个孩子最多只能给一块饼干。对每个孩子i,都有一个胃口值g[i],这是能让孩子们满足胃口的饼干的最小尺寸;并且每块饼干j,都有一个尺寸s[j]。如果s[j]>=g[i],我们可以将这个饼干j分配给孩子i,这个孩子......
  • IPython的宏功能:批量执行代码块功能
    IPython的宏功能:批量执行代码块功能项目概述本项目旨在利用IPython的宏功能,通过批量执行代码块来简化和自动化常见的重复任务。IPython提供了记录和执行宏的功能,可以极大地提高开发效率。我们将创建一个示例项目,展示如何使用IPython宏功能批量执行代码块。项目结构ipyth......
  • 快手开源中英双语文本生成图像模型Kolors;漫画翻译工具Comic Translate;支持谷歌搜索、
    ✨1:KolorsKolors是基于潜在扩散的大规模中英双语文本生成图像模型。Kolors是由快手的Kolors团队开发的一种基于潜在扩散的文本到图像生成模型。它经过了数十亿对文本和图像数据的训练,在视觉质量、复杂语义准确性以及中文和英文文本渲染方面都表现出显著的优势。Kolo......
  • Asun安全学习【漏洞复现】CVE-2023-38831 WinRAR代码执行漏洞
    (ps:本人是小白,复现漏洞来进行安全学习,也借鉴了许多大佬的研究内容,感谢各位大佬进行指导和点评。)[漏洞名称]:CVE漏洞复现-CVE-2023-38831WinRAR代码执行漏洞[漏洞描述]:WinRAR是一款功能强大的Windows文件压缩和解压缩工具,支持高效的压缩算法、密码保护、分卷压缩、恢复记录等......
  • git合并代码方法
    你合并代码用merge还是用rebase?macrozheng 2024年07月08日14:10 江苏 1人听过 以下文章来源于古时的风筝 ,作者风筝古时的风筝.写代码是一种爱好,写文章是一种情怀。mall学习教程官网:macrozheng.com你们平时合并代码的时候用merge还是rebase?我问了......