首页 > 其他分享 >(12-4-03)基于CPM中英双语多模态大模型的文生图系统:实现模型(3)基于Stable Diffusion的图像生成模型+多模态模型

(12-4-03)基于CPM中英双语多模态大模型的文生图系统:实现模型(3)基于Stable Diffusion的图像生成模型+多模态模型

时间:2024-12-11 16:31:58浏览次数:12  
标签:模态 文生 模型 states latents image hidden data self

12.5.5  基于Stable Diffusion的图像生成模型

文件stablediffusion.py用于实现一个基于Stable Diffusion的图像生成模型,此文件通过定义一个包含自动编码器、噪声调度器、UNet条件模型以及自定义转换块的SDWrapper类,支持图像的编码、添加噪声、降噪和解码过程。同时,该类还集成了图像安全检查器,用于检测生成图像中的不适宜内容。此外,SDWrapper类支持通过提供文本嵌入来指导图像生成,并能够在推理时根据给定的指导比例进行噪声预测和图像生成。

class CPMBeeTransBlock(torch.nn.Module):
    def __init__(
        self,
        dim_model=4096,
        dim_ff=1024,
        dim_out=768,
        dtype=torch.float,
        eps=1e-6,
        dropout_p=0,
    ):
        super().__init__()
        if dropout_p is not None:
            self.dropout = torch.nn.Dropout(dropout_p)
        else:
            self.dropout = None
        self.w_out_res = torch.nn.Linear(dim_model, dim_out, bias=False)
        self.layernorm = torch.nn.LayerNorm(
            dim_out,
            dtype=dtype,
            eps=eps,
        )

    def forward(self, hidden_states: torch.Tensor):
        x_res = self.w_out_res(hidden_states)
        if self.dropout is not None:
            x_res = self.dropout(x_res)
        hidden_states = self.layernorm(x_res)
        return hidden_states

