首页 > 编程语言 >在低显存GPU上运行PixArt-Σ/Flux.1图像生成:Python简短教程

在低显存GPU上运行PixArt-Σ/Flux.1图像生成:Python简短教程

时间:2024-08-28 11:22:52浏览次数:11  
标签:显存 Flux.1 Python pipe 图像 量化 PixArt

由PixArt-Σ在本地生成,所需显存不超过8Gb。

图像生成工具的热度从未如此高涨,而且它们也变得越来越强大。像PixArt Sigma和Flux.1这样的模型处于领先地位,这得益于它们的开源权重模型和宽松的许可协议。这种设置允许进行创造性的尝试,包括在不共享计算机外部数据的情况下训练LoRA模型。

然而,如果你使用的是较旧或显存较少的GPU,使用这些模型可能会有些挑战。通常在质量、速度和显存使用之间存在权衡。在这篇博文中,我们将重点优化速度和减少显存使用,同时尽量保持质量。这种方法在PixArt上效果尤其好,因为它模型较小,但在Flux.1上的效果可能有所不同。最后,我会分享一些针对Flux.1的替代解决方案。

PixArt Sigma和Flux.1都是基于Transformer的,这意味着它们可以利用大型语言模型(LLM)使用的量化技术。量化涉及将模型组件压缩,从而占用更少的内存。这允许你将所有模型组件同时保存在GPU显存中,生成速度会比在GPU和CPU之间移动权重的方法更快,因为后者会减慢处理速度。

让我们开始设置环境吧!

设置本地环境

首先,确保你已经安装了Nvidia驱动程序和Anaconda。

接下来,创建一个Python环境并安装所有主要需求:

conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia 然后安装Diffusers和Quanto库:

pip install pillow==10.3.0 loguru~=0.7.2 optimum-quanto==0.2.4 diffusers==0.30.0 transformers==4.44.2 accelerate==0.33.0 sentencepiece==0.2.0 量化代码


以下是一个让你入门的PixArt-Sigma简单脚本:

```
from optimum.quanto import qint8, qint4, quantize, freeze

from diffusers import PixArtSigmaPipeline

import torch

pipeline = PixArtSigmaPipeline.from_pretrained(

“PixArt-alpha/PixArt-Sigma-XL-2-1024-MS”, torch_dtype=torch.float16

)

quantize(pipeline.transformer, weights=qint8)

freeze(pipeline.transformer)

quantize(pipeline.text_encoder, weights=qint4, exclude=“proj_out”)

freeze(pipeline.text_encoder)

pipe = pipeline.to(“cuda”)

for i in range(2):

generator = torch.Generator(device=“cpu”).manual_seed(i)

prompt = "Cyberpunk cityscape, small black crow, neon lights, dark alleys, skyscrapers, futuristic, vibrant colors, high contrast, highly detailed"

image = pipe(prompt, height=512, width=768, guidance_scale=3.5, generator=generator).images[0]

image.save(f"Sigma_{i}.png")

```
理解脚本:这里是主要的实现步骤

  1. 导入必要的库:我们导入了量化、模型加载和GPU处理的库。
  2. 加载模型:我们将PixArt Sigma模型以半精度(float16)加载到CPU。
  3. 量化模型:对模型的Transformer和文本编码部分进行量化。这里使用了不同级别的量化:文本编码部分由于较大,使用qint4进行量化。视觉部分如果使用qint8进行量化,整个流水线将使用 7.5G显存,如果不进行量化,将使用约 8.5G显存
  4. 移动到GPU:将流水线移动到GPU .to("cuda") 进行更快的处理。
  5. 生成图像:使用 pipe 根据给定的提示生成图像并保存输出。

运行脚本

保存脚本并在相应环境中运行,您将看到基于提示“赛博朋克城市景观,小黑乌鸦,霓虹灯,黑暗的小巷,摩天大楼,未来主义,鲜艳的色彩,高对比度,高度细节”的图像生成,并保存为 sigma_1.png。在RTX 3080 GPU上生成图像需要 6秒钟

由 PixArt-Σ 本地生成

您可以使用Flux.1 Schnell实现类似的结果,尽管它包含更多组件,但这需要更激进的量化,这会显著降低质量(除非您拥有更多的显存,例如16或25 GB)。

```
import torch

from optimum.quanto import qint2, qint4, quantize, freeze

from diffusers.pipelines.flux.pipeline_flux import FluxPipeline

pipe = FluxPipeline.from_pretrained(“black-forest-labs/FLUX.1-schnell”, torch_dtype=torch.bfloat16)

quantize(pipe.text_encoder, weights=qint4, exclude=“proj_out”)

freeze(pipe.text_encoder)

quantize(pipe.text_encoder_2, weights=qint2, exclude=“proj_out”)

freeze(pipe.text_encoder_2)

quantize(pipe.transformer, weights=qint4, exclude=“proj_out”)

freeze(pipe.transformer)

pipe = pipe.to(“cuda”)

for i in range(10):

generator = torch.Generator(device=“cpu”).manual_seed(i)

prompt = “赛博朋克城市景观,小黑乌鸦,霓虹灯,黑暗的小巷,摩天大楼,未来主义,鲜艳的色彩,高对比度,高度细节”

image = pipe(prompt, height=512, width=768, guidance_scale=3.5, generator=generator, num_inference_steps=4).images[0]

image.save(f"Schnell_{i}.png")

```

由 Flux.1 Schnell 本地生成: 由于过度量化导致质量较差和提示词的遵循度较低

我们可以看到,将文本编码器量化到qint2和视觉变压器量化到qint8可能过于激进,显著影响了Flux.1 Schnell的质量。

以下是运行Flux.1 Schnell的一些替代方案:

如果PixArt-Sigma不能满足您的需求,而您又没有足够的显存运行Flux.1以获得足够的质量,那么您有两个主要选项:

  • ComfyUI或Forge:这些是爱好者使用的GUI工具,它们主要牺牲速度来提高质量。
  • Replicate API:每次生成Schnell图像的成本为0.003美元。

部署

我在一台旧机器上部署PixArt Sigma时找到了些乐趣。以下是我的简要步骤总结:

首先是组件列表:

  1. HTMX和Tailwind:它们就像项目的面貌。HTMX帮助使网站在没有大量额外代码的情况下互动,而Tailwind则赋予它漂亮的外观。
  2. FastAPI:它接收来自网站的请求并决定如何处理这些请求。
  3. Celery Worker:这就像是勤劳的工人。它接收FastAPI的指令并实际创建图像。
  4. Redis Cache/Pub-Sub:这是通信中心。它帮助项目的不同部分互相交流并记住重要信息。
  5. GCS(谷歌云存储):这是我们存储完成的图像的地方。

现在,它们如何协同工作?下面是一个简单的概述:

  • 当你访问网站并发出请求时,HTMX和Tailwind确保页面看起来很好。
  • FastAPI接收请求,并通过Redis告诉Celery Worker要生成什么样的图像。
  • Celery Worker开始工作,创建图像。
  • 图像生成后会存储在GCS中,便于访问。

应用程序演示

总结

通过量化模型组件,我们可以显著减少VRAM的使用,同时保持良好的图像质量并提高生成速度。此方法对于类似PixArt Sigma的模型尤其有效。对于Flux.1,尽管结果可能有所不同,但量化的原理依然适用。

参考资料:

  • https://huggingface.co/blog/quanto-diffusers
  • https://lightning.ai/lightning-ai/studios/deploy-an-image-generation-api-with-flux

总结:

近期,图像生成工具如PixArt Sigma和Flux.1迅速走红,凭借其开源权重模型和宽松的许可协议,用户可以进行创造性的尝试,尤其是在本地环境中使用较少显存的情况下。对于使用较旧或显存较少的GPU,我们可以通过量化技术来优化显存使用和提升生成速度。本文介绍了如何使用量化技术来压缩模型组件,从而在不牺牲图像质量的情况下减少显存使用。具体操作包括安装相关环境、加载模型、量化模型以及在GPU上进行处理。

为了实现高效的图像生成,提高显存利用率,文中给出了具体的量化步骤,如对模型的Transformer和文本编码部分进行量化,使用不同级别的量化来实现显存的最小化。此外,还分享了Flux.1的替代方案,适用于显存较大的GPU,以保持图像质量。

通过这些优化方法,无论是选择PixArt Sigma还是Flux.1模型,用户都能在低显存环境中快速生成高质量图像。这种方法不仅适用图像生成领域,同样适用于其他需要高计算资源支持的AI应用。

用光年AI,轻松提升私域流量转化率,赢得市场先机!光年AI通过整合主流AI平台及自研技术,提供高效、智能的流量增长解决方案,无论是图像生成还是营销管理,均能显著提升工作效率和客户满意度。选择光年AI,让您的私域流量增长无忧,开启AI时代的私域流量革命!

标签:显存,Flux.1,Python,pipe,图像,量化,PixArt
From: https://blog.csdn.net/2401_86793402/article/details/141636433

相关文章

  • Python系列(10)- Python 多线程
    多线程(Multithreading),是指从软件或者硬件上实现多个线程并发执行的技术。具有多线程能力的系统包括对称多处理机、多核心处理器、芯片级多处理或同时多线程处理器。在一个程序中,这些独立运行的程序片段叫作“线程”(Thread),利用它编程的概念就叫作“多线程处理”。多线程是并行化......
  • Python酷库之旅-第三方库Pandas(104)
    目录一、用法精讲451、pandas.DataFrame.pow方法451-1、语法451-2、参数451-3、功能451-4、返回值451-5、说明451-6、用法451-6-1、数据准备451-6-2、代码示例451-6-3、结果输出452、pandas.DataFrame.dot方法452-1、语法452-2、参数452-3、功能452-4、返回值......
  • Python画笔案例-017 绘制画H图
    1、绘制画H图通过python的turtle库绘制一个画H图的图案,如下图:2、实现代码 绘制一个画H图图案,以下实现的代码直接按移动,左转,右转的方式实现,大家可以尝试把本程序改成递归图,要点为在下面的dot命令修改。相信你一定能完成。:"""画H图.py"""importturtle......
  • yum依赖python2环境-"No module named urlgrabber"
    1.python3安装perl环境以及IPC/cmd.pm模块,由于环境中安装了pyhon2和python3导致模块引入冲突。makepython3时一直报错没有Module_tktinter,重新安装tk后python3还是import失败 2.检查发现python2可以引入,并且再进行安装模块时,使用的是python,而系统python指向python2 3.修改......
  • 简答登陆采集python
    importparamikoimportos创建SSH对象ssh=paramiko.SSHClient()允许连接不在know_hosts文件中的主机ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())连接服务器ssh.connect(hostname='10.10.10.10',port=22,username='root',password='root123�......
  • 使用Python进行Mock测试详解(含Web API接口Mock)
    使用Python进行Mock测试详解(含WebAPI接口Mock)在软件开发过程中,单元测试是非常重要的一部分。为了确保代码的质量和可靠性,开发者需要编写测试用例来检查代码的行为是否符合预期。然而,在测试中有时会遇到一些难以直接测试的情况,例如依赖外部系统、数据库或网络服务等。在这......
  • Python数据采集与网络爬虫技术实训室解决方案
    在大数据与人工智能时代,数据采集与分析已成为企业决策、市场洞察、产品创新等领域不可或缺的一环。而Python,作为一门高效、易学的编程语言,凭借其强大的库支持和广泛的应用场景,在数据采集与网络爬虫领域展现出了非凡的潜力。唯众特此推出《Python数据采集与网络爬虫技术实训......
  • Python的继承
    #1.继承#就是让类和类之间转变为父子关系,子类默认继承父类的属性和方法#1.1语法#class类名(父类名):# 代码块#1.2单继承#classPerson:#  defeat(self):#    print("吃")#  defdrink(self):#    print('喝')#  def......
  • 【python】基础之生成器
    1.什么是生成器?是Python中一种特殊的迭代器,它是一个能按需生成值的轻量级对象。与一次性创建所有元素的数据结构(如列表或元组)不同,生成器在每次迭代时只生成下一个值,从而节省内存并支持无限序列或其他大量数据流的操作。#iter中简单是4行代码,可以代替MyRangeIterator一样的......
  • python读取串口 数据
    读取10s数据,然后关闭串口#读取10s串口数据后关闭这个串口importtimedefread_serial(port,baudrate,duration):try:#初始化串口ser=serial.Serial(port,baudrate)print(f"Openedserialport{port}at{baudrate}baud.")......