首页 > 编程问答 >一维变分自动编码器的错误重建

一维变分自动编码器的错误重建

时间:2024-08-06 03:45:44浏览次数:11  
标签:python machine-learning deep-learning neural-network vae

我想实现一个变分自动编码器,它将一维 Numpy 数组(声音文件的波形)作为输入。运行该文件不会引发错误,但损失收敛到 2000 左右,并且重建看起来像纯粹的噪声。

我使用了 Goffinet 等人的代码 并尝试重写它以采用一维输入,因为我之前使用过他们的(二维)VAE。这是网络和转发功能的代码:

def _build_network(self):
        # Encoder
        self.conv1 = nn.Conv1d(1, 8, 3,1,padding=1)
        self.conv2 = nn.Conv1d(8, 8, 3,2,padding=1)
        self.conv3 = nn.Conv1d(8, 16,3,1,padding=1)
        self.conv4 = nn.Conv1d(16,16,3,2,padding=1)
        self.conv5 = nn.Conv1d(16,24,3,1,padding=1)
        self.conv6 = nn.Conv1d(24,24,3,2,padding=1)
        self.conv7 = nn.Conv1d(24,32,3,1,padding=1)
        self.bn1 = nn.BatchNorm1d(1)
        self.bn2 = nn.BatchNorm1d(8)
        self.bn3 = nn.BatchNorm1d(8)
        self.bn4 = nn.BatchNorm1d(16)
        self.bn5 = nn.BatchNorm1d(16)
        self.bn6 = nn.BatchNorm1d(24)
        self.bn7 = nn.BatchNorm1d(24)
        self.fc1 = nn.Linear(1800,1024)
        self.fc2 = nn.Linear(1024,256)
        self.fc31 = nn.Linear(256,64)
        self.fc32 = nn.Linear(256,64)
        self.fc33 = nn.Linear(256,64)
        self.fc41 = nn.Linear(64,self.z_dim)
        self.fc42 = nn.Linear(64,self.z_dim)
        self.fc43 = nn.Linear(64,self.z_dim)
        # Decoder
        self.fc5 = nn.Linear(self.z_dim,64)
        self.fc6 = nn.Linear(64,256)
        self.fc7 = nn.Linear(256,1024)
        self.fc8 = nn.Linear(1024,1800)
        self.convt1 = nn.ConvTranspose1d(32,24,3,1,padding=1)
        self.convt2 = nn.ConvTranspose1d(24,24,3,2,padding=1,output_padding=1)
        self.convt3 = nn.ConvTranspose1d(24,16,3,1,padding=1)
        self.convt4 = nn.ConvTranspose1d(16,16,3,2,padding=1,output_padding=1)
        self.convt5 = nn.ConvTranspose1d(16,8,3,1,padding=1)
        self.convt6 = nn.ConvTranspose1d(8,8,3,2,padding=1,output_padding=1)
        self.convt7 = nn.ConvTranspose1d(8,1,3,1,padding=1)
        self.bn8 = nn.BatchNorm1d(32)
        self.bn9 = nn.BatchNorm1d(24)
        self.bn10 = nn.BatchNorm1d(24)
        self.bn11 = nn.BatchNorm1d(16)
        self.bn12 = nn.BatchNorm1d(16)
        self.bn13 = nn.BatchNorm1d(8)
        self.bn14 = nn.BatchNorm1d(8)

    def encode(self, x):
        #print("encoder x:",x.shape)
        x = x.unsqueeze(1)
        #nn.Flatten(x)
        #print("encoder x:",x.shape)
        x = F.relu(self.conv1(self.bn1(x)))
        x = F.relu(self.conv2(self.bn2(x)))
        x = F.relu(self.conv3(self.bn3(x)))
        x = F.relu(self.conv4(self.bn4(x)))
        x = F.relu(self.conv5(self.bn5(x)))
        x = F.relu(self.conv6(self.bn6(x)))
        x = F.relu(self.conv7(self.bn7(x)))
        #print(" x:",x.shape)
        x = x.view(-1, 1800)
        #print(" x:",x.shape)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mu = F.relu(self.fc31(x))
        mu = self.fc41(mu)
        u = F.relu(self.fc32(x))
        u = self.fc42(u).unsqueeze(-1) # Last dimension is rank \Sigma = 1.
        d = F.relu(self.fc33(x))
        d = torch.exp(self.fc43(d)) # d must be positive.
        return mu, u, d
    
    def decode(self, z):
        #print(z.shape)
        z = F.relu(self.fc5(z))
        z = F.relu(self.fc6(z))
        z = F.relu(self.fc7(z))
        z = F.relu(self.fc8(z))
        #print("z shape before view", z.shape)
        z = z.view(-1,32,1800)#16,16)
        #print("z shape after view", z.shape)
        z = F.relu(self.convt1(self.bn8(z)))
        z = F.relu(self.convt2(self.bn9(z)))
        z = F.relu(self.convt3(self.bn10(z)))
        z = F.relu(self.convt4(self.bn11(z)))
        z = F.relu(self.convt5(self.bn12(z)))
        z = F.relu(self.convt6(self.bn13(z)))
       
        z = self.convt7(self.bn14(z))
        #print("z shape end decode", z.shape)
        return z.view(-1, 14400)
    
    def forward(self, x, return_latent_rec=False):
        #print("forward", x.shape)
        mu, u, d = self.encode(x)
        #print(mu.shape,u.shape,d.shape)
        latent_dist = LowRankMultivariateNormal(mu, u, d)
        z = latent_dist.rsample()
        #print("z:",z.shape)
        x_rec = self.decode(z)
        # E_{q(z|x)} p(z)
        elbo = -0.5 * (torch.sum(torch.pow(z,2)) + self.z_dim * np.log(2*np.pi))
        # E_{q(z|x)} p(x|z)
        # pxz_term = -0.5 * X_DIM * (np.log(2*np.pi/self.model_precision))
        pxz_term = -0.5 * 14400 * (np.log(2*np.pi/self.model_precision))
        #print("x(input)",x.shape)
        #print("x_rec(vae)",x_rec.shape)
        l2s = torch.sum(torch.pow(x.view(x.shape[0],-1) - x_rec, 2), dim=1)
        pxz_term = pxz_term - 0.5 * self.model_precision * torch.sum(l2s)
        elbo = elbo + pxz_term
        # H[q(z|x)]
        elbo = elbo + torch.sum(latent_dist.entropy())
        if return_latent_rec:
            return -elbo, z.detach().cpu().numpy(), \
                x_rec.view(-1, 128, 128).detach().cpu().numpy()
        return -elbo

我的重构图如下所示:

如何改进我的 VAE?顺便说一句,使用任何其他方式来显示声音(如频谱图)而不是波形并不是一个可悲的选择。

我已经玩过潜在维度并注意到损失略有改善,但仍然不令人满意(1800 而不是2000)而重建工作根本没有改善


在尝试改进的 1D VAE 以用于音频波形重建时,以下是一些需要考虑的领域:

  1. 输入表示:

  2. 归一化: 确保的输入波形数据在 [-1, 1] 或 [0, 1] 范围内进行归一化。这可以通过以下方式实现: ```python from sklearn.preprocessing import MinMaxScaler

    scaler = MinMaxScaler(feature_range=(-1, 1)) # 将 scaler 拟合到训练数据并进行变换 normalized_waveform = scaler.fit_transform(waveform.reshape(-1, 1)).flatten() ``` * 音频预处理: 考虑使用音频领域知识进行预处理。例如,可以提高模型性能的操作包括转换为频谱图(即使不能将它们用于可视化)、应用 Mel 滤波器组或使用 MFCC(Mel 频率倒谱系数)。

  3. 架构:

  4. 内核大小和步幅: 对于 1D 卷积,请尝试不同的内核大小(例如,5、7 或 9)和步幅。较大的内核可以捕获更多上下文,但也会增加计算成本。

  5. 网络深度: 尝试在编码器和解码器中添加更多卷积层以增加模型的容量。这可能有助于捕获复杂的非线性关系。
  6. 跳过连接: 在编码器和解码器中的相应层之间引入跳过连接(如 U-Net 中)可以帮助在解码过程中保留信息。
  7. 激活函数: ReLU 激活在编码器中可能有效,但在解码器中,考虑使用其他激活函数,如 tanh 或 sigmoid,以匹配波形的范围。

  8. 损失函数:

  9. 重建损失: 虽然均方误差 (MSE) 是一种常见的重建损失,但对于音频,考虑使用感知损失函数,例如皮尔逊相关系数或结构相似性指数 (SSIM),它们更好地与人类对音频质量的感知相一致。

  10. 训练:

  11. 批量大小和学习率: 尝试不同的批量大小和学习率。较小的批量大小和学习率可以提高收敛性,但训练时间会更长。

  12. 优化器: 尝试不同的优化器,例如 Adam 或 RMSprop,以查看哪种优化器最适合的模型。
  13. 早期停止: 实施早期停止以防止过度拟合。如果验证损失停止改善,请停止训练。

  14. 其他技巧:

  15. 潜在空间正则化: 增加 KL 散度系数以鼓励更强的潜在空间正则化。

  16. 更丰富的潜在分布: 探索使用更复杂的潜在分布,例如高斯混合模型,而不是简单的各向同性高斯分布。

调试:

  • 从一个更小的数据集开始: 从一个小的音频样本子集开始,以确保的模型可以过度拟合到这个小子集。
  • 可视化重建: 随着训练的进行,定期绘制重建结果和原始音频,以观察是否存在任何有意义的变化。
  • 监控潜在空间: 可视化潜在空间以查看训练数据是否已很好地编码。

请记住,训练 VAE 可能具有挑战性,并且通常需要大量的实验来找到适合特定数据集和任务的最佳超参数和架构。

标签:python,machine-learning,deep-learning,neural-network,vae
From: 78834575

相关文章

  • 【香橙派系列教程】(七)香橙派下的Python3安装
    【七】香橙派下的Python3安装为接下来的Linux图像识别智能垃圾桶做准备。图像处理使用京东SDK只支持pyhton和Java接口,目的是引入C语言的Python调用,感受大厂做的算法bug此接口是人工智能接口,京东识别模型是通过训练后的模型,精准度取决于训练程度,人工智能范畴在常规嵌入式......
  • vnpy,一个不可思议的Python库!
    vn.py是一个开源的Python交易编程框架,旨在帮助程序员快速搭建属于自己的量化交易平台。该框架支持股票、期货、外汇等多种金融产品的交易,提供了从数据获取、策略开发到交易执行的全流程支持。如何安装vnpy首先,要使用vnpy,您需要通过Python的包管理工具pip来安装它。以下......
  • Python回溯算法
    回溯算法回溯算法是一种系统的搜索算法,用于解决诸如排列组合、子集生成、图的路径、棋盘问题等问题。其核心思想是通过递归尝试各种可能的解决方案,遇到不满足条件的解时则回退(回溯),继续尝试其他可能性,直到找到所有的解决方案或确认无解。主要步骤:选择路径:在当前步骤选择一个可......
  • [python]使用gunivorn部署fastapi服务
    前言Gunicorn是一种流行的WSGIHTTP服务器,常用于部署Django和Flask等PythonWeb框架程序。Gunicorn具有轻量级、高稳定性和高性能等特性,可以轻易提高PythonWSGIApp运行时的性能。基本原理Gunicorn采用了pre-fork模型,也就是一个工作进程和多个worker进程的工作模式。在这个模......
  • python十六进制编辑器
    源代码:importtkinterastkfromtkinterimportfiledialogimportstructimportbinasciiimportosclassHexEditor:def__init__(self,master):self.master=masterself.master.title("十六进制编辑器")self.master.configure(bg......
  • python项目学习 mediapipe手势识别 opencv可视化显示
    importcv2importmediapipeimportnumpydefget_angle(vector1,vector2):#角度计算angle=numpy.dot(vector1,vector2)/(numpy.sqrt(numpy.sum(vector1*vector1))*numpy.sqrt(numpy.sum(vector2*vector2)))#cos(angle)=向量的点乘/向量的模angle=nump......
  • 【优秀python大屏】基于python flask的广州历史天气数据应用与可视化大屏
    摘要气象数据分析在各行各业中扮演着重要的角色,尤其对于农业、航空、海洋、军事、资源环境等领域。在这些领域中,准确的气象数据可以对预测未来的自然环境变化和采取行动来减轻负面影响的决策起到至关重要的作用。本系统基于PythonFlask框架,通过对气象数据的分析和处理来提供......
  • Python-MNE全套教程(官网翻译)-入门01:概述篇
    目的以牺牲深度为代价进行入门学习,简易学习基本方法开始导入相关库:#License:BSD-3-Clause#CopyrighttheMNE-Pythoncontributors.importnumpyasnpimportmne加载数据MNE-Python数据结构式基于fif格式的,但是对于其他格式也有阅读方法,如https://mne.tools/s......
  • Python-MNE全套教程(官网翻译)-入门05:关于传感器位置
    本教程描述了如何读取和绘制传感器位置,以及MNE-Python如何处理传感器的物理位置。像往常一样,我们将从导入我们需要的模块开始:frompathlibimportPathimportmatplotlib.pyplotaspltimportnumpyasnpimportmne关于montage和layout(蒙太奇和传感器布局)montage......
  • Codeforces Round 963 (Div. 2) A - C 详细题解(思路加代码,C++,Python) -- 来自灰名
    比赛链接:Dashboard-CodeforcesRound963(Div.2)-Codeforces之后有实力了再试试后面的题目,现在要做那些题,起码要理解一个多小时题目A:链接:Problem-A-Codeforces题目大意理解:        极少数不考翻译能读懂的cf题目(bushi)每个测试用例第一行一个n,......