class SDWrapper(torch.nn.Module):
    def __init__(self, image_safety_checker=True):
        super().__init__()
        self.vae = AutoencoderKL.from_pretrained('stabilityai/stable-diffusion-2-1-base', subfolder='vae')
        self.noise_scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-2-1-base', subfolder='scheduler')
        self.unet = UNet2DConditionModel.from_config(UNet2DConditionModel.load_config(
            'stabilityai/stable-diffusion-2-1-base', subfolder='unet'))

        self.trans_block = CPMBeeTransBlock(4096, 4096 // 4, self.unet.config.cross_attention_dim)
        if image_safety_checker:
            self.image_safety_checker = StableDiffusionSafetyChecker.from_pretrained(
                "CompVis/stable-diffusion-safety-checker")
            self.feature_extractor = CLIPImageProcessor.from_pretrained(
                "openai/clip-vit-base-patch32"
            )
        else:
            self.image_safety_checker = None
            self.feature_extractor = None
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

    def forward(self, pixel_values, text_hidden_states):
        pixel_values = pixel_values.type(text_hidden_states.dtype)
        latents = self.vae.encode(pixel_values).latent_dist.sample()
        latents = latents * self.vae.config.scaling_factor
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()
        noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
        text_hidden_states = text_hidden_states.type(noisy_latents.dtype)
        if self.trans_block is not None:
            text_hidden_states = self.trans_block(text_hidden_states)
        model_pred = self.unet(noisy_latents, timesteps, text_hidden_states).sample
        loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
        return loss, model_pred

    @torch.no_grad()
    def generate(self,
                 text_hidden_states,
                 uncond_text_hidden_states,
                 height=None,
                 width=None,
                 num_inference_steps=50,
                 guidance_scale=7.5,
                 num_images_per_prompt=1,
                 generator=None,
                 latents=None,
                 scheduler=None,
                 output_type='pil'
                 ):
        device = text_hidden_states.device
        batch_size = text_hidden_states.size(0)
        text_hidden_states = text_hidden_states.type(self.unet.conv_in.weight.dtype)
        uncond_text_hidden_states = uncond_text_hidden_states.type(self.unet.conv_in.weight.dtype)
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor
        if scheduler is not None:
            self.noise_scheduler = scheduler
        self.noise_scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.noise_scheduler.timesteps
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            text_hidden_states.dtype,
            device,
            generator,
            latents,
        )
        if self.trans_block is not None:
            text_hidden_states = self.trans_block(text_hidden_states)
            uncond_text_hidden_states = self.trans_block(uncond_text_hidden_states)
        bs_embed, seq_len, _ = text_hidden_states.shape
        text_hidden_states = text_hidden_states.repeat(1, num_images_per_prompt, 1)
        text_hidden_states = text_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1)
        bs_embed, seq_len, _ = uncond_text_hidden_states.shape
        uncond_text_hidden_states = uncond_text_hidden_states.repeat(1, num_images_per_prompt, 1)
        uncond_text_hidden_states = uncond_text_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1)
        text_hidden_states = torch.cat([uncond_text_hidden_states, text_hidden_states])
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
            # predict the noise residual
            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=text_hidden_states,
            ).sample
            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.noise_scheduler.step(noise_pred, t, latents, generator=generator).prev_sample
        image = self.decode_latents(latents)
        # Run safety checker
        image, has_nsfw_concept = self.run_image_safety_checker(image, device, self.unet.conv_in.weight.dtype)
        if output_type == 'pil':
            image = utils.numpy_to_pil(image)
        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

    def decode_latents(self, latents):
        latents = 1 / self.vae.config.scaling_factor * latents
        image = self.vae.decode(latents).sample
        image = (image / 2 + 0.5).clamp(0, 1)
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        return image
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if latents is None:
            latents = utils.randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            latents = latents.to(device)

        # scale the initial noise by the standard deviation required by the noise_scheduler
        latents = latents * self.noise_scheduler.init_noise_sigma
        return latents

    def numpy_to_pil(self, images):
        """
        Convert a numpy image or a batch of images to a PIL image.
        """
        if images.ndim == 3:
            images = images[None, ...]
        images = (images * 255).round().astype("uint8")
        if images.shape[-1] == 1:
            # special case for grayscale (single channel) images
            pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
        else:
            pil_images = [Image.fromarray(image) for image in images]

        return pil_images

    def run_image_safety_checker(self, image, device, dtype):
        if self.image_safety_checker is not None:
            image_safety_checker_input = self.feature_extractor(
                self.numpy_to_pil(image), return_tensors="pt").to(device)
            image, has_nsfw_concept = self.image_safety_checker(
                images=image, clip_input=image_safety_checker_input.pixel_values.to(dtype)
            )
            if any(has_nsfw_concept):
                print(
                    "Potential NSFW content was detected in one or more images. A black image will be returned instead."
                )
                for idx, _has_nsfw_concept in enumerate(has_nsfw_concept):
                    if _has_nsfw_concept:
                        image[idx] = np.zeros(image[idx].shape)  # black image
        else:
            has_nsfw_concept = None
        return image, has_nsfw_concept

对上述代码的具体说明如下所示:

  1. 类CPMBeeTransBlock的功能是定义一个转换块,该块包含一个线性输出层和一个层归一化层,用于处理隐藏状态并提高模型的表现能力。
  2. 类SDWrapper的功能是实现Stable Diffusion模型的封装,集成多个组件(如自动编码器、噪声调度器和条件UNet模型),支持图像编码、噪声添加、解码和安全检查,并能够通过提供文本嵌入进行图像生成和控制。
  3. 方法forward的功能是接受输入的像素值和文本隐藏状态,执行图像编码、添加噪声,并通过UNet模型进行噪声预测,返回损失值和模型预测结果。
  4. 方法generate的功能是根据提供的文本隐藏状态生成图像,支持条件生成、无条件生成,并通过多次迭代和噪声调度来逐步生成最终图像。
  5. 方法decode_latents的功能是将潜在空间中的表示解码为图像,通过自动编码器进行反向处理,并对图像进行归一化以准备输出。
  6. 方法prepare_latents的功能是准备潜在表示的初始随机噪声,确定形状并根据输入生成或调整潜在变量,确保它们符合噪声调度器的要求。
  7. 方法numpy_to_pil的功能是将NumPy数组格式的图像转换为PIL图像格式,以便进行后续处理和展示。
  8. 方法run_image_safety_checker的功能是对生成的图像进行安全检查,识别潜在的不当内容,并在检测到不当内容时返回黑色图像作为替代。

