首页 > 编程语言 >手写数字数据集AutoEncoder降噪算法

手写数字数据集AutoEncoder降噪算法

时间:2023-09-26 22:13:35浏览次数:36  
标签:loss AutoEncoder Linear nn 28 降噪 train 手写 data

对训练数据加噪声的方法,在训练里面对 x 做如下处理,添加椒盐噪声:

        bs, ch, h, w = x.shape
        x = x.reshape(bs, ch, h*w) + 0.2*np.random.normal(size=28*28)
        x = x.to(torch.float32)

数据集里面的标签 label 无用,因为 AutoEncoder 去噪是无监督方法。

一、读取数据

import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

EPOCH = 5
BATCH_SIZE = 64
LR = 0.001
DOWNLOAD_MNIST = True
N_TEST_IMG = 5

train_data = torchvision.datasets.MNIST(
    root='../mnist_data/',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST,
    )

test_data = torchvision.datasets.MNIST(
    root='../mnist_data/',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST,
    )
print(train_data.train_data.size())
print(train_data.train_labels.size())

train_loader=Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE,shuffle=True)
test_loader=Data.DataLoader(dataset=test_data, batch_size=BATCH_SIZE,shuffle=False)

二、前3步:构建模型、设置优化器、损失函数

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
        nn.Linear(28*28, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, 12),
        nn.ReLU(),
        #             nn.Linear(12, 3),
        )
        self.decoder = nn.Sequential(
        #             nn.Linear(3, 12),
        #             nn.Tanh(),
            nn.Linear(12, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 28*28),
        #             nn.Sigmoid(),
        )
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded
AE = AutoEncoder()
optimizer = optim.Adam(AE.parameters(), lr=LR)
loss_func = nn.MSELoss()
# 1 2 3

三、后5步:前向计算、计算损失、no_grad, backward, step,如果有验证集的话,每到一定步数在no_grad下进行验证,不需要zer_grad和backward

for epoch in range(EPOCH):
    for step, (x, _) in enumerate(train_loader):
        bs, ch, h, w = x.shape
        x = x.reshape(bs, ch, h*w) + 0.2*np.random.normal(size=28*28)
        x = x.to(torch.float32)
        # 4 5
        code = AE.encoder(x)  # https://blog.csdn.net/weixin_55191433/article/details/121402942
        recon = AE.decoder(code)
        loss = loss_func(recon, x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 100 == 0:
            print('Epoch:', epoch, ' | train loss: %.4f'%loss.item())

四、查看结果(测试集)

cnt = 16
idx = 0
plt.figure(1)
with torch.no_grad():
    for step, (x, _) in enumerate(test_loader):
        bs, ch, h, w = x.shape
        x = x.reshape(bs, ch, h*w) + 0.2*np.random.normal(size=28*28)
        x = x.to(torch.float32)
        # 4 5
        code = AE.encoder(x)  # https://blog.csdn.net/weixin_55191433/article/details/121402942
        recon = AE.decoder(code)
        print(recon.shape)
        for i in range(16):
            plt.subplot(4,4,step+1)
            img = recon[i].squeeze().reshape(28, 28)
            plt.imshow(img)
        # loss = loss_func(recon, x)
        idx += 1
        if idx == 16:
            break

结果如下:加噪声后,和通过AE去噪后。

 

标签:loss,AutoEncoder,Linear,nn,28,降噪,train,手写,data
From: https://www.cnblogs.com/zhaoke271828/p/17731370.html

相关文章

  • 怎么制作手写电子签名?
    https://zhuanlan.zhihu.com/p/157419337年初在家办公,多次遇到需要在电子版文档上手写签名,以前的我习惯了打印出来再签字,但家里又没有打印机,可难倒了我…直到一个程序员朋友告诉了我几个傻瓜式操作方法,才发现手写电子签名也没那么难嘛。记得当时也去网上寻找解决方式,发现大家都......
  • 手写promise核心代码(一)
     classmyPromise{staticPENDING='pending'staticREJECT='reject'staticRESOLVE='resolve';constructor(executor){this.value=nullthis.status=myPromise.PENDINGtry{executor(this.resolve1.bind......
  • Ansible专栏文章之十六:成就感源于创造,自己动手写Ansible模块
    回到:Ansible系列文章各位读者,请您:由于Ansible使用Jinja2模板,它的模板语法{%raw%}{{}}{%endraw%}和{%raw%}{%%}{%endraw%}和博客系统的模板使用的符号一样,在渲染时会产生冲突,尽管我尽我努力地花了大量时间做了调整,但无法保证已经全部都调整。因此,如果各位阅读时发......
  • speex降噪算法移植及测试
    下载libspeexdspPC上,编译。修改内置demo输入in.pcm,输出out.pcm,用音频分析软件及实测效果OK.#ifdefHAVE_CONFIG_H#include"config.h"#endif#include"speex/speex_preprocess.h"#include<stdio.h>#defineNN160intmain(){  shortin[NN];  inti;  SpeexPre......
  • Kingbase中手写Mysql底层函数DATE_FORMAT()
    Kingbase中手写Mysql底层函数DATE_FORMAT()分析底层函数的实现逻辑MySQL的DATE_FORMAT()函数其底层逻辑涉及多个组件和模块。以下是DATE_FORMAT()函数的大致实现逻辑:解析日期格式字符串:DATE_FORMAT()函数接受两个参数,一个是日期值,另一个是格式字符串。首先,MySQL解析格......
  • 手写Promise
    //excutor:可以理解为传入一个函数为执行器functionmyPromise(excutor){//1.执行结构letself=thisself.status='pending'//状态self.value=null//成功的值self.reason=null//失败原因......
  • 如何高效地进行告警降噪
    在事件处理方面,一般我们会遇到两个痛点,一个是告警事件太多,被过度打扰,另一个是重要告警疏漏,无法闭环处理。告警太多的常见原因最常见的原因,是告警规则设置得不合理。比如很多规则触发了告警之后,实际没有后续动作,只是起到常态化通知的效果,不需要排查,也不需要止损,甚至连个长线的TODO......
  • Python给你一个字符串,你怎么判断是不是ipv4地址?手写这段代码,并写出测试用例【杭州多测
    ipv4地址的格式:(1~255).(0 ~255).(0 ~255).(0 ~255)1.正则表达式importredefcheck_ip(one_str):compile_ip=re.compile('^(([1-9]|[1-9]\d|1\d{2}|2[0-4]\d|25[0-5])\.){3}(\d|[1-9]\d|1\d{2}|2[0-4]\d|25[0-5])$')ifcompile_ip.match(one_str......
  • 自己动手写一个C++日志库
    自己动手写一个C++日志库logger.h////CreatedbyFkkton2023/9/8.//#pragmaonce#include<string>#include<iostream>#include<fstream>#include<chrono>#include<sstream>namespacefkkt{classlogger{public:......
  • 降噪与恢复——各种降噪类型综合讲解
    选择减少混响减少混响之后,声音明显比较薄了如何消除嗡嗡音呢消除嗡嗡音打开之后,嗡嗡的声音就会变得很薄......