变分自编码器Pytorch实现。
1 class VAE(nn.Module): 2 def __init__(self): 3 super(VAE, self).__init__() 4 5 self.fc1 = nn.Linear(784, 400) 6 self.fc21 = nn.Linear(400, 20) 7 self.fc22 = nn.Linear(400, 20) 8 self.fc3 = nn.Linear(20, 400) 9 self.fc4 = nn.Linear(400, 784) 10 11 def encode(self, x): 12 h1 = F.relu(self.fc1(x)) 13 return self.fc21(h1), self.fc22(h1) 14 15 def reparameterize(self, mu, logvar): 16 std = torch.exp(0.5*logvar) 17 eps = torch.randn_like(std) 18 return mu + eps*std 19 20 def decode(self, z): 21 h3 = F.relu(self.fc3(z)) 22 return torch.sigmoid(self.fc4(h3)) 23 24 def forward(self, x): 25 mu, logvar = self.encode(x.view(-1, 784)) 26 z = self.reparameterize(mu, logvar) 27 return self.decode(z), mu, logvar 28 29 def loss_function_original(recon_x, x, mu, logvar): 30 BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') 31 KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 32 return BCE + KLD
CVAE:
https://www.cnblogs.com/amazingter/p/14696251.html
https://www.cnblogs.com/boyknight/p/16290582.html
https://baileyswu.github.io/2019/11/disentangling-disentanglement-in-vae/
https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/116246208
标签:nn,实现,self,VAE,mu,Pytorch,logvar,400,def From: https://www.cnblogs.com/zxcayumi/p/16727691.html