首页 > 其他分享 >深度学习(VAE)

深度学习(VAE)

时间:2024-11-11 22:42:54浏览次数:1  
标签:mu log nn self torch VAE 学习 深度 var

变分自编码器(VAE,Variational Auto-Encoder)是一种生成模型,它通过学习数据的潜在表示来生成新的样本。

在学习潜空间时,需要保持生成样本与真实数据的相似性,并尽量让潜变量的分布接近标准正态分布。

VAE的基本结构:

1. 编码器(Encoder):将输入数据转换为潜在空间的分布,输出潜在变量的均值和方差。

2. 重参数化层(Reparameterization Layer):从编码器输出的均值和方差中进行重参数化采样,生成潜在变量。

3. 解码器(Decoder):接收潜在变量并将其转换回原始数据的分布。

为了让生成样本接近原始数据,最终loss是样本与真实数据相似度和潜变量与标准高斯分布相似度之和。

生成样本和真实数据相似度可以通过mse计算。

潜变量与标准高斯分布相似度可以通过KL散度计算。

下面是两个高斯分布计算KL散度的推导:

设其中一个为标准高斯函数:

下面代码是用FashionMNIST作为数据集,生成样本的示例:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms,datasets
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#dataset = datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
dataset = datasets.FashionMNIST(root='./fasion_data',train=True,transform=transforms.ToTensor(),download=True)

data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=128, 
                                          shuffle=True)

class VAE(nn.Module):
    def __init__(self, image_size=784, h=400, z=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h)
        self.fc2 = nn.Linear(h, z)
        self.fc3 = nn.Linear(h, z)

        self.fc4 = nn.Linear(z, h)
        self.fc5 = nn.Linear(h, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc2(h)
        log_var = self.fc3(h)
        return mu,log_var 
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        reconst_x = F.sigmoid(self.fc5(h))
        return reconst_x
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        reconst_x = self.decode(z)
        return reconst_x, mu, log_var

def loss_function(reconst_x, x, mu, log_var): 
    mse = F.binary_cross_entropy(reconst_x, x, size_average=False)
    kld = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return mse+kld


model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    for i, (x, _) in enumerate(data_loader):

        x = x.to(device).view(-1, 784)
        reconst_x, mu, log_var = model(x)
     
        loss = loss_function(reconst_x,x,mu,log_var) 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print("epoch : ",epoch, "loss:", loss.item())
    
    with torch.no_grad():
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join('./', '{}.png'.format(epoch)))

 结果如下:

标签:mu,log,nn,self,torch,VAE,学习,深度,var
From: https://www.cnblogs.com/tiandsp/p/18453132

相关文章

  • 2024-11-11-Linux学习-基础篇(1)(鸟哥的LINUX私房菜 第四章)
    Linux的学习,也是一本大厚书,学起来。文章目录一、前言二、知识点2.1开始执行命令2.2日期与时间2.3日历2.4计算器2.4重要的热键2.4.1[TAB]2.4.2[Ctrl]-c2.4.3[Ctrl]-d2.4.4[Shift]+{[PageUP]l[PageDown]}按键小结一、前言  Linux命令学习,开始。二、......
  • 并查集+最小生成树 学习笔记+杂题 2
    图论系列:前言:相关题单:戳我算法讲解:戳我CF1829ETheLakes给定一张\(n*m\)的矩阵,询问正整数四联通块权值和的最大值。并查集维护即可,记录一下集合内的点的权值和。代码:constintM=1005;intT,n,m,ans;inta[M][M],fa[M*M],siz[M*M];intfx[5]={0,1,-1,0,0};intfy[5]......
  • (1) Pytorch深度学习—数值处理
    (1)Pytorch——数值处理参考于李沐“动手学深度学习”系列以及网上各路大佬的博客资料,感谢大家的分享,如错改,如侵删。torch中的数值处理数值处理是深度学习中极其重要的一部分,张量(tensor)是后续进行处理和计算的基本单位。张量表示一个由数值组成的数组,这个数组可能有多个维度。......
  • 学习笔记(三十五):[email protected] (线性容器ArrayList)
    概述:一种线性数据结构,底层基于数组实现 一、导入import{ArrayList}from'@kit.ArkTS'; 二、定义letarrayList:ArrayList<string|number>=newArrayList(); 三、常用函数1、add,在ArrayList尾部插入元素 2、insert,在长度范围内任意位置插入指定元素......
  • 学习笔记(三十六):[email protected] (非线性容器HashMap)
    概述:HashMap底层使用数组+链表+红黑树的方式实现,查询、插入和删除的效率都很高。HashMap存储内容基于key-value的键值对映射,不能有重复的key,且一个key只能对应一个value一、导入import{HashMap}from'@kit.ArkTS' 二、定义lethashMap:HashMap<string,number>=ne......
  • 基于Java+SpringBoot+Mysql在线课程学习教育系统功能设计与实现三
    一、前言介绍:[免费获取]1.1项目摘要随着信息技术的飞速发展和互联网的普及,教育领域正经历着深刻的变革。传统的面对面教学模式逐渐受到挑战,而在线课程学习教育系统作为一种新兴的教育形式,正逐渐受到广泛关注和应用。在线课程学习教育系统的出现,不仅为学生提供了更加灵活、便......
  • 基于Java+SpringBoot+Mysql在线课程学习教育系统功能设计与实现四
    一、前言介绍:免费获取:猿来入此1.1项目摘要随着信息技术的飞速发展和互联网的普及,教育领域正经历着深刻的变革。传统的面对面教学模式逐渐受到挑战,而在线课程学习教育系统作为一种新兴的教育形式,正逐渐受到广泛关注和应用。在线课程学习教育系统的出现,不仅为学生提供了更加灵......
  • 机器学习-34-对ML的思考之PAC学习理论和标准数据集对机器学习的影响
    1研究目标1.1科学研究的目标科学研究的目标就是发现有用的知识,以提高人类认识自然改造自然的能力。1.2机器学习的最初目标机器学习最初的目标与科学研究的目标其实是相同的,也是为了发现有用的知识。然而,今天的机器学习研究,与最初的机器学习研究有着非常大的不同。之......
  • min_25筛法学习
    min_25筛学习算法min_25筛是解决如下问题的:设\(f\)为一个积性的数论函数,问求\(\sum_{i=1}^nf(i)\)。其中\(f\)满足若\(i\)为质数那么\(f(i^k)\)可以快速计算。min_25筛算法可以在\(O\left(\frac{n^{\frac34}}{\logn}\right)\)(通常情况下)的时间复杂度内解决......
  • WEB 漏洞 - SQL 注入之 MySQL 注入深度解析
    目录WEB漏洞-SQL注入之MySQL注入深度解析一、从宇宙奇想到SQL注入二、SQL注入原理回顾(一)基本概念(二)以简单PHP代码示例说明三、MySQL注入步骤(一)确定注入点(二)判断注入类型(三)利用注入获取信息或执行恶意操作四、防御MySQL注入的方法(一)使用参数化查询(二)......