首页 > 其他分享 >Open-Sora代码详细解读(2):时空3D VAE

Open-Sora代码详细解读(2):时空3D VAE

时间:2024-09-14 22:26:28浏览次数:3  
标签:downsample self VAE pad Sora time Open size

Diffusion Models视频生成

前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。

目录

3D VAE原理

代码剖析

2D VAE

时间VAE

因果3D卷积


3D VAE原理

之前绝大多数都是2D VAE,特别是SDXL的VAE相当好用,很多人都拿来直接用了。但是在DiT-based的模型中,时间序列上如果再不做压缩的话,就已经很难训得动了。因此非常有必要在时间序列上进行压缩,3D VAE应运而生。

Open-Sora的方案是在2D VAE的基础上,再添加一个时间VAE,相比于EasyAnimate 和 CogVideoX的方案的Full Attention 存在劣势,但是可以充分利用到2D VAE的权重,成本更低。

代码剖析

2D VAE

来自华为pixart sdxl vae:

    vae_2d = dict(
        type="VideoAutoencoderKL",
        from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
        subfolder="vae",
        micro_batch_size=micro_batch_size,
        local_files_only=local_files_only,
    )

时间VAE

    vae_temporal = dict(
        type="VAE_Temporal_SD",
        from_pretrained=None,
    )
@MODELS.register_module()
class VAE_Temporal(nn.Module):
    def __init__(
        self,
        in_out_channels=4,
        latent_embed_dim=4,
        embed_dim=4,
        filters=128,
        num_res_blocks=4,
        channel_multipliers=(1, 2, 2, 4),
        temporal_downsample=(True, True, False),
        num_groups=32,  # for nn.GroupNorm
        activation_fn="swish",
    ):
        super().__init__()

        self.time_downsample_factor = 2 ** sum(temporal_downsample)
        # self.time_padding = self.time_downsample_factor - 1
        self.patch_size = (self.time_downsample_factor, 1, 1)
        self.out_channels = in_out_channels

        # NOTE: following MAGVIT, conv in bias=False in encoder first conv
        self.encoder = Encoder(
            in_out_channels=in_out_channels,
            latent_embed_dim=latent_embed_dim * 2,
            filters=filters,
            num_res_blocks=num_res_blocks,
            channel_multipliers=channel_multipliers,
            temporal_downsample=temporal_downsample,
            num_groups=num_groups,  # for nn.GroupNorm
            activation_fn=activation_fn,
        )
        self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)

        self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)
        self.decoder = Decoder(
            in_out_channels=in_out_channels,
            latent_embed_dim=latent_embed_dim,
            filters=filters,
            num_res_blocks=num_res_blocks,
            channel_multipliers=channel_multipliers,
            temporal_downsample=temporal_downsample,
            num_groups=num_groups,  # for nn.GroupNorm
            activation_fn=activation_fn,
        )

    def get_latent_size(self, input_size):
        latent_size = []
        for i in range(3):
            if input_size[i] is None:
                lsize = None
            elif i == 0:
                time_padding = (
                    0
                    if (input_size[i] % self.time_downsample_factor == 0)
                    else self.time_downsample_factor - input_size[i] % self.time_downsample_factor
                )
                lsize = (input_size[i] + time_padding) // self.patch_size[i]
            else:
                lsize = input_size[i] // self.patch_size[i]
            latent_size.append(lsize)
        return latent_size

    def encode(self, x):
        time_padding = (
            0
            if (x.shape[2] % self.time_downsample_factor == 0)
            else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor
        )
        x = pad_at_dim(x, (time_padding, 0), dim=2)
        encoded_feature = self.encoder(x)
        moments = self.quant_conv(encoded_feature).to(x.dtype)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, z, num_frames=None):
        time_padding = (
            0
            if (num_frames % self.time_downsample_factor == 0)
            else self.time_downsample_factor - num_frames % self.time_downsample_factor
        )
        z = self.post_quant_conv(z)
        x = self.decoder(z)
        x = x[:, :, time_padding:]
        return x

    def forward(self, x, sample_posterior=True):
        posterior = self.encode(x)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        recon_video = self.decode(z, num_frames=x.shape[2])
        return recon_video, posterior, z

因果3D卷积

class CausalConv3d(nn.Module):
    def __init__(
        self,
        chan_in,
        chan_out,
        kernel_size: Union[int, Tuple[int, int, int]],
        pad_mode="constant",
        strides=None,  # allow custom stride
        **kwargs,
    ):
        super().__init__()
        kernel_size = cast_tuple(kernel_size, 3)

        time_kernel_size, height_kernel_size, width_kernel_size = kernel_size

        assert is_odd(height_kernel_size) and is_odd(width_kernel_size)

        dilation = kwargs.pop("dilation", 1)
        stride = strides[0] if strides is not None else kwargs.pop("stride", 1)

        self.pad_mode = pad_mode
        time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
        height_pad = height_kernel_size // 2
        width_pad = width_kernel_size // 2

        self.time_pad = time_pad
        self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)

        stride = strides if strides is not None else (stride, 1, 1)
        dilation = (dilation, 1, 1)
        self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)

    def forward(self, x):
        x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
        x = self.conv(x)
        return x

