GAN初步-生成1010格式规律的向量
构建和训练GAN的推荐步骤:
(1)从真实数据集预 览数据;
(2)测试鉴别器至少具备从随机噪声中区分 真实数据的能力;
(3)测试未经训练的生成器能否创 建正确格式的数据;
(4)可视化观察损失值,了解训 练进展。
#真实的数据源
import torch
import torch.nn as nn
import pandas
import matplotlib.pyplot as plt
import random
import numpy
def synthetic_data():
real_data = torch.FloatTensor([
random.uniform(0.8,1.0),
random.uniform(0.0,0.1),
random.uniform(0.8,0.9),
random.uniform(0.0,0.1)
])
return real_data
#Generator
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(1,3),
nn.Sigmoid(),
nn.Linear(3,4),
nn.Sigmoid())
# self.loss_function = nn.MSELoss()
self.optimiser = torch.optim.SGD(self.parameters(),lr=0.01)
self.counter = 0
self.progress = []
pass
def forward(self,inputs):
return self.model(inputs)
def train(self,D,inputs,targets):
g_outputs = self.forward(inputs)
d_output = D.forward(g_outputs)
loss = D.loss_function(d_output,targets)
self.counter += 1
if (self.counter % 10 == 0):
self.progress.append(loss.item())
pass
if (self.counter % 10000 == 0):
# print('countetr = ',self.counter)
pass
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
pass
def plot_progress(self):
df = pandas.DataFrame(self.progress,columns=['loss'])
df.plot(ylim=(0,1.0),figsize=(16,8),alpha=0.1,marker='.',
grid=True,yticks=(0,0.25,0.5))
pass
#descriminator
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(4,3),
nn.Sigmoid(),
nn.Linear(3,1),
nn.Sigmoid())
self.loss_function = nn.MSELoss()
self.optimiser = torch.optim.SGD(self.parameters(),lr=0.01)
self.counter = 0
self.progress = []
pass
def forward(self,inputs):
return self.model(inputs)
def train(self,inputs,targets):
outputs = self.forward(inputs)
loss = self.loss_function(outputs,targets)
self.counter += 1
if (self.counter % 10 == 0):
self.progress.append(loss.item())
pass
if (self.counter % 10000 == 0):
print('countetr = ',self.counter)
pass
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
pass
def plot_progress(self):
df = pandas.DataFrame(self.progress,columns=['loss'])
df.plot(ylim=(0,1.0),figsize=(16,8),alpha=0.1,marker='.',
grid=True,yticks=(0,0.25,0.5))
pass
# 记录训练过程
image_list=[]
for i in range(10000):
D.train(synthetic_data(),torch.FloatTensor([1.0]))
D.train(G.forward(torch.FloatTensor([0.5])).detach(),torch.FloatTensor([0.0]))
G.train(D, torch.FloatTensor([0.5]), torch.FloatTensor([1.0]))
if i%1000 == 0:
image_list.append(G.forward(torch.FloatTensor([0.5])))
# G.train(D,torch.FloatTensor([0.5]),torch.FloatTensor([1.0]))
pass
image_list_ = []
for i in range(len(image_list)):
image_list_.append(image_list[i].detach().numpy())
plt.imshow(numpy.array(image_list_).T,interpolation='none',cmap='Blues')
标签:loss,counter,nn,self,torch,GAN,pass,1010,向量
From: https://www.cnblogs.com/afengblog/p/16794588.html