首页 > 其他分享 >stable diffusion中的UNet2DConditionModel代码解读

stable diffusion中的UNet2DConditionModel代码解读

时间:2024-07-28 12:25:55浏览次数:18  
标签:diffusion self UNet2DConditionModel states Transformer2DModel ResnetBlock2D stab

UNet2DConditionModel总体结构
在这里插入图片描述
图片来自于 https://zhuanlan.zhihu.com/p/635204519

stable diffusion 运行unet部分的代码。

noise_pred = self.unet(
    sample=latent_model_input,  #(2,4,64,64) 生成的latent
    timestep=t,  #时刻t
    encoder_hidden_states=prompt_embeds, #(2,77,768) #输入的prompt和negative prompt 生成的embedding
    timestep_cond=timestep_cond,#默认空
    cross_attention_kwargs=self.cross_attention_kwargs, #默认空
    added_cond_kwargs=added_cond_kwargs, #默认空
    return_dict=False,
)[0]

1.time

get_time_embed使用了sinusoidal timestep embeddings,
time_embedding 使用了两个线性层和激活层进行映射,将320转换到1280。
如果还有class_labels,added_cond_kwargs等参数,也转换为embedding,并且相加。

t_emb = self.get_time_embed(sample=sample, timestep=timestep)  #(2,320)
emb = self.time_embedding(t_emb, timestep_cond)  #(2,1280)

2.pre-process

卷积转换,输入latent从(2,4,64,64) 到(2,320,64,64)

sample = self.conv_in(sample)  #(2,320,64,64)
self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )

3.down

down_block 由三个CrossAttnDownBlock2D和一个DownBlock2D组成。输入包括:
hidden_states:latent
temb:时刻t的embdedding
encoder_hidden_states:prompt和negative prompt的embedding

网络结构

CrossAttnDownBlock2D( 
    ResnetBlock2D()
    Transformer2DModel()
    ResnetBlock2D()
    Transformer2DModel()   
    Downsample2D()  #(2,320,32,32)
)
CrossAttnDownBlock2D( 
    ResnetBlock2D()
    Transformer2DModel()
    ResnetBlock2D()
    Transformer2DModel()   
    Downsample2D()  #(2,640,16,16)
)
CrossAttnDownBlock2D( 
    ResnetBlock2D()
    Transformer2DModel()
    ResnetBlock2D()
    Transformer2DModel()   
    Downsample2D()  #(2,1280,8,8)
)
DownBlock2D(
    ResnetBlock2D()
    ResnetBlock2D()  #(2,1280,8,8)
)

4.mid

UNetMidBlock2DCrossAttn 包含 resnet,attn,resnet三个模块,输入输出维度不变。输入包括:
hidden_states:latent
temb,时刻t的embdedding
encoder_hidden_states:prompt和negative prompt的embedding

UNetMidBlock2DCrossAttn(
    ResnetBlock2D()
    Transformer2DModel()
    ResnetBlock2D()
 )

5.up

up由一个UpBlock2D和三个CrossAttnUpBlock2D组成,输入包括:
hidden_states:latent
temb: 时刻t的embdedding
encoder_hidden_states:prompt和negative prompt的embedding
res_hidden_states_tupleL:下采样时记录的残差结果。

UpBlock2D(
    ResnetBlock2D()
    ResnetBlock2D()
    ResnetBlock2D()
    Upsample2D()  #(2,1280,16,16)
)
CrossAttnUpBlock2D( 
    ResnetBlock2D()
    Transformer2DModel()
    ResnetBlock2D()
    Transformer2DModel()   
    ResnetBlock2D()
    Transformer2DModel() 
    Downsample2D()  #(2,1280,32,32)
)
CrossAttnUpBlock2D( 
    ResnetBlock2D()
    Transformer2DModel()
    ResnetBlock2D()
    Transformer2DModel()   
    ResnetBlock2D()
    Transformer2DModel() 
    Downsample2D()  #(2,640,64,64)
)
CrossAttnUpBlock2D( 
    ResnetBlock2D() #(2,320,64,64)
    Transformer2DModel() 
    ResnetBlock2D()
    Transformer2DModel()   
    ResnetBlock2D()
    Transformer2DModel() 
)  

6.post-process

卷积变换通道数,得到最终结果

 if self.conv_norm_out:
     sample = self.conv_norm_out(sample)
     sample = self.conv_act(sample)
 sample = self.conv_out(sample) #(2,4,64,64)

时刻t,类别class等参数作用在resnet部分,都是和输入直接相加。
由prompt,negative prompt 计算得到的encoder_hidden_states,作用在attention部分,作为key和value,参与计算。

ResnetBlock2D

x在标准化、激活、卷积之后,和temb相加,再次标准化、激活、卷积之后作为残差,与x相加。

hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states) #激活函数
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
    if not self.skip_time_act:
        temb = self.nonlinearity(temb)
    temb = self.time_emb_proj(temb)[:, :, None, None] #(2,320,1,1)

if self.time_embedding_norm == "default":
    if temb is not None:
        hidden_states = hidden_states + temb  #与temb相加
    hidden_states = self.norm2(hidden_states)             
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

return output_tensor

Transformer2DModel attentions部分

每个attention 包括 Self-Attention 和Cross-Attention两部分。

