前言
import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
image = torch.randn(1, 256, 256) # 示例,随机生成一个单通道图像
# 将图像张量保存为文件
save_image(image, "single_channel_image.png", normalize=True)
pytorch中通常如上使用torchvision.utils.save_image来保存图片,但是在保存单通道灰度图片时,该函数保存后的图片会是3通道的,虽然每个通道上的数据一样,视觉上也是灰度图片,但后续输入单通道的网络会报错。
方案
此时我们可以用将图片转换成PIL,用它提供的save()方法来保存
import torch
import torchvision.transforms as transforms
from PIL import Image
# 假设你有一个单通道的图像张量image,形状为 [H, W]
image = torch.randn(1, 256, 256) # 示例,随机生成一个单通道图像
# 创建一个转换函数来将图像张量转换为PIL图像
to_pil = transforms.ToPILImage()
# 将图像张量转换为PIL图像
pil_image = to_pil(image)
# 保存PIL图像
pil_image.save("single_channel_image.png")