首页 > 其他分享 >支持4K高分辨率,PixArt-Sigma最新文生图落地经验

支持4K高分辨率,PixArt-Sigma最新文生图落地经验

时间:2024-07-26 12:29:11浏览次数:9  
标签:文生 self args 4K device model save PixArt size

PixArt-Sigma是由华为诺亚方舟实验室、大连理工大学和香港大学的研究人员共同开发的一个先进的文本到图像(Text-to-Image,T2I)生成模型。

PixArt-Sigma是在PixArt-alpha的基础上进一步改进的模型,旨在生成高质量的4K分辨率图像。

PixArt-Sigma通过整合高级元素和采用由弱到强式训练方法,这种策略有助于模型逐渐学习并优化图像细节,从而提高了生成图像的保真度和与文本提示的对齐程度。

PixArt-Sigma在美学质量上与当前顶级的文本到图像产品如DALL·E 3和Midjourney V6不相上下,并且在遵循文本提示方面表现出色。

PixArt-Sigma的生成能力支持高分辨率海报和壁纸的创作,有效支持电影和游戏等行业高质量视觉内容的制作。

github项目地址:https://github.com/PixArt-alpha/PixArt-sigma。

一、环境安装

1、python环境

建议安装python版本在3.10以上。

2、pip库安装

pip install torch==2.3.0+cu118 torchvision==0.18.0+cu118 torchaudio==2.3.0 --extra-index-url https://download.pytorch.org/whl/cu118

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

3、SDXL-VAE模型下载

git lfs install

git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers

3、PixArt-Sigma模型下载

python tools/download.py

、功能测试

1、命令行运行测试

(1)python代码调用测试
 

import os
import re
import sys
import argparse
from datetime import datetime
from pathlib import Path

import torch
from torch import nn
from torchvision.utils import save_image
from tqdm import tqdm
from diffusers.models import AutoencoderKL
from transformers import T5EncoderModel, T5Tokenizer

from diffusion.model.utils import prepare_prompt_ar
from diffusion import IDDPM, DPMS, SASolverSampler
from diffusion.model.nets import PixArtMS_XL_2, PixArt_XL_2
from diffusion.data.datasets import get_chunks
import diffusion.data.datasets.utils as ds_utils
from tools.download import find_model