12.5.6  多模态模型

本项目的核心功能通过两个相互结合的模块实现,它们分别负责处理视觉和语言任务。第一个模块专注于视觉内容的生成,能够根据输入的文本描述生成对应的图像,支持多样化的视觉输出。第二个模块则侧重于理解和解析图像内容,能够基于上传的图像生成相关的文本描述或回答。通过这两个模块的紧密配合,项目实现了高效的多模态交互,用户可以输入文本或上传图像,并获得相应的视觉生成或文本理解结果,从而实现流畅的文生图和图像识别体验。

(1)文件vlg_cpmbee.py定义了类VLG_CPMBee,其功能是将大语言模型(LLM)与稳定扩散(Stable Diffusion)模型结合,通过forward方法处理输入数据并计算损失和模型预测,同时提供generate方法,根据输入和无条件数据生成图像,利用LLM的隐藏状态增强生成过程。

class VLG_CPMBee(torch.nn.Module):
    def __init__(self, llm, sd) -> None:
        super().__init__()
        self.sd = sd
        self.llm = llm

    def forward(self, data):
        device = data['input_ids'].device
        bs = data['input_ids'].size(0)

        llm_hidden_state = self.llm.input_embedding(data['input_ids'], data['input_id_subs'])

        _, hidden_states = self.llm(
            input=data['input_ids'],
            input_sub=data['input_id_subs'],
            length=data['length'],
            context=data['context'],
            sample_ids=data['sample_ids'],
            num_segments=data['num_segments'],
            segment=data['segment_ids'],
            segment_rel_offset=data['segment_rel_offset'],
            segment_rel=data['segment_rel'],
            span=data['span'],
            ext_table_ids=data['ext_table_ids'],
            ext_table_sub=data['ext_table_sub'],
            hidden_states=llm_hidden_state
        )
        loss, model_pred = self.sd(data['pixel_values'], hidden_states)
        return loss, model_pred

    @torch.no_grad()
    def generate(
        self,
        data,
        uncond_data,
        **generate_kwargs,
    ):
        device = data['input_ids'].device
        bs = data['input_ids'].size(0)
        with torch.no_grad():
            llm_hidden_state = self.llm.input_embedding(data['input_ids'], data['input_id_subs'])
            _, hidden_states = self.llm(
                input=data['input_ids'],
                input_sub=data['input_id_subs'],
                length=data['length'],
                context=data['context'],
                sample_ids=data['sample_ids'],
                num_segments=data['num_segments'],
                segment=data['segment_ids'],
                segment_rel_offset=data['segment_rel_offset'],
                segment_rel=data['segment_rel'],
                span=data['span'],
                ext_table_ids=data['ext_table_ids'],
                ext_table_sub=data['ext_table_sub'],
                hidden_states=llm_hidden_state
            )

        with torch.no_grad():
            uncond_llm_hidden_state = self.llm.input_embedding(uncond_data['input_ids'], uncond_data['input_id_subs'])

            _, uncond_hidden_states = self.llm(
                input=uncond_data['input_ids'],
                input_sub=uncond_data['input_id_subs'],
                length=uncond_data['length'],
                context=uncond_data['context'],
                sample_ids=uncond_data['sample_ids'],
                num_segments=uncond_data['num_segments'],
                segment=uncond_data['segment_ids'],
                segment_rel_offset=uncond_data['segment_rel_offset'],
                segment_rel=uncond_data['segment_rel'],
                span=uncond_data['span'],
                ext_table_ids=uncond_data['ext_table_ids'],
                ext_table_sub=uncond_data['ext_table_sub'],
                hidden_states=uncond_llm_hidden_state
            )
        image = self.sd.generate(
            hidden_states,
            uncond_hidden_states,
            **generate_kwargs
        )
        return image