#Self-Attention ,encoder_hidden_states=None
attn_output = self.attn1(
    norm_hidden_states,
    encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
    attention_mask=attention_mask,
    **cross_attention_kwargs,
       )

#Cross-Attention,encoder_hidden_states由prompt计算得来,在这里和latent交互。
attn_output = self.attn2(
    norm_hidden_states,
    encoder_hidden_states=encoder_hidden_states,
    attention_mask=encoder_attention_mask,
    **cross_attention_kwargs,
)

#query由norm_hidden_states计算而来,
#key、value由encoder_hidden_states计算而来。
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  #(2,8,4096,40)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  #(2,8,77,40)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) #(2,8,77,40)
hidden_states = F.scaled_dot_product_attention(
    query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False  #(2,8,4096,40)
)

参考:stable diffusion 中使用的 UNet 2D Condition Model 结构解析(diffusers库)

标签:diffusion,self,UNet2DConditionModel,states,Transformer2DModel,ResnetBlock2D,stab
From: https://blog.csdn.net/zhilaizhiwang/article/details/140683649

相关文章

  • Stable Diffusion 改变光线的能力简直太强大了!
    在没有StableDiffusion的年代,对照片的光线进行后期处理,基本要依靠Photoshop。比如添加一个曝光图层。这个技术对于形状简单的物体来说很方便,因为光线效果很好模拟。但对于形状复杂的主体,比如人来说,要想实现自然的光线效果,你最好得有美术功底,并配备一个数位板。Stable......
  • 万字长文,带你从0-1入门Stable Diffusion
    一、本地部署StableDiffusion前言目前市面上比较权威,并能用于工作中的AI绘画软件其实就两款。一个叫Midjourney(简称MJ),另一个叫Stable-Diffusion(简称SD)。MJ需要付费使用,而SD开源免费,但是上手难度和学习成本略大,并且非常吃电脑配置(显卡、内存)。和Midjourney相比,StableD......
  • Diffusion|DDPM 理解、数学、代码
    Diffusion论文:DenoisingDiffusionProbabilisticModels参考博客openinnewwindow;参考paddle版本代码:aistudio实践链接openinnewwindow该文章主要对DDPM论文中的公式进行小白推导,并根据ppdiffuser进行DDPM探索。读者能够对论文中的大部分公式如何得来,用在了什么......
  • Diffusion|DDIM 理解、数学、代码
    DIFFUSION系列笔记|DDIM数学、思考与ppdiffuser代码探索论文:DENOISINGDIFFUSIONIMPLICITMODELS参考博客openinnewwindow;参考aistudionotebook链接,其中包含详细的公式与代码探索:linkopeninnewwindow该文章主要对DDIM论文中的公式进行小白推导,同时笔者将使用......
  • Stable Diffusion整合包安装教程你值得拥有!!!(附安装包)
    Stabledifusion是一个开源的模型,开源=公开=免费,意味着你可以把这个模型下载到你自己的电脑上或者服务器上面畅玩没有审核人员卡你图片是否有问题,随意出图。01、电脑配置相关知识我们先来看看安装StableDiffusion整合包的需要的电脑配置:电脑配置需求:操作系统:windows......
  • Stable Diffusion(AI绘画)软件安装包下载及安装教程!
    软件介绍StableDiffusion简称(SD)是一款开源的AI绘画软件,基于LatentDiffusionModel(文转图合成技术),能够根据文本描述或图像提示生成生成高质量、高分辨率、高逼真的图像。StableDiffusion由于开源属性,有很多免费高质量的外接预训练模型(fine-tune)和插件。软件:StableDiffu......
  • 昇思25天学习打卡营第22天|Diffusion扩散模型
    ☀️第22天学习应用实践/生成式/Diffusion扩散模型1.DiffusionModel简介如果将Diffusion与其他生成模型(如NormalizingFlows、GAN或VAE)进行比较,它并没有那么复杂,它们都将噪声从一些简单分布转换为数据样本,Diffusion也是从纯噪声开始通过一个神经网络学习逐步去噪,最......
  • 刚刚!Stable diffusion 4.8+ComfyUI升级版终于来了!(一键安装包,感谢大佬)
    如果这个世界有上帝,那么祂一定是程序员。国内SD绘画启动器第一人是我认为是B站的秋葉aaaki因为制作了这款StableDiffusion启动器,降低了国内使用SD的门槛且分文不收,秋叶被粉丝戏称赛博菩萨。1背景信息▍****StableDiffusion是什么?StableDiffusion(简称SD)是一种生......
  • stable diffusion文生图代码解读
    使用diffusers运行stablediffusion,文生图过程代码解读。只按照下面这种最简单的运行代码,省略了一些参数的处理步骤。fromdiffusersimportDiffusionPipelinepipeline=DiffusionPipeline.from_pretrained(MODEL_PATH,torch_dtype=torch.float16)pipeline.to("cuda......
  • Enhancing Diffusion Models with Reinforcement Learning
    EnhancingDiffusionModelswithReinforcementLearningSep27,2023 | UncategorizedTL;DRTodaywe'regoingtotellyouallabout DRLX -ourlibraryforDiffusionReinforcementLearning!Releasedafewweeksago,DRLXisalibraryforscalabledist......