PixArt-Sigma是由华为诺亚方舟实验室、大连理工大学和香港大学的研究人员共同开发的一个先进的文本到图像(Text-to-Image,T2I)生成模型。
PixArt-Sigma是在PixArt-alpha的基础上进一步改进的模型,旨在生成高质量的4K分辨率图像。
PixArt-Sigma通过整合高级元素和采用由弱到强式训练方法,这种策略有助于模型逐渐学习并优化图像细节,从而提高了生成图像的保真度和与文本提示的对齐程度。
PixArt-Sigma在美学质量上与当前顶级的文本到图像产品如DALL·E 3和Midjourney V6不相上下,并且在遵循文本提示方面表现出色。
PixArt-Sigma的生成能力支持高分辨率海报和壁纸的创作,有效支持电影和游戏等行业高质量视觉内容的制作。
github项目地址:https://github.com/PixArt-alpha/PixArt-sigma。
一、环境安装
1、python环境
建议安装python版本在3.10以上。
2、pip库安装
pip install torch==2.3.0+cu118 torchvision==0.18.0+cu118 torchaudio==2.3.0 --extra-index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
3、SDXL-VAE模型下载:
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers
3、PixArt-Sigma模型下载:
python tools/download.py
二、功能测试
1、命令行运行测试:
(1)python代码调用测试
import os
import re
import sys
import argparse
from datetime import datetime
from pathlib import Path
import torch
from torch import nn
from torchvision.utils import save_image
from tqdm import tqdm
from diffusers.models import AutoencoderKL
from transformers import T5EncoderModel, T5Tokenizer
from diffusion.model.utils import prepare_prompt_ar
from diffusion import IDDPM, DPMS, SASolverSampler
from diffusion.model.nets import PixArtMS_XL_2, PixArt_XL_2
from diffusion.data.datasets import get_chunks
import diffusion.data.datasets.utils as ds_utils
from tools.download import find_model
class ImageGenerator:
def __init__(self, args):
self.args = args
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.seed = args.seed
self._set_env()
self._load_model_components()
def _set_env(self):
torch.manual_seed(self.seed)
torch.set_grad_enabled(False)
for _ in range(30):
torch.randn(1, 4, self.args.image_size, self.args.image_size)
def _load_model_components(self):
self.latent_size = self.args.image_size // 8
self.max_sequence_length = {"alpha": 120, "sigma": 300}[self.args.version]
self.pe_interpolation = self.args.image_size / 512
self.micro_condition = self.args.version == 'alpha' and self.args.image_size == 1024
self.sample_steps_dict = {'iddpm': 100, 'dpm-solver': 20, 'sa-solver': 25}
self.sample_steps = self.args.step if self.args.step != -1 else self.sample_steps_dict[self.args.sampling_algo]
self.weight_dtype = torch.float16
self._load_main_model()
self._load_vae()
self._load_text_components()
def _load_main_model(self):
if self.args.image_size in [512, 1024, 2048] or self.args.version == 'sigma':
self.model = PixArtMS_XL_2(
input_size=self.latent_size,
pe_interpolation=self.pe_interpolation,
micro_condition=self.micro_condition,
model_max_length=self.max_sequence_length,
).to(self.device)
else:
self.model = PixArt_XL_2(
input_size=self.latent_size,
pe_interpolation=self.pe_interpolation,
model_max_length=self.max_sequence_length,
).to(self.device)
print("Generating sample from ckpt: %s" % self.args.model_path)
state_dict = find_model(self.args.model_path)
state_dict['state_dict'].pop('pos_embed', None)
missing, unexpected = self.model.load_state_dict(state_dict['state_dict'], strict=False)
print('Missing keys: ', missing)
print('Unexpected keys', unexpected)
self.model.eval()
self.model.to(self.weight_dtype)
self.base_ratios = getattr(ds_utils, f'ASPECT_RATIO_{self.args.image_size}', ds_utils.ASPECT_RATIO_1024)
def _load_vae(self):
vae_path = "output/pretrained_models/sd-vae-ft-ema" if self.args.sdvae else f"{self.args.pipeline_load_from}/vae"
self.vae = AutoencoderKL.from_pretrained(vae_path).to(self.device).to(self.weight_dtype)
def _load_text_components(self):
self.tokenizer = T5Tokenizer.from_pretrained(self.args.pipeline_load_from, subfolder="tokenizer")
self.text_encoder = T5EncoderModel.from_pretrained(self.args.pipeline_load_from, subfolder="text_encoder").to(self.device)
null_caption_token = self.tokenizer("", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt").to(self.device)
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, attention_mask=null_caption_token.attention_mask)[0]
def generate_images(self, items: list):
save_root = self._prepare_save_directory()
self._visualize(items, save_root)
def _prepare_save_directory(self):
work_dir = 'output'
try:
epoch_name = re.search(r'.*epoch_(\d+).*', self.args.model_path).group(1)
step_name = re.search(r'.*step_(\d+).*', self.args.model_path).group(1)
except:
epoch_name = 'unknown'
step_name = 'unknown'
img_save_dir = os.path.join(work_dir, 'vis')
os.umask(0o000) # file permission: 666; dir permission: 777
os.makedirs(img_save_dir, exist_ok=True)
save_root = os.path.join(img_save_dir, f"{datetime.now().date()}_{self.args.dataset}_epoch{epoch_name}_step{step_name}_scale{self.args.cfg_scale}_step{self.sample_steps}_size{self.args.image_size}_bs{self.args.bs}_samp{self.args.sampling_algo}_seed{self.seed}")
print("save_root: ", save_root)
os.makedirs(save_root, exist_ok=True)
return save_root
@torch.inference_mode()
def _visualize(self, items, save_root):
for chunk in tqdm(list(get_chunks(items, self.args.bs)), unit='batch'):
prompts, hw, ar = self._prepare_prompts_and_configs(chunk)
caption_embs, emb_masks, null_y = self._get_text_embeddings(prompts)
with torch.no_grad():
samples = self._run_sampling(hw, ar, caption_embs, emb_masks, null_y)
self._save_images(samples, save_root)
def _prepare_prompts_and_configs(self, chunk):
prompts = []
if self.args.bs == 1:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
save_path = os.path.join(save_root, f"{timestamp}.jpg")
if os.path.exists(save_path):
return
prompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(chunk[0], self.base_ratios, device=self.device, show=False)
latent_size_h, latent_size_w = int(hw[0, 0] // 8), int(hw[0, 1] // 8)
prompts.append(prompt_clean.strip())
else:
hw = torch.tensor([[self.args.image_size, self.args.image_size]], dtype=torch.float, device=self.device).repeat(self.args.bs, 1)
ar = torch.tensor([[1.]], device=self.device).repeat(self.args.bs, 1)
for prompt in chunk:
prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())
latent_size_h, latent_size_w = self.latent_size, self.latent_size
return prompts, hw, ar
def _get_text_embeddings(self, prompts):
caption_token = self.tokenizer(prompts, max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt").to(self.device)
caption_embs = self.text_encoder(caption_token.input_ids, attention_mask=caption_token.attention_mask)[0]
emb_masks = caption_token.attention_mask
caption_embs = caption_embs[:, None]
null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None]
print(f'finish embedding')
return caption_embs, emb_masks, null_y
def _run_sampling(self, hw, ar, caption_embs, emb_masks, null_y):
model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks)
if self.args.sampling_algo == 'iddpm':
z = torch.randn(len(prompts), 4, latent_size_h, latent_size_w, device=self.device).repeat(2, 1, 1, 1)
model_kwargs['y'] = torch.cat([caption_embs, null_y])
model_kwargs['cfg_scale'] = self.args.cfg_scale
diffusion = IDDPM(str(self.sample_steps))
samples = diffusion.p_sample_loop(
self.model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True,
device=self.device
)
samples, _ = samples.chunk(2, dim=0)
elif self.args.sampling_algo == 'dpm-solver':
z = torch.randn(len(prompts), 4, latent_size_h, latent_size_w, device=self.device)
dpm_solver = DPMS(self.model.forward_with_dpmsolver,
condition=caption_embs,
uncondition=null_y,
cfg_scale=self.args.cfg_scale,
model_kwargs=model_kwargs)
samples = dpm_solver.sample(
z,
steps=self.sample_steps,
order=2,
skip_type="time_uniform",
method="multistep",
)
elif self.args.sampling_algo == 'sa-solver':
sa_solver = SASolverSampler(self.model.forward_with_dpmsolver, device=self.device)
samples = sa_solver.sample(
S=25,
batch_size=len(prompts),
shape=(4, latent_size_h, latent_size_w),
eta=1,
conditioning=caption_embs,
unconditional_conditioning=null_y,
unconditional_guidance_scale=self.args.cfg_scale,
model_kwargs=model_kwargs,
)[0]
samples = samples.to(self.weight_dtype)
samples = self.vae.decode(samples / self.vae.config.scaling_factor).sample
torch.cuda.empty_cache()
return samples
def _save_images(self, samples, save_root):
os.umask(0o000) # file permission: 666; dir permission: 777
for i, sample in enumerate(samples):
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
save_path = os.path.join(save_root, f"{timestamp}.jpg")
print("Saving path: ", save_path)
save_image(sample, save_path, nrow=1, normalize=True, value_range=(-1, 1))
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--image_size', default=1024, type=int)
parser.add_argument('--version', default='sigma', type=str)
parser.add_argument(
"--pipeline_load_from", default='PixArt-sigma-model/pixart_sigma_sdxlvae_T5_diffusers',
type=str, help="Download for loading text_encoder, "
"tokenizer and vae from https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"
)
parser.add_argument('--txt_file', default='asset/test.txt', type=str)
parser.add_argument('--model_path', default='PixArt-sigma-model/PixArt-Sigma-XL-2-1024-MS.pth', type=str)
parser.add_argument('--sdvae', action='store_true', help='sd vae')
parser.add_argument('--bs', default=1, type=int)
parser.add_argument('--cfg_scale', default=4.5, type=float)
parser.add_argument('--sampling_algo', default='dpm-solver', type=str, choices=['iddpm', 'dpm-solver', 'sa-solver'])
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--dataset', default='custom', type=str)
parser.add_argument('--step', default=-1, type=int)
parser.add_argument('--save_name', default='test_sample', type=str)
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
generator = ImageGenerator(args)
with open(args.txt_file, 'r') as f:
items = [item.strip() for item in f.readlines()]
generator.generate_images(items)
(2)web端测试
未完......
更多详细的内容欢迎关注:杰哥新技术