首页 > 其他分享 >Scalable Diffusion Models with Transformers(DIT)代码笔记

Scalable Diffusion Models with Transformers(DIT)代码笔记

时间:2024-08-04 19:52:33浏览次数:14  
标签:Diffusion Transformers nn Models self mlp hidden adaLN size

完整代码来源:DiT
DiT模型主要是在diffusion中,使用transformer模型替换了UNet模型,使用class来控制图像生成。
根据论文,模型越大,patch size 越小,FID越小。
模型越大,参数越多,patch size越小,参与计算的信息就越多,模型效果越好。

在这里插入图片描述

模型使用了Imagenet 训练,有1000个分类,class_labels是0到999的整数,无条件类则必须是1000,在class embedding的时候定义了 nn.Embedding(1000+1,1152)。

在模型初始化的时候,DiTBlock和FinalLayer的参数中weight和bias被置为0,也就是adaLN_zero。
在DiTBlock 中,参数 c = t + y ,也就是 class embedding和 t embedding的和,然后通过线性映射生成shift,scale,和gate。


def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x


 # Zero-out adaLN modulation layers in DiT blocks:
 for block in self.blocks:
     nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
     nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
     
 # Zero-out output layers:
 nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
 nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
 nn.init.constant_(self.final_layer.linear.weight, 0)
 nn.init.constant_(self.final_layer.linear.bias, 0)

标签:Diffusion,Transformers,nn,Models,self,mlp,hidden,adaLN,size
From: https://blog.csdn.net/zhilaizhiwang/article/details/140889102

相关文章

  • 【人工智能】如何选择AI绘画工具?Midjourney VS Stable Diffusion
    文章目录......
  • How to pass multimodal data directly to models
    Howtopassmultimodaldatadirectlytomodelshttps://python.langchain.com/v0.2/docs/how_to/multimodal_inputs/Herewedemonstratehowtopassmultimodalinputdirectlytomodels.WecurrentlyexpectallinputtobepassedinthesameformatasOpenAIe......
  • Enhancing Question Answering for Enterprise Knowledge Bases using Large Language
    本文是LLM系列文章,针对《EnhancingQuestionAnsweringforEnterpriseKnowledgeBasesusingLargeLanguageModels》的翻译。使用大型语言模型增强企业知识库的问答能力摘要1引言2相关工作3前言4方法5实验6结论摘要高效的知识管理在提高企业和组......
  • Large Language Models meet Collaborative Filtering
    本文是LLM系列文章,针对《LargeLanguageModelsmeetCollaborativeFiltering:AnEfficientAll大型语言模型与协同过滤:一个高效的基于LLM的全方位推荐系统摘要1引言2相关工作3问题定义4提出的方法5实验6结论摘要协同过滤推荐系统(CFRecSys)在增强社......
  • GitHub Models服务允许开发人员免费查找和试用AI模型
    今天,GitHub宣布推出一项新服务–GitHubModels,允许开发人员免费查找和试用人工智能模型。它将领先的大型和小型语言模型的强大功能直接带给GitHub的1亿多用户。GitHub模型将提供对领先模型的访问,包括OpenAI的GPT-4o和GPT-4omini、微软的Phi3、Meta的Llama3.......
  • Pixel Aligned Language Models论文阅读笔记
    Motivation&Abs近年来,大语言模型在视觉方面取得了极大的进步,但其如何完成定位任务(如wordgrounding等)仍然不清楚。本文旨在设计一种模型能够将一系列点/边界框作为输入或者输出。当模型接受定位信息作为输入时,可以进行以定位为condition的captioning。当生成位置作为输出时,模型......
  • Modelsim仿真实现Verilog HDL序列检测器
    检测接收到的数字序列中出现“10011”的次数。例如输入序列为40位:1100_1001_1100_1001_0100_1100_1011_0010_1100_1011从最高位开始检测,出现了2次:1100_1001_1100_1001_0100_1100_1011_0010_1100_1011所以,序列检测器的计数结果应该是2。状态机如下:当前状态current_stat......
  • 劝你先别更新!!最新Stable Diffusion WebUI 1.10已来!WebUI终于支持SD3大模型了!你跑起来
    你的SD3大模型在SDWebUI1.10.0中跑起来了么?今天发现StableDiffusionWebUI于昨日推出了最新SDWebUI1.10.0版本。令人比较兴奋的是该版本支持了SD3大模型,同时也新增了DDIMCFG++采样器。主要更新内容如下:最新版本地址:更新后重启,可在WebUI设置中开启对T5文本的支持,......
  • 3.校验,格式化,ModelSerializer使用
    【一】反序列化校验1)三层校验字段自己校验直接写在字段类的属性上局部钩子在序列化中写validata_字段名全局钩子#serializers.pyclassBookSerializer(serializers.Serializer):#1)name字段的要大于1小于10name=serializers.CharField(min_length=......
  • 一文详解Denoising Diffusion Implicit Models(DDIM)
    目录0前言1DDIM2总结0前言  上一篇博文我们介绍了目前流行的扩散模型基石DDPM,并且给出了代码讲解,有不了解的小伙伴可以跳转到前面先学习一下。今天我们再来介绍下DDPM的改进版本。DDPM虽然对生成任务带来了新得启发,但是他有一个致命的缺点,就是推理速度比较慢,......