标签:downsample,self,VAE,pad,Sora,time,Open,size
From: https://blog.csdn.net/qq_41895747/article/details/141791528

相关文章

  • 新电脑安装和配置pytorch、anaconda、CUDA、cuDNN、pycharm、OpenCV的过程记录
    显卡驱动和CUDA一、升级显卡驱动到官方最新版    1、打开英伟达官网,输入显卡芯片型号,手动搜索并下载显卡驱动。 NVIDIA官方驱动 ​    2、下载完成后安装驱动。 二、确认显卡支持的最高CUDA版本    1、键盘"win+R",调出运行输入cmd后点”......
  • Openwrt挂载大容量NTFS磁盘,解决默认挂载不能写入问题,实现samba共享
    1.目的在OpenWRT上挂载上大容量磁盘(NTFS文件系统),支持读写操作。解决默认挂载仅能读取,不能写入问题。配置Samba36,实现局域网文件夹共享。2.环境架构ARMv7Processorrev1(v7l)固件版本OpenWrt19.07.7内核版本3.10.33安装插件smbd-v#检查是否安装了smbopk......
  • OpenAI的ChatGPT各个模型有什么区别?
    ChatGPT版本历史/区别特点对比以下是OpenAI 公司ChatGPT 各主要模型版本的详细描述,说明了每个版本中的显著变化:GPT-3.5发布日期:2022年11月描述:GPT-3.5是第一个用于ChatGPT的版本,基于GPT-3.5模型。此版本在准确性和理解能力上有所提升,但仍在GPT-3的基础......
  • openEuler22.03关闭交换分区swap失败处理
    在架设很多上层应用系统时会遇到很多需要关闭swap的操作,例如安装Kubernetes节点。通常的做法是在/etc/fstab文件中注销swap分区的挂载,但是没有起作用,运行free-h还是能看见挂载的swap,而通过命令sudoswapoff-a&&sudosystemctlrestartkubelet.service是能够关闭并成功启......
  • OpenAI o1模型:偏科的理科生
    LLM需要增强的地方大模型的三大基础能力:• 语言理解和表达能力:GPT-3已解决• 世界知识存储:GPT-4已经解决了不少• 逻辑推理能力:是最薄弱的环节,o1模型在这方面有明显的进步。原理o1模型增强逻辑推理能力的思路是:收到问题后,自动生成CoT,再生成答案。避免人类写基于于CoT的Prompt。......
  • OpenCV(cv::split())
    目录1.函数定义2.工作原理3.示例4.使用场景5.注意事项cv::split()是OpenCV提供的一个函数,用于将多通道图像分割成其各个单通道。该函数主要用于处理彩色图像和多通道矩阵,通常用于对图像中的每个颜色通道单独进行处理。1.函数定义voidcv::split(constMat&src,s......
  • OpenAI 的 o1 与 GPT-4o:深入探究 AI 的推理革命
    简介在不断发展的人工智能领域,OpenAI再次凭借其最新产品突破界限:o1模型和GPT-4o。作为一名几十年来一直报道科技的人,我见过不少伪装成革命的增量更新。但这个?这不一样。让我们拨开炒作的迷雾,看看这些新模型到底带来了什么。推荐文章《AI交通管理系列之使用Python......
  • OpenAI 的 GPT-o1(GPT5)详细评论 OpenAI 的 Strawberry 项目具有博士级智能
    简介OpenAI的GPT-5又名Strawberry项目,又名GPT-o1,又名博士级LLM现已推出。几个月来一直备受关注,从结果来看,它不负众望。OpenAI-o1是一系列模型,旨在增强科学、编码和数学等复杂领域的问题解决能力。推荐文章《AI交通管理系列之使用Python进行现代路线优化最......
  • WebGIS开发必学开源框架Openlayers(附赠视频教程+电子书)
    WebGIS开发之Openlayers当前,WebGIS开发领域的流行度不断攀升,导致市场上对该技能的需求与供应之间存在一定的紧张关系。在众多WebGIS开源框架中,如OpenLayers、Leaflet、MapBox、MapFish、GeoServer、GeoEXT和MapInfo等,企业通常期望应聘者能够掌握至少一种框架的开发能力,例如Op......
  • OpenSSL证书通过Subject Alternative Name扩展字段扩展证书支持的域名
    1、概述1.1什么是SubjectAlternativeName(证书主体别名)SAN(SubjectAlternativeName)是SSL标准x509中定义的一个扩展。它允许一个证书支持多个不同的域名。通过使用SAN字段,可以在一个证书中指定多个DNS名称(域名)、IP地址或其他类型的标识符,这样证书就可以同时用于多......