首页 > 其他分享 >深度学习从入门到精通——pytorch实现生成手写数字

深度学习从入门到精通——pytorch实现生成手写数字

时间:2022-11-01 18:03:59浏览次数:54  
标签:loss 入门 img torch nn pytorch fake net 手写


网络构建

该版本,网络全程采用全连接网络,激活函数采用leakyReLU

from torch import nn

class D_Net(nn.Module):

def __init__(self):
super().__init__()
self.dnet = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(),
nn.Linear(512, 256),
nn.LeakyReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
)

def forward(self, x):
out = self.dnet(x)
return out

class G_Net(nn.Module):
def __init__(self):
super().__init__()
self.gnet = nn.Sequential(
nn.Linear(128,256),
nn.LeakyReLU(),
nn.Linear(256,512),
nn.LeakyReLU(),
nn.Linear(512,784)
)
def forward(self, x):
out = self.gnet(x)
return out

模型训练

from torch.utils.data import DataLoader
from torchvision import transforms,datasets
from torchvision.utils import save_image
import os
import torch
from torch import nn
from model import D_Net,G_Net

if __name__ == '__main__':
batch_size = 100
num_epoch = 100
if not os.path.exists("img"):
os.makedirs("img")
if not os.path.exists("./params"):
os.mkdir("./params")
mnist_data = datasets.MNIST("/data", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(mnist_data, batch_size, shuffle=True)

d_net = D_Net().cuda()
g_net = G_Net().cuda()

if os.path.exists("./params/d_net.pth"):
d_net.load_state_dict(torch.load("./params/d_net.pth"))
if os.path.exists("./params/g_net.pth"):
g_net.load_state_dict(torch.load("./params/g_net.pth"))

loss_fun = nn.BCELoss()
d_opt = torch.optim.Adam(d_net.parameters(),lr=0.0002,betas=(0.5, 0.999))
g_opt = torch.optim.Adam(g_net.parameters(),lr=0.0002,betas=(0.5, 0.999))
k = 0
for epoch in range(num_epoch):
for i,(img,label) in enumerate(train_loader):
real_img = img.reshape(-1,784).cuda()
#生成真标签1和假标签0
real_label = torch.ones(img.size(0),1).cuda()
fake_label = torch.zeros(img.size(0),1).cuda()
#训练判别器判断真图片
real_out = d_net(real_img)
d_loss_real = loss_fun(real_out,real_label)
#训练判别器判断假图片
z = torch.randn(img.size(0),128).cuda()
fake_img = g_net(z)
fake_out = d_net(fake_img)
d_loss_fake = loss_fun(fake_out,fake_label)

d_loss = d_loss_real+d_loss_fake
d_opt.zero_grad()
d_loss.backward()
d_opt.step()

#训练生成器#
z = torch.randn(img.size(0),128).cuda()
fake_img = g_net(z)
g_fake_out = d_net(fake_img)
g_loss = loss_fun(g_fake_out,real_label)

g_opt.zero_grad()
g_loss.backward()
g_opt.step()

if i%10 == 0:
print("Epoch:{0},d_loss{1},g_loss{2}".format(epoch,d_loss,g_loss))
real_img = real_img.reshape(-1,1,28,28)
fake_img = fake_img.reshape(-1,1,28,28)
save_image(real_img,"img/{}-real_img.jpg".format(k),nrow=10,normalize=True,scale_each=True)
save_image(fake_img, "img/{}-fake_img.jpg".format(k), nrow=10, normalize=True, scale_each=True)
torch.save(d_net.state_dict(), "./params/d_net.pth")
torch.save(g_net.state_dict(), "./params/g_net.pth")
k+=1

模型运行

from torchvision.utils import save_image
import os
import torch

from model import G_Net

if __name__ == '__main__':
batch_size = 100
num_epoch = 10
if not os.path.exists("test_img"):
os.makedirs("test_img")
if not os.path.exists("./params"):
os.mkdir("./params")
g_net = G_Net().cuda()

if os.path.exists("./params/g_net.pth"):
g_net.load_state_dict(torch.load("./params/g_net.pth"))


for i in range(num_epoch):

z = torch.randn(batch_size, 128).cuda()
fake_img = g_net(z)
fake_img = fake_img.reshape(-1, 1, 28, 28)
save_image(fake_img, "test_img/{}-fake_img.jpg".format(i), nrow=10, normalize=True, scale_each=True)
print(i)

训练过程图

深度学习从入门到精通——pytorch实现生成手写数字_全连接

阶段过程图呈现

深度学习从入门到精通——pytorch实现生成手写数字_pytorch_02

生成结果

深度学习从入门到精通——pytorch实现生成手写数字_python_03


标签:loss,入门,img,torch,nn,pytorch,fake,net,手写
From: https://blog.51cto.com/u_13859040/5814622

相关文章