(2)文件vlu_cpmbee.py实现了一个结合大语言模型(LLM)和视觉模型(VPM)的处理模块VLU_CPMBee,该模块通过construct_query_parameter函数生成查询参数,并在get_vllm_embedding方法中提取视觉特征,将其与语言输入的嵌入结合。forward方法执行前向传播,利用提取的视觉隐藏状态和语言输入生成逻辑回归输出及隐藏状态,从而支持多模态任务的处理。

def construct_query_parameter(query_k, h_size, init_weights):
    query_data = torch.zeros(query_k, h_size)
    trunc_normal_(query_data, std=.02)
    for idx in range(query_k):
        if init_weights[idx] is not None:
            query_data[idx] = init_weights[idx]
    query = torch.nn.Parameter(query_data)
    return query

@dataclass
class CausalVLLMOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

class VLU_CPMBee(torch.nn.Module):
    def __init__(self, llm: CPMBeeTorch, vpm, vision_dim, query_num, device=None) -> None:
        super().__init__()
        self.device = device
        self.vpm = vpm
        self.llm = llm
        self.vision_dim = vision_dim
        self.query_num = query_num
        self.query = None
        if query_num is not None:
            bos_weight = self.vpm.beit3.text_embed.weight.data[0]
            eos_weight = self.vpm.beit3.text_embed.weight.data[2]
            query_init_weight = [bos_weight] + [None] * (self.query_num - 2) + [eos_weight]
            self.query = construct_query_parameter(
                self.query_num, self.vision_dim, query_init_weight)
        self.mapping = torch.nn.Sequential(
            torch.nn.Linear(self.vpm.hidden_size, self.llm.config.dim_model),
            torch.nn.GELU(),
            torch.nn.Linear(self.llm.config.dim_model, self.llm.config.dim_model)
        )

    def get_vllm_embedding(self, data):
        if 'vision_hidden_states' not in data:
            pixel_values = data['pixel_values']
            vision_hidden_states = self.vpm(pixel_values=pixel_values, query_embed=self.query)
            vision_hidden_states = self.mapping(vision_hidden_states)  # (query_num, llm_dim)
        else:
            vision_hidden_states = data['vision_hidden_states']

        vllm_embedding = self.llm.input_embedding(data['input_ids'], data['input_id_subs'])
        vision_hidden_states = vision_hidden_states.type(vllm_embedding.dtype)
        image_bound = data['image_bound']
        image_bound = image_bound.squeeze(1)
        image_indices = torch.stack([torch.arange(r[0], r[1], dtype=torch.long) for r in image_bound]).to(self.device)
        vllm_embedding.scatter_(1, image_indices.unsqueeze(-1).repeat(1, 1, vllm_embedding.shape[-1]),vision_hidden_states)
        return vllm_embedding, vision_hidden_states

    def forward(self, data, **kwargs):
        vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
        logits, hidden_states = self.llm(
            input=data['input_ids'],
            input_sub=data['input_id_subs'],
            length=data['length'],
            context=data['context'],
            sample_ids=data['sample_ids'],
            num_segments=data['num_segments'],
            segment=data['segment_ids'],
            segment_rel_offset=data['segment_rel_offset'],
            segment_rel=data['segment_rel'],
            span=data['span'],
            ext_table_ids=data['ext_table_ids'],
            ext_table_sub=data['ext_table_sub'],
            hidden_states=vllm_embedding
        )
        return CausalVLLMOutput(
            logits=logits,
            hidden_states=hidden_states,
            vision_hidden_states=vision_hidden_states
        )

文件vlg_cpmbee.py侧重于图像生成,而文件vlu_cpmbee.py则注重文本和视觉信息的融合与处理。两者的主要区别如下所示:

  1. vlg_cpmbee.py:主要关注于文本和图像之间的交互,使用大语言模型(LLM)和图像生成模型(如Stable Diffusion)结合,以实现图像生成或图像处理任务。它的forward方法处理输入数据并生成损失和预测的图像。
  2. vlu_cpmbee.py:则专注于视觉信息的处理,结合视觉模型(VPM)和大语言模型(LLM),用于多模态任务。它通过提取视觉特征并将其与语言输入的嵌入结合,生成与视觉相关的逻辑回归输出。

标签:模态,文生,模型,states,latents,image,hidden,data,self
From: https://blog.csdn.net/asd343442/article/details/144403662

