import torch from torch import nn import numpy as np import matplotlib.pyplot as plt from PIL import Image from torchvision import transforms from math import sqrt import os import torchvision.utils as vutils os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # 读取两张图像 img1 = Image.open('img/low/1.png') img2 = Image.open('img/low/5.png') # 转换为[N, C, H, W]张量形式 # transform = transforms.Compose([ # transforms.Resize((256, 256)), # transforms.CenterCrop((224, 224)), # transforms.ToTensor() # ]) # if img1.size != img2.size: # new_size = min(img1.size, img2.size) # transform = transforms.Compose([ # transforms.Resize(new_size), # transforms.CenterCrop((224, 224)), # transforms.ToTensor() # ]) # else: # transform = transforms.Compose([ # transforms.Resize((256, 256)), # transforms.CenterCrop((224, 224)), # transforms.ToTensor() # ]) # # # img1 = transform(img1).unsqueeze(0) # 添加批次维(N=1) # img2 = transform(img2).unsqueeze(0) # 添加批次维(N=1) if img1.size != img2.size: new_size = min(img1.size, img2.size) transform = transforms.Compose([ transforms.Resize(new_size), transforms.CenterCrop((224, 224)), transforms.ToTensor() ]) img1 = transform(img1).unsqueeze(0) img2 = transform(img2).unsqueeze(0) else: transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop((224, 224)), transforms.ToTensor() ]) img1 = transform(img1).unsqueeze(0) img2 = transform(img2).unsqueeze(0) # assert img1.size() == img2.size() # _, c, h, w = img1.size() # h_crop = int(h * sqrt(1.0)) # w_crop = int(w * sqrt(1.0)) # print(h_crop) # print(w_crop) # h_start = h // 2 - h_crop // 2 # print(h_start) # w_start = w // 2 - w_crop // 2 # print(w_start) lam = 1 # np.random.uniform(0, 1.0) img1_fft = torch.fft.fft2(img1, dim=[2, 3]) img2_fft = torch.fft.fft2(img2, dim=[2, 3]) img1_abs, img1_pha = torch.abs(img1_fft), torch.angle(img1_fft) img2_abs, img2_pha = torch.abs(img2_fft), torch.angle(img2_fft) img1_abs = torch.fft.fftshift(img1_abs, dim=[2, 3]) img2_abs = torch.fft.fftshift(img2_abs, dim=[2, 3]) img1_abs_ = img1_abs.clone() img2_abs_ = img2_abs.clone() # img1_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = lam * img2_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img1_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] # img2_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = lam * img1_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img2_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] img1_abs = lam * img2_abs_ + (1 - lam) * img1_abs_ img2_abs = lam * img1_abs_ + (1 - lam) * img2_abs_ img1_abs = torch.fft.ifftshift(img1_abs, dim=[2, 3]) img2_abs = torch.fft.ifftshift(img2_abs, dim=[2, 3]) img21 = img1_abs * (torch.exp(1j * img1_pha)) img12 = img2_abs * (torch.exp(1j * img2_pha)) img21 = torch.real(torch.fft.ifft2(img21, dim=[2, 3])) img12 = torch.real(torch.fft.ifft2(img12, dim=[2, 3])) # img21 = torch.clamp(img21, 0, 255).to(torch.uint8) # img12 = torch.clamp(img12, 0, 255).to(torch.uint8) # img21 = img21.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) # img12 = img12.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) vutils.save_image(img21,'img/hecheng.png') img21 = torch.clamp(img21, 0, 1) * 255.0 img12 = torch.clamp(img12, 0, 1) * 255.0 print(img21.shape) img21 = img21.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) img12 = img12.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) # 展示原始图像和重构图像 plt.subplot(221), plt.imshow(img1[0].permute(1, 2, 0)), plt.title('Original Image 1') plt.axis('off') plt.subplot(222), plt.imshow(img2[0].permute(1, 2, 0)), plt.title('Original Image 2') plt.axis('off') plt.subplot(223), plt.imshow(img21), plt.title('Reconstruct Image 1') plt.axis('off') plt.subplot(224), plt.imshow(img12), plt.title('Reconstruct Image 2') plt.axis('off') # plt.show() plt.savefig('mix', bbox_inches='tight')
标签:torch,start,abs,222222222222,transforms,img2,img1 From: https://www.cnblogs.com/yyhappy/p/17497557.html