class ImageGenerator:
    def __init__(self, args):
        self.args = args
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.seed = args.seed
        self._set_env()
        self._load_model_components()

    def _set_env(self):
        torch.manual_seed(self.seed)
        torch.set_grad_enabled(False)
        for _ in range(30):
            torch.randn(1, 4, self.args.image_size, self.args.image_size)

    def _load_model_components(self):
        self.latent_size = self.args.image_size // 8
        self.max_sequence_length = {"alpha": 120, "sigma": 300}[self.args.version]
        self.pe_interpolation = self.args.image_size / 512
        self.micro_condition = self.args.version == 'alpha' and self.args.image_size == 1024
        self.sample_steps_dict = {'iddpm': 100, 'dpm-solver': 20, 'sa-solver': 25}
        self.sample_steps = self.args.step if self.args.step != -1 else self.sample_steps_dict[self.args.sampling_algo]
        self.weight_dtype = torch.float16
        
        self._load_main_model()
        self._load_vae()
        self._load_text_components()

    def _load_main_model(self):
        if self.args.image_size in [512, 1024, 2048] or self.args.version == 'sigma':
            self.model = PixArtMS_XL_2(
                input_size=self.latent_size,
                pe_interpolation=self.pe_interpolation,
                micro_condition=self.micro_condition,
                model_max_length=self.max_sequence_length,
            ).to(self.device)
        else:
            self.model = PixArt_XL_2(
                input_size=self.latent_size,
                pe_interpolation=self.pe_interpolation,
                model_max_length=self.max_sequence_length,
            ).to(self.device)

        print("Generating sample from ckpt: %s" % self.args.model_path)
        state_dict = find_model(self.args.model_path)
        state_dict['state_dict'].pop('pos_embed', None)
        missing, unexpected = self.model.load_state_dict(state_dict['state_dict'], strict=False)
        print('Missing keys: ', missing)
        print('Unexpected keys', unexpected)
        self.model.eval()
        self.model.to(self.weight_dtype)

        self.base_ratios = getattr(ds_utils, f'ASPECT_RATIO_{self.args.image_size}', ds_utils.ASPECT_RATIO_1024)

    def _load_vae(self):
        vae_path = "output/pretrained_models/sd-vae-ft-ema" if self.args.sdvae else f"{self.args.pipeline_load_from}/vae"
        self.vae = AutoencoderKL.from_pretrained(vae_path).to(self.device).to(self.weight_dtype)

    def _load_text_components(self):
        self.tokenizer = T5Tokenizer.from_pretrained(self.args.pipeline_load_from, subfolder="tokenizer")
        self.text_encoder = T5EncoderModel.from_pretrained(self.args.pipeline_load_from, subfolder="text_encoder").to(self.device)
        null_caption_token = self.tokenizer("", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt").to(self.device)
        self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, attention_mask=null_caption_token.attention_mask)[0]

    def generate_images(self, items: list):
        save_root = self._prepare_save_directory()
        self._visualize(items, save_root)

    def _prepare_save_directory(self):
        work_dir = 'output'
        try:
            epoch_name = re.search(r'.*epoch_(\d+).*', self.args.model_path).group(1)
            step_name = re.search(r'.*step_(\d+).*', self.args.model_path).group(1)
        except:
            epoch_name = 'unknown'
            step_name = 'unknown'
        
        img_save_dir = os.path.join(work_dir, 'vis')
        os.umask(0o000)  # file permission: 666; dir permission: 777
        os.makedirs(img_save_dir, exist_ok=True)

        save_root = os.path.join(img_save_dir, f"{datetime.now().date()}_{self.args.dataset}_epoch{epoch_name}_step{step_name}_scale{self.args.cfg_scale}_step{self.sample_steps}_size{self.args.image_size}_bs{self.args.bs}_samp{self.args.sampling_algo}_seed{self.seed}")
        print("save_root: ", save_root)
        os.makedirs(save_root, exist_ok=True)
        
        return save_root

    @torch.inference_mode()
    def _visualize(self, items, save_root):
        for chunk in tqdm(list(get_chunks(items, self.args.bs)), unit='batch'):
            prompts, hw, ar = self._prepare_prompts_and_configs(chunk)
            caption_embs, emb_masks, null_y = self._get_text_embeddings(prompts)

            with torch.no_grad():
                samples = self._run_sampling(hw, ar, caption_embs, emb_masks, null_y)
            self._save_images(samples, save_root)

    def _prepare_prompts_and_configs(self, chunk):
        prompts = []
        if self.args.bs == 1:
            timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
            save_path = os.path.join(save_root, f"{timestamp}.jpg")
            if os.path.exists(save_path):
                return
            prompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(chunk[0], self.base_ratios, device=self.device, show=False)
            latent_size_h, latent_size_w = int(hw[0, 0] // 8), int(hw[0, 1] // 8)
            prompts.append(prompt_clean.strip())
        else:
            hw = torch.tensor([[self.args.image_size, self.args.image_size]], dtype=torch.float, device=self.device).repeat(self.args.bs, 1)
            ar = torch.tensor([[1.]], device=self.device).repeat(self.args.bs, 1)
            for prompt in chunk:
                prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())
            latent_size_h, latent_size_w = self.latent_size, self.latent_size

        return prompts, hw, ar

    def _get_text_embeddings(self, prompts):
        caption_token = self.tokenizer(prompts, max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt").to(self.device)
        caption_embs = self.text_encoder(caption_token.input_ids, attention_mask=caption_token.attention_mask)[0]
        emb_masks = caption_token.attention_mask

        caption_embs = caption_embs[:, None]
        null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None]
        print(f'finish embedding')

        return caption_embs, emb_masks, null_y

    def _run_sampling(self, hw, ar, caption_embs, emb_masks, null_y):
        model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks)
        if self.args.sampling_algo == 'iddpm':
            z = torch.randn(len(prompts), 4, latent_size_h, latent_size_w, device=self.device).repeat(2, 1, 1, 1)
            model_kwargs['y'] = torch.cat([caption_embs, null_y])
            model_kwargs['cfg_scale'] = self.args.cfg_scale
            diffusion = IDDPM(str(self.sample_steps))
            samples = diffusion.p_sample_loop(
                self.model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True,
                device=self.device
            )
            samples, _ = samples.chunk(2, dim=0)
        elif self.args.sampling_algo == 'dpm-solver':
            z = torch.randn(len(prompts), 4, latent_size_h, latent_size_w, device=self.device)
            dpm_solver = DPMS(self.model.forward_with_dpmsolver,
                              condition=caption_embs,
                              uncondition=null_y,
                              cfg_scale=self.args.cfg_scale,
                              model_kwargs=model_kwargs)
            samples = dpm_solver.sample(
                z,
                steps=self.sample_steps,
                order=2,
                skip_type="time_uniform",
                method="multistep",
            )
        elif self.args.sampling_algo == 'sa-solver':
            sa_solver = SASolverSampler(self.model.forward_with_dpmsolver, device=self.device)
            samples = sa_solver.sample(
                S=25,
                batch_size=len(prompts),
                shape=(4, latent_size_h, latent_size_w),
                eta=1,
                conditioning=caption_embs,
                unconditional_conditioning=null_y,
                unconditional_guidance_scale=self.args.cfg_scale,
                model_kwargs=model_kwargs,
            )[0]

        samples = samples.to(self.weight_dtype)
        samples = self.vae.decode(samples / self.vae.config.scaling_factor).sample
        torch.cuda.empty_cache()

        return samples

    def _save_images(self, samples, save_root):
        os.umask(0o000)  # file permission: 666; dir permission: 777
        for i, sample in enumerate(samples):
            timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
            save_path = os.path.join(save_root, f"{timestamp}.jpg")
            print("Saving path: ", save_path)
            save_image(sample, save_path, nrow=1, normalize=True, value_range=(-1, 1))


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_size', default=1024, type=int)
    parser.add_argument('--version', default='sigma', type=str)
    parser.add_argument(
        "--pipeline_load_from", default='PixArt-sigma-model/pixart_sigma_sdxlvae_T5_diffusers',
        type=str, help="Download for loading text_encoder, "
                       "tokenizer and vae from https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"
    )
    parser.add_argument('--txt_file', default='asset/test.txt', type=str)
    parser.add_argument('--model_path', default='PixArt-sigma-model/PixArt-Sigma-XL-2-1024-MS.pth', type=str)
    parser.add_argument('--sdvae', action='store_true', help='sd vae')
    parser.add_argument('--bs', default=1, type=int)
    parser.add_argument('--cfg_scale', default=4.5, type=float)
    parser.add_argument('--sampling_algo', default='dpm-solver', type=str, choices=['iddpm', 'dpm-solver', 'sa-solver'])
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--dataset', default='custom', type=str)
    parser.add_argument('--step', default=-1, type=int)
    parser.add_argument('--save_name', default='test_sample', type=str)
    return parser.parse_args()

if __name__ == '__main__':
    args = get_args()
    generator = ImageGenerator(args)
    with open(args.txt_file, 'r') as f:
        items = [item.strip() for item in f.readlines()]
    generator.generate_images(items)

(2)web端测试

未完......

更多详细的内容欢迎关注:杰哥新技术
 

标签:文生,self,args,4K,device,model,save,PixArt,size
From: https://blog.csdn.net/m0_71062934/article/details/140593574

相关文章