def getHighLowFre(image): f = torch.fft.fft2(image) # 计算频率 freqs = torch.fft.fftfreq(image.shape[-1]) # print(freqs) # 设定阈值,用于分离高频和低频信息 threshold = 0.1 # 创建掩码,用于分离高频和低频信息 mask = (freqs.abs() < threshold).float().to(args.device) print(mask) # 应用掩码,分离高频和低频信息 low_freq = torch.fft.ifft2(f * mask) print(low_freq) high_freq = image - low_freq print(high_freq) return high_freq, low_freq
标签:11,image,fft,mask,low,print,freq From: https://www.cnblogs.com/yyhappy/p/17581741.html