相关文章

  • 在STM32上运行KWS之三 模型搭建与训练测试
    此篇文章在2023年2月6日被记录搭建一个怎样的模型KWS模型结构属于比较简单的模型结构,但是为了少走弯路,我计划使用现成的结构,我从这个演示视频参考而来:点击我跳转,这个KWS项目运行在AT32F403上,其网络模型结构为一个64个特征的普通卷积层,然后重复四次的DS-CNN卷积,在每次卷积后都......
  • Stable Diffusion文生图技术详解
    StableDiffusion(SD模型),由StabilityAI与LAION等机构合作研发,是一款功能强大的生成式模型,拥有约10亿(1B)参数。其应用广泛,包括但不限于文生图(txt2img)、图生图(img2img)及图像修复(inpainting)等功能。 这张图上来就给人看晕了,这个结构其实是LatentDiffusion(StableDiffusion的前身)......
  • 理论+实操(全文python代码)—— 01 详解RFM模型,让你知道你咋被大数据杀熟的。
    大家好,我是摇光~,用大白话讲解所有你难懂的知识点最近在和同事讨论大数据杀熟,其实就是网络上说的,你在网络上没有秘密~很多企业都用上了数据分析,为客户归类,再对你们进行大数据杀熟!!今天我们就来破解一下这个大数据杀熟的一种模型——RFM模型。接下来我将从理论到实操,来给......
  • 中国网络空间安全协会发布用于大模型的首批中文基础语料库
    中文基础语料库页面截图。澎湃新闻从中国网络空间安全协会获悉,12月20日,中国网络空间安全协会人工智能安全治理专业委员会在北京发布了用于大模型的首批中文基础语料库。中国网络空间安全协会相关负责人介绍,在中央网信办相关业务部门指导下,网安协会人工智能安全治理专委会会......
  • 大模型备案详细解析与流程指导
    随着人工智能技术的飞速发展,大模型在语音识别、图像处理、自然语言处理等领域的应用日益广泛。为确保AI技术的健康发展和市场的公平竞争,我国出台了一系列法律法规,对大模型进行备案管理。本文将对大模型备案进行全网最详细的说明,涵盖背景、目的、流程、所需材料、安全评估要点及......
  • 从技术的角度来看大模型产业链
    “分工合作才是社会进步的根本,而大模型也是如此,每个人都应该在大模型领域找到自己的位置”大模型技术虽然指的是以机器学习和神经网络模型为主的机器学习模型,但大模型并不只是一个独立的技术,由于其复杂度和对算力和能源的需求,因此大模型整个体系由多层功能组成,而这些层就......
  • 以Qwen2-VL为例说明模型训练过程
    以Qwen2-VL为例说明模型训练过程flyfish日志{"best_metric":null,"best_model_checkpoint":null,"epoch":0.8001066808907854,"eval_steps":500,"global_step":1500,"is_hyper_param_search":fals......
  • 大模型技术学习过程梳理
    “学习是一个从围观到宏观,从宏观到微观的一个过程”学习大模型技术也有几个月的时间了,之前的学习一直是东一榔头,西一棒槌,这学一点那学一点,虽然弄的乱七八糟,但对大模型技术也算有了一个初步的认识。因此,今天就来整体梳理一下大模型技术的框架,争取从大模型所涉及的理论,技......
  • 如何使用aws的bedrock训练适合自己的模型
    使用Amazon Bedrock训练适合自己的模型是一个涉及多个步骤的过程,包括数据准备、模型选择、训练和部署。以下是九河云总结的详细的步骤指南,帮助您在AWSBedrock上训练和部署自定义模型。###1.**准备工作**在开始训练模型之前,您需要完成一些准备工作。####1.1创建AW......
  • 初识Linux · 编写生产消费模型(2)
    目录前言:RingQueue编写生产消费模型认识接口开始编写前言:前文我们介绍了基于阻塞队列实现生产消费模型,使用阻塞队列实现生产消费模型中,我们学习到了pthread_cond_wait的第二个参数的重要性,不仅会解锁,此时锁被其他人持有,当条件满足的时候,就重新竞争锁,所以在pthread_cond_......