首页 > 编程问答 >使用python图像去噪没有获得所需的重建图像

使用python图像去噪没有获得所需的重建图像

时间:2024-07-22 07:39:05浏览次数:19  
标签:python machine-learning pytorch neural-network gradio

我是 python 机器学习的初学者,我正在编写一个程序,使图像变得嘈杂,然后我的程序输出重建的图像。我正在使用加性高斯白噪声并使用前馈神经网络。我的程序显示真实图像、噪声图像和重建图像。这些是我通常得到的结果。 enter image description here

有人知道如何解决这样的问题吗?这是我的代码:

app.py

from flask_cors import CORS
import gradio as gr
from mnist_denoising import train_model

app = Flask(__name__)
CORS(app)

@app.route('/train', methods=['POST'])
def train():
    try:
        data = request.get_json()
        dataset_name = data['dataset_name']
        noise_level = data['noise_level']
        model_type = "Feed Forward Neural Network"  # Since only one reconstruction model is available

        true_image, noisy_image, reconstructed_image = train_model(dataset_name, noise_level, model_type)
        return jsonify({
            "true_image": true_image,
            "noisy_image": noisy_image,
            "reconstructed_image": reconstructed_image
        })
    except Exception as e:
        return jsonify({"error": str(e)}), 400

def gradio_interface():
    def gradio_train_model(dataset_name, simulation, reconstruction_model, noise_level):
        true_image, noisy_image, reconstructed_image = train_model(dataset_name, noise_level, reconstruction_model)
        return true_image, noisy_image, reconstructed_image

    iface = gr.Interface(
        fn=gradio_train_model,
        inputs=[
            gr.Dropdown(choices=["MNIST", "CIFAR10", "PathMNIST", "DermaMNIST", "OrganAMNIST"], label="Dataset"),
            gr.Dropdown(choices=["AWGN"], label="Simulation Model"),
            gr.Dropdown(choices=["Feed Forward Neural Network"], label="Reconstruction Model"),
            gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Noise Level")
        ],
        outputs=[gr.Image(label="True Image"), gr.Image(label="Noisy Image"), gr.Image(label="Reconstructed Image")]
    )

    iface.launch()

if __name__ == '__main__':
    gradio_interface()
    app.run(debug=True)

model.py

import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from medmnist import INFO, Evaluator
from medmnist.dataset import PathMNIST, DermaMNIST, OrganAMNIST

class DenseNet(nn.Module):
    def __init__(self, input_shape, output_shape, hidden_channels_list, activation='relu'):
        super(DenseNet, self).__init__()
        self.input_shape = input_shape
        self.output_shape = output_shape

        self.input_size = int(torch.prod(torch.tensor(input_shape)))
        self.output_size = int(torch.prod(torch.tensor(output_shape)))

        layers = []
        if isinstance(activation, str):
            activation = [activation] * len(hidden_channels_list)
        elif isinstance(activation, list) and len(activation) != len(hidden_channels_list):
            raise ValueError("Length of activation functions list must match the length of hidden_channels_list")

        in_size = self.input_size
        for out_size, act in zip(hidden_channels_list, activation):
            layers.append(nn.Linear(in_size, out_size))
            layers.append(nn.BatchNorm1d(out_size))
            layers.append(self.get_activation(act))
            in_size = out_size

        layers.append(nn.Linear(in_size, self.output_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        batch_shape = x.shape[:-len(self.input_shape)]
        batch_size = int(torch.prod(torch.tensor(batch_shape)))
        x = x.view(batch_size, self.input_size)  # Ensure the input is reshaped correctly

        for layer in self.model:
            if isinstance(layer, nn.BatchNorm1d) and batch_size == 1:
                continue  # Skip BatchNorm layers for batch size 1
            x = layer(x)

        x = x.view(*batch_shape, *self.output_shape)
        return x

    def get_activation(self, activation):
        if activation == 'relu':
            return nn.ReLU()
        elif activation == 'prelu':
            return nn.PReLU()
        elif activation == 'leaky_relu':
            return nn.LeakyReLU()
        elif activation == 'sigmoid':
            return nn.Sigmoid()
        elif activation == 'tanh':
            return nn.Tanh()
        else:
            raise ValueError(f"Unsupported activation function: {activation}")

def get_densenet_model(model_type, input_shape):
    if model_type == "DenseNetSmall":
        return DenseNet(input_shape=input_shape, output_shape=input_shape, hidden_channels_list=[1024])
    elif model_type == "DenseNetLarge":
        return DenseNet(input_shape=input_shape, output_shape=input_shape, hidden_channels_list=[1024, 2048, 1024])
    elif model_type == "Feed Forward Neural Network":
        return DenseNet(input_shape=input_shape, output_shape=input_shape, hidden_channels_list=[1024, 2048, 1024])
    else:
        raise ValueError("Unsupported model type")

def get_dataset(dataset_name):
    transform = transforms.Compose([transforms.ToTensor()])
    root_dir = 'data'
    
    if dataset_name == "MNIST":
        return datasets.MNIST(root=os.path.join(root_dir, 'mnist'), train=True, transform=transform, download=True)
    elif dataset_name == "CIFAR10":
        return datasets.CIFAR10(root=os.path.join(root_dir, 'cifar10'), train=True, transform=transform, download=True)
    elif dataset_name == "PathMNIST":
        dataset_dir = os.path.join(root_dir, 'pathmnist')
        os.makedirs(dataset_dir, exist_ok=True)
        return PathMNIST(root=dataset_dir, split='train', transform=transform, download=True)
    elif dataset_name == "DermaMNIST":
        dataset_dir = os.path.join(root_dir, 'dermamnist')
        os.makedirs(dataset_dir, exist_ok=True)
        return DermaMNIST(root=dataset_dir, split='train', transform=transform, download=True)
    elif dataset_name == "OrganAMNIST":
        dataset_dir = os.path.join(root_dir, 'organamnist')
        os.makedirs(dataset_dir, exist_ok=True)
        return OrganAMNIST(root=dataset_dir, split='train', transform=transform, download=True)
    else:
        raise ValueError("Unsupported dataset")

def add_awgn_noise(image, noise_level=0.1):
    noise = noise_level * torch.randn_like(image)
    noisy_image = image + noise
    return noisy_image

mnist_denoising.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from model import get_densenet_model, add_awgn_noise, get_dataset
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_model(dataset_name, noise_level, model_type, epochs=300, batch_size=64, learning_rate=1e-4):
    dataset = get_dataset(dataset_name)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    if dataset_name == "MNIST":
        input_shape = (1, 28, 28)
    elif dataset_name == "CIFAR10":
        input_shape = (3, 32, 32)
    elif dataset_name in ["PathMNIST", "DermaMNIST", "OrganAMNIST"]:
        input_shape = (3, 28, 28)

    model = get_densenet_model(model_type, input_shape).to(device)  

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss()

    for epoch in range(epochs):
        model.train()
        for images, _ in dataloader:
            images = images.to(device)
            noisy_images = add_awgn_noise(images, noise_level).to(device)

            outputs = model(noisy_images)
            loss = loss_fn(outputs, images)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

    # Get a sample image from the dataset
    sample_image, _ = dataset[0]
    noisy_sample_image = add_awgn_noise(sample_image.unsqueeze(0).to(device), noise_level)
    reconstructed_image = model(noisy_sample_image).detach().squeeze(0)

    # Transform the tensor to a displayable image format
    to_pil = transforms.ToPILImage()
    sample_image_pil = to_pil(sample_image)
    noisy_image_pil = to_pil(noisy_sample_image.cpu().squeeze(0))
    reconstructed_image_pil = to_pil(reconstructed_image.cpu())

    return sample_image_pil, noisy_image_pil, reconstructed_image_pil

代码中有几个地方可以改进,以提高图像去噪性能。

  1. 网络架构: 虽然你使用的全连接网络(DenseNet)可以用于图像去噪,但通常情况下,卷积神经网络 (CNN) 更适合这项任务。CNN 能更好地捕捉图像中的空间信息,因为它们使用的是卷积核而不是全连接层。你可以尝试使用一个简单的基于 CNN 的自动编码器来进行图像去噪。

  2. 损失函数: 均方误差 (MSE) 损失函数对于图像去噪来说可能不是最佳选择,因为它可能导致图像模糊。考虑使用其他损失函数,例如结构相似性指数度量 (SSIM) 或感知损失,这些损失函数更能捕捉人类对图像质量的感知。

  3. 超参数调优: 模型的性能可能会受到学习率、批大小和 epochs 数量等超参数的显著影响。尝试使用网格搜索或随机搜索等技术来找到适合你特定问题的最佳超参数。

  4. 数据增强: 通过对训练数据应用随机变换(例如旋转、翻转和裁剪),可以增强模型的泛化能力并提高其性能。

  5. 噪声模型: 你正在使用加性高斯白噪声 (AWGN),但其他类型的噪声(例如椒盐噪声或泊松噪声)可能更符合你的数据。根据你的数据使用适当的噪声模型可以提高去噪性能。

以下是使用简单的基于 CNN 的自动编码器进行图像去噪的修改后的代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

# 定义基于 CNN 的自动编码器
class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        # 编码器
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        # 解码器
        self.t_conv1 = nn.ConvTranspose2d(32, 16, 2, stride=2)
        self.t_conv2 = nn.ConvTranspose2d(16, 1, 2, stride=2)

    def forward(self, x):
        # 编码器
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        # 解码器
        x = torch.relu(self.t_conv1(x))
        x = torch.sigmoid(self.t_conv2(x))
        return x

# 超参数
epochs = 10
batch_size = 64
learning_rate = 1e-3
noise_level = 0.5

# 加载 MNIST 数据集
train_dataset = MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 初始化模型、优化器和损失函数
model = ConvAutoencoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.MSELoss()

# 训练循环
for epoch in range(epochs):
    for images, _ in train_loader:
        images = images.to(device)
        noisy_images = images + noise_level * torch.randn_like(images)
        outputs = model(noisy_images)
        loss = loss_fn(outputs, images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

# 测试模型
with torch.no_grad():
    sample_image, _ = train_dataset[0]
    sample_image = sample_image.unsqueeze(0).to(device)
    noisy_sample_image = sample_image + noise_level * torch.randn_like(sample_image)
    reconstructed_image = model(noisy_sample_image)

# 显示结果
# ...

这些更改应有助于提高图像去噪模型的性能。请记住,图像去噪是一个具有挑战性的问题,可能需要大量实验才能获得令人满意的结果。

标签:python,machine-learning,pytorch,neural-network,gradio
From: 78776540

相关文章

  • 使用 pip 22.3.1 和 python 3.11.0 安装 MetaTrader5 错误
    我正在尝试使用pip在Windows上安装MetaTrader5。python--versionPython3.11.0pip--versionpip22.3.1pipinstallMetaTrader5ERROR:CouldnotfindaversionthatsatisfiestherequirementMetaTrader5(fromversions:none)ERROR:Nomatchingdistribu......
  • 在 Python 中溶解线条
    我有一个包含多行的形状文件。我正在寻找一种方法来消除所有的接触线。这在ArcMap中是可能的,但似乎在Python和QGIS中都无法做到:之前:所需的输出:这需要在多行上完成,因此像QGIS合并一样手动执行不是一个选项。在ArcMap中,我曾经使用“溶解”......
  • 一个简单的问题(python、串行通信和arduinos)
    只是一个关于小脚本的快速问题,由于某种原因无法工作我运行了这个脚本:importserialimporttimeimportturtledefserialreading():serialPort=serial.Serial(port="COM5",baudrate=9600,bytesize=8,timeout=2,stopbits=serial.STOPBITS_ONE......
  • 我在 Windows 10 上运行 Python 代码后控制台立​​即关闭
    虽然我在代码末尾使用input(),但在Windows10(IDLE之外)的窗口中输入名称后,控制台仍然立即关闭,并且我看不到结果。我该怎么做才能阻止控制台关闭?#!python3name=input('Enteryourname:')print('Hello'+name)input('pressEntertoexit:')你在代码末尾......
  • 具有未定义嵌套列表深度的嵌套列表的Python注释类型
    [[1,2,3],3,[2,4],5]的类型是list[list[int]|int]但是,如果嵌套列表具有未定义的深度,例如[[[1,2]],2,[1,[3,[3]]]],3,[2,[2]]],那么它会具有什么类型?可以使用递归类型提示来表示任意深度嵌套的列表:fromtypingimportList,Union......
  • 在Spyder(Python 3.6)中导入cv2时出现导入错误
    我已经在Windows操作系统中安装了opencv3.0.0。我已运行该应用程序并已成功将其安装在C:\驱动器中,并且还将cv2.pyd文件复制到C:\Python27\Lib\site-packages中,正如我在几个教程视频中看到的那样在我的Python2.7.5Shell中,即当我键入>>>i......
  • Python + VS Code 调试器:可视化我的程序当前正在执行的源代码中的位置?
    当我使用VSCodePython调试器时:我可以执行我的程序,以便编辑器将我逐行带到源代码中的任何位置(跳转到相关文件/如果尚未打开则将其打开)目前的执行情况是?是的,VSCode的Python调试器可以让你逐行执行代码,并实时显示当前执行的位置。以下是操作方法:1.设置断点:......
  • 如何立即取消使用 Ollama Python 库生成答案的 Asyncio 任务?
    我正在使用Ollama通过OllamaPythonAPI从大型语言模型(LLM)生成答案。我想通过单击停止按钮取消响应生成。问题在于,只有当响应生成已经开始打印时,任务取消才会起作用。如果任务仍在处理并准备打印,则取消不起作用,并且无论如何都会打印响应。更具体地说,即使单击按钮后,此函数......
  • 使用 np.array 索引过滤 Python 列表
    谁能向我解释一下列表self.data如何使用numpy索引数组selec_idx进行索引/过滤?完整的课程可在https://github.com/kaidic/LDAM-DRW/blob/master/imbalance_cifar.pydefgen_imbalanced_data(self,img_num_per_cls):new_data=[]n......
  • 尝试在cmd(python)中安装turtle
    当我在cmd中编写pipinstallturtle后,这条消息出现在我面前。有人知道该怎么办吗?C:\>pipinstallturtleCollectingturtleUsingcachedturtle-0.0.2.tar.gz(11kB)ERROR:Commanderroredoutwithexitstatus1:command:'C:\Users\Bruger\App......