首页 > 其他分享 >Diffusers实战

Diffusers实战

时间:2024-02-20 20:33:37浏览次数:20  
标签:实战 noise torch Diffusers device images import size

Smiling & Weeping

 

                ---- 一生拥有自由和爱,是我全部的野心

 

1. 环境准备

 

%pip install diffusers

 

from huggingface_hub import notebook_login

# 登录huggingface
notebook_login()
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
import torchvision
from PIL import Image
​
def show_images(x):
    """给定一批图像,创建一个网格并将其转换成PIL"""
    x = x*0.5 + 0.5
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1)*255
    grad_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grad_im
​
def make_grid(images, size=64):
    """给定一个PIL图像列表,将他们叠加成一行以便查看"""
    output_im = Image.new("RGB", (size*len(images), size))
    for i, im in enumerate(images):
        out_im.paste(im.resize((size, size)), (i*size, 0))
    return output_im
​
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
from diffusers import DDPMPipeline, StableDiffusionPipeline

model_id = "sd-dreambooth-library/mr-potato-head"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
prompt = "a cute anime characters using 8K resolution"
image = pipe(prompt, num_inference_steps=50, guidance_scale=5.5).images[0]
image

 

Diffusers核心API:

  • 管线:从高层次设计的多种类函数,便于部署的方式实现,能够快速利用预训练的主流扩散模型来生成样本。
  • 模型:在训练新的扩散模型时需要用到的网络结构。
  • 调度器:在推理过程中使用多种不同的技巧来从噪声中生成图像,同时可以生成训练过程中所需的“带噪”图像。
import torchvision
from datasets import load_dataset
from torchvision import transforms
from diffusers import DDPMScheduler
from diffusers import DDPMPipeline, StableDiffusionPipeline

dataset = load_dataset('lowres/anime', split="train")

image_size = 256
batch_size = 8

preprocess = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize([0.5], [0.5]),
])

def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

dataset.set_transform(transform)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

  

xb = next(iter(train_dataloader))['images'].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8*256, 256), resample=Image.NEAREST)

 

# 定义调度器
from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)
timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.rand_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print("Noise X Shape", noisy_xb.shape)
show_images(noisy_xb).resize((8*64, 64), resample=Image.NEAREST)

  

from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=image_size,  # 目标图像的分辨率
    in_channels=3,
    out_channels=3,
    layers_per_block=2,      # 每一个UNet块中的ResNet层数
    block_out_channels=(64, 128, 128, 256),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",   # 带有空域维度的self-att的ResNet下采样模块
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",     # 带有空域维度的self-att的ResNet上采样模块
        "UpBlock2D",
        "UpBlock2D",
    ),
)

model = model.to(device)
with torch.no_grad():
    model_pred = model(noisy_xb, timesteps).sample

model_pred.shape

 训练



# 设定噪声调度器 noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2") # 训练循环 optimizer = torch.optim.Adam(model.parameters(), lr=4e-4) losses = [] # 定义损失函数 loss_fn = torch.nn.MSELoss() for epoch in range(45): for step, batch in enumerate(train_dataloader): # 未添加噪声的数据(clean data) clean_data = batch['images'].to(device) # 生成噪声 noise = torch.randn(clean_data.shape).to(device) bs = clean_data.shape[0] # 为每张图片随机采样一个时间步 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs, ), device=device).long() # 噪声数据 # 根据每个时间步的噪声幅度(迭代次数),向清晰的图片中添加噪声 noisy_data = noise_scheduler.add_noise(clean_data, noise, timesteps) # 获得预测模型 pred_data = model(noisy_data, timesteps, return_dict=False)[0] # 计算损失 loss = loss_fn(pred_data, clean_data) loss.backward() losses.append(loss.item()) # 迭代模型参数 optimizer.step() optimizer.zero_grad() if (epoch+1) % 5 == 0: loss_last_epoch = sum(losses[-len(train_dataloader):]) / len(train_dataloader) print(f"Epoch: {epoch+1}, loss: {loss_last_epoch}")
torch.save(model.state_dict(), 'save.pt')

 绘制损失图线

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))
plt.show()

 

标签:实战,noise,torch,Diffusers,device,images,import,size
From: https://www.cnblogs.com/smiling-weeping-zhr/p/18023991

相关文章

  • python实战:用requests+做爬虫
    一,安装requests1,用pip安装(venv)liuhongdi@192news%pip3installrequests2,查看所安装库的版本:(venv)liuhongdi@192news%pip3showrequestsName:requestsVersion:2.31.0Summary:PythonHTTPforHumans.Home-page:https://requests.readthedocs.ioAu......
  • 轻松掌握 Linux 文本处理三剑客:grep、awk 和 sed 实战演练
     Shell脚本语言编程有哪些优势呢?Shell脚本语言的优势在于能够以最轻量级最快捷的速度处理Linux操作系统偏底层的业务。比如软件的自动化安装、更新版本,监控报警,日志分析等。虽然其他高级编程语言如PHP、Python、Ruby等语言也能做到,但是效率和开发成本上会大打折扣,所谓“......
  • python实战:用SQLAlchemy作orm
    一,安装SQLAlchemy1,用pip安装(venv)[liuhongdi@imgnews]$pip3installsqlalchemy2,安装完成后查看已安装的版本:(venv)[liuhongdi@imgnews]$pip3showsqlalchemyName:SQLAlchemyVersion:2.0.27Summary:DatabaseAbstractionLibraryHome-page:https://www......
  • 前端知识回顾概览--商业级项目实战
    1.大厂性能的计算方式与优化方案网页性能指标影响因素客户端缓存策略异步加载按需加载bigpipe浏览器原理与PWA2.大厂前端页面的质量保障单元测试上线规范预发环境线上日志及报警定时自动检查页面3.上列表无限滚动方案不同框架的实现方案渲染卡顿的解决方案高性能......
  • Qt 项目实战:基于QMediaPlayer播放器
    QMediaPlayer开发视频播放器Q:我们为何不使用QMediaPlayer?A:QMediaPlayer支持的编解码库太少;QMediaPlayer在windows中解码调用的是DirectShow,在Linux中调用的是GStreamer;相对Windows而言GStreamer扩展编解码库比较方便,但是windows中的DirectShow太老了,DemuxerDecoder都比较麻......
  • Qt 项目实战:电子时钟
    电子时钟隐藏widget边框this->setWindowFlags(Qt::FramelessWindowHint);//隐藏边框实时跟踪鼠标this->setMouseTracking(true);//实时跟踪鼠标通过信号与槽来刷新时针分针秒针状态connect(timer,SIGNAL(timeout()),this,SLOT(update()));鼠标左键按下移动窗......
  • Qt 项目实战:MD5工具开发
    MD介绍MD5消息摘要算法(英语:MD5Message­DigestAlgorithm),一种被广泛使用的密码散列函数,可以产生出一个128位(16字节)的散列值(hashvalue),用于确保信息传输完整一致。MD5由美国密码学家罗纳德·李维斯特(RonaldLinnRivest)设计,于1992年公开,用以取代MD4算法。MD5应用编辑......
  • Qt 项目实战:幸运转盘
    幸运电子转盘基础绘图通过paintEvent来绘图鼠标事件:鼠标左键单击开始旋转Timer:定时器信号与槽1#ifndefWIDGET_H2#defineWIDGET_H34#include<QWidget>5#include<QEvent>6#include<QDebug>7#include<QTimer>8#include<QTime>9#include&l......
  • OpenResty 介绍与实战讲解(nginx&lua)
    目录一、概述二、OpenResty安装三、OpenResty的工作原理四、OpenResty核心模块1)ngx_lua模块2)ngx_stream_lua模块3)ngx_http_lua_module模块4)ngx_http_headers_more模块5)ngx_http_echo模块6)ngx_http_lua_upstream模块7)ngx_http_redis模块8)ngx_http_proxy_connect_module......
  • FluentFTP实战:轻松操控FTP文件,创造高效传输体验
     概述:通过FluentFTP库,轻松在.NET中实现FTP功能。支持判断、创建、删除文件夹,判断文件是否存在,实现上传、下载和删除文件。简便而强大的FTP操作,提升文件传输效率。在.NET中,使用FluentFTP库可以方便地实现FTP的相关功能。以下是判断文件夹是否存在、文件夹的创建和删除、判断文......