12.5.5 基于Stable Diffusion的图像生成模型
文件stablediffusion.py用于实现一个基于Stable Diffusion的图像生成模型,此文件通过定义一个包含自动编码器、噪声调度器、UNet条件模型以及自定义转换块的SDWrapper类,支持图像的编码、添加噪声、降噪和解码过程。同时,该类还集成了图像安全检查器,用于检测生成图像中的不适宜内容。此外,SDWrapper类支持通过提供文本嵌入来指导图像生成,并能够在推理时根据给定的指导比例进行噪声预测和图像生成。
class CPMBeeTransBlock(torch.nn.Module):
def __init__(
self,
dim_model=4096,
dim_ff=1024,
dim_out=768,
dtype=torch.float,
eps=1e-6,
dropout_p=0,
):
super().__init__()
if dropout_p is not None:
self.dropout = torch.nn.Dropout(dropout_p)
else:
self.dropout = None
self.w_out_res = torch.nn.Linear(dim_model, dim_out, bias=False)
self.layernorm = torch.nn.LayerNorm(
dim_out,
dtype=dtype,
eps=eps,
)
def forward(self, hidden_states: torch.Tensor):
x_res = self.w_out_res(hidden_states)
if self.dropout is not None:
x_res = self.dropout(x_res)
hidden_states = self.layernorm(x_res)
return hidden_states
class SDWrapper(torch.nn.Module):
def __init__(self, image_safety_checker=True):
super().__init__()
self.vae = AutoencoderKL.from_pretrained('stabilityai/stable-diffusion-2-1-base', subfolder='vae')
self.noise_scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-2-1-base', subfolder='scheduler')
self.unet = UNet2DConditionModel.from_config(UNet2DConditionModel.load_config(
'stabilityai/stable-diffusion-2-1-base', subfolder='unet'))
self.trans_block = CPMBeeTransBlock(4096, 4096 // 4, self.unet.config.cross_attention_dim)
if image_safety_checker:
self.image_safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker")
self.feature_extractor = CLIPImageProcessor.from_pretrained(
"openai/clip-vit-base-patch32"
)
else:
self.image_safety_checker = None
self.feature_extractor = None
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
def forward(self, pixel_values, text_hidden_states):
pixel_values = pixel_values.type(text_hidden_states.dtype)
latents = self.vae.encode(pixel_values).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
text_hidden_states = text_hidden_states.type(noisy_latents.dtype)
if self.trans_block is not None:
text_hidden_states = self.trans_block(text_hidden_states)
model_pred = self.unet(noisy_latents, timesteps, text_hidden_states).sample
loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
return loss, model_pred
@torch.no_grad()
def generate(self,
text_hidden_states,
uncond_text_hidden_states,
height=None,
width=None,
num_inference_steps=50,
guidance_scale=7.5,
num_images_per_prompt=1,
generator=None,
latents=None,
scheduler=None,
output_type='pil'
):
device = text_hidden_states.device
batch_size = text_hidden_states.size(0)
text_hidden_states = text_hidden_states.type(self.unet.conv_in.weight.dtype)
uncond_text_hidden_states = uncond_text_hidden_states.type(self.unet.conv_in.weight.dtype)
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if scheduler is not None:
self.noise_scheduler = scheduler
self.noise_scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.noise_scheduler.timesteps
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_hidden_states.dtype,
device,
generator,
latents,
)
if self.trans_block is not None:
text_hidden_states = self.trans_block(text_hidden_states)
uncond_text_hidden_states = self.trans_block(uncond_text_hidden_states)
bs_embed, seq_len, _ = text_hidden_states.shape
text_hidden_states = text_hidden_states.repeat(1, num_images_per_prompt, 1)
text_hidden_states = text_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1)
bs_embed, seq_len, _ = uncond_text_hidden_states.shape
uncond_text_hidden_states = uncond_text_hidden_states.repeat(1, num_images_per_prompt, 1)
uncond_text_hidden_states = uncond_text_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_hidden_states = torch.cat([uncond_text_hidden_states, text_hidden_states])
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=text_hidden_states,
).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.noise_scheduler.step(noise_pred, t, latents, generator=generator).prev_sample
image = self.decode_latents(latents)
# Run safety checker
image, has_nsfw_concept = self.run_image_safety_checker(image, device, self.unet.conv_in.weight.dtype)
if output_type == 'pil':
image = utils.numpy_to_pil(image)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = utils.randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the noise_scheduler
latents = latents * self.noise_scheduler.init_noise_sigma
return latents
def numpy_to_pil(self, images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def run_image_safety_checker(self, image, device, dtype):
if self.image_safety_checker is not None:
image_safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.image_safety_checker(
images=image, clip_input=image_safety_checker_input.pixel_values.to(dtype)
)
if any(has_nsfw_concept):
print(
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
)
for idx, _has_nsfw_concept in enumerate(has_nsfw_concept):
if _has_nsfw_concept:
image[idx] = np.zeros(image[idx].shape) # black image
else:
has_nsfw_concept = None
return image, has_nsfw_concept
对上述代码的具体说明如下所示:
- 类CPMBeeTransBlock的功能是定义一个转换块,该块包含一个线性输出层和一个层归一化层,用于处理隐藏状态并提高模型的表现能力。
- 类SDWrapper的功能是实现Stable Diffusion模型的封装,集成多个组件(如自动编码器、噪声调度器和条件UNet模型),支持图像编码、噪声添加、解码和安全检查,并能够通过提供文本嵌入进行图像生成和控制。
- 方法forward的功能是接受输入的像素值和文本隐藏状态,执行图像编码、添加噪声,并通过UNet模型进行噪声预测,返回损失值和模型预测结果。
- 方法generate的功能是根据提供的文本隐藏状态生成图像,支持条件生成、无条件生成,并通过多次迭代和噪声调度来逐步生成最终图像。
- 方法decode_latents的功能是将潜在空间中的表示解码为图像,通过自动编码器进行反向处理,并对图像进行归一化以准备输出。
- 方法prepare_latents的功能是准备潜在表示的初始随机噪声,确定形状并根据输入生成或调整潜在变量,确保它们符合噪声调度器的要求。
- 方法numpy_to_pil的功能是将NumPy数组格式的图像转换为PIL图像格式,以便进行后续处理和展示。
- 方法run_image_safety_checker的功能是对生成的图像进行安全检查,识别潜在的不当内容,并在检测到不当内容时返回黑色图像作为替代。
12.5.6 多模态模型
本项目的核心功能通过两个相互结合的模块实现,它们分别负责处理视觉和语言任务。第一个模块专注于视觉内容的生成,能够根据输入的文本描述生成对应的图像,支持多样化的视觉输出。第二个模块则侧重于理解和解析图像内容,能够基于上传的图像生成相关的文本描述或回答。通过这两个模块的紧密配合,项目实现了高效的多模态交互,用户可以输入文本或上传图像,并获得相应的视觉生成或文本理解结果,从而实现流畅的文生图和图像识别体验。
(1)文件vlg_cpmbee.py定义了类VLG_CPMBee,其功能是将大语言模型(LLM)与稳定扩散(Stable Diffusion)模型结合,通过forward方法处理输入数据并计算损失和模型预测,同时提供generate方法,根据输入和无条件数据生成图像,利用LLM的隐藏状态增强生成过程。
class VLG_CPMBee(torch.nn.Module):
def __init__(self, llm, sd) -> None:
super().__init__()
self.sd = sd
self.llm = llm
def forward(self, data):
device = data['input_ids'].device
bs = data['input_ids'].size(0)
llm_hidden_state = self.llm.input_embedding(data['input_ids'], data['input_id_subs'])
_, hidden_states = self.llm(
input=data['input_ids'],
input_sub=data['input_id_subs'],
length=data['length'],
context=data['context'],
sample_ids=data['sample_ids'],
num_segments=data['num_segments'],
segment=data['segment_ids'],
segment_rel_offset=data['segment_rel_offset'],
segment_rel=data['segment_rel'],
span=data['span'],
ext_table_ids=data['ext_table_ids'],
ext_table_sub=data['ext_table_sub'],
hidden_states=llm_hidden_state
)
loss, model_pred = self.sd(data['pixel_values'], hidden_states)
return loss, model_pred
@torch.no_grad()
def generate(
self,
data,
uncond_data,
**generate_kwargs,
):
device = data['input_ids'].device
bs = data['input_ids'].size(0)
with torch.no_grad():
llm_hidden_state = self.llm.input_embedding(data['input_ids'], data['input_id_subs'])
_, hidden_states = self.llm(
input=data['input_ids'],
input_sub=data['input_id_subs'],
length=data['length'],
context=data['context'],
sample_ids=data['sample_ids'],
num_segments=data['num_segments'],
segment=data['segment_ids'],
segment_rel_offset=data['segment_rel_offset'],
segment_rel=data['segment_rel'],
span=data['span'],
ext_table_ids=data['ext_table_ids'],
ext_table_sub=data['ext_table_sub'],
hidden_states=llm_hidden_state
)
with torch.no_grad():
uncond_llm_hidden_state = self.llm.input_embedding(uncond_data['input_ids'], uncond_data['input_id_subs'])
_, uncond_hidden_states = self.llm(
input=uncond_data['input_ids'],
input_sub=uncond_data['input_id_subs'],
length=uncond_data['length'],
context=uncond_data['context'],
sample_ids=uncond_data['sample_ids'],
num_segments=uncond_data['num_segments'],
segment=uncond_data['segment_ids'],
segment_rel_offset=uncond_data['segment_rel_offset'],
segment_rel=uncond_data['segment_rel'],
span=uncond_data['span'],
ext_table_ids=uncond_data['ext_table_ids'],
ext_table_sub=uncond_data['ext_table_sub'],
hidden_states=uncond_llm_hidden_state
)
image = self.sd.generate(
hidden_states,
uncond_hidden_states,
**generate_kwargs
)
return image
(2)文件vlu_cpmbee.py实现了一个结合大语言模型(LLM)和视觉模型(VPM)的处理模块VLU_CPMBee,该模块通过construct_query_parameter函数生成查询参数,并在get_vllm_embedding方法中提取视觉特征,将其与语言输入的嵌入结合。forward方法执行前向传播,利用提取的视觉隐藏状态和语言输入生成逻辑回归输出及隐藏状态,从而支持多模态任务的处理。
def construct_query_parameter(query_k, h_size, init_weights):
query_data = torch.zeros(query_k, h_size)
trunc_normal_(query_data, std=.02)
for idx in range(query_k):
if init_weights[idx] is not None:
query_data[idx] = init_weights[idx]
query = torch.nn.Parameter(query_data)
return query
@dataclass
class CausalVLLMOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class VLU_CPMBee(torch.nn.Module):
def __init__(self, llm: CPMBeeTorch, vpm, vision_dim, query_num, device=None) -> None:
super().__init__()
self.device = device
self.vpm = vpm
self.llm = llm
self.vision_dim = vision_dim
self.query_num = query_num
self.query = None
if query_num is not None:
bos_weight = self.vpm.beit3.text_embed.weight.data[0]
eos_weight = self.vpm.beit3.text_embed.weight.data[2]
query_init_weight = [bos_weight] + [None] * (self.query_num - 2) + [eos_weight]
self.query = construct_query_parameter(
self.query_num, self.vision_dim, query_init_weight)
self.mapping = torch.nn.Sequential(
torch.nn.Linear(self.vpm.hidden_size, self.llm.config.dim_model),
torch.nn.GELU(),
torch.nn.Linear(self.llm.config.dim_model, self.llm.config.dim_model)
)
def get_vllm_embedding(self, data):
if 'vision_hidden_states' not in data:
pixel_values = data['pixel_values']
vision_hidden_states = self.vpm(pixel_values=pixel_values, query_embed=self.query)
vision_hidden_states = self.mapping(vision_hidden_states) # (query_num, llm_dim)
else:
vision_hidden_states = data['vision_hidden_states']
vllm_embedding = self.llm.input_embedding(data['input_ids'], data['input_id_subs'])
vision_hidden_states = vision_hidden_states.type(vllm_embedding.dtype)
image_bound = data['image_bound']
image_bound = image_bound.squeeze(1)
image_indices = torch.stack([torch.arange(r[0], r[1], dtype=torch.long) for r in image_bound]).to(self.device)
vllm_embedding.scatter_(1, image_indices.unsqueeze(-1).repeat(1, 1, vllm_embedding.shape[-1]),vision_hidden_states)
return vllm_embedding, vision_hidden_states
def forward(self, data, **kwargs):
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
logits, hidden_states = self.llm(
input=data['input_ids'],
input_sub=data['input_id_subs'],
length=data['length'],
context=data['context'],
sample_ids=data['sample_ids'],
num_segments=data['num_segments'],
segment=data['segment_ids'],
segment_rel_offset=data['segment_rel_offset'],
segment_rel=data['segment_rel'],
span=data['span'],
ext_table_ids=data['ext_table_ids'],
ext_table_sub=data['ext_table_sub'],
hidden_states=vllm_embedding
)
return CausalVLLMOutput(
logits=logits,
hidden_states=hidden_states,
vision_hidden_states=vision_hidden_states
)
文件vlg_cpmbee.py侧重于图像生成,而文件vlu_cpmbee.py则注重文本和视觉信息的融合与处理。两者的主要区别如下所示:
- vlg_cpmbee.py:主要关注于文本和图像之间的交互,使用大语言模型(LLM)和图像生成模型(如Stable Diffusion)结合,以实现图像生成或图像处理任务。它的forward方法处理输入数据并生成损失和预测的图像。
- vlu_cpmbee.py:则专注于视觉信息的处理,结合视觉模型(VPM)和大语言模型(LLM),用于多模态任务。它通过提取视觉特征并将其与语言输入的嵌入结合,生成与视觉相关的逻辑回归输出。