https://www.dandelioncloud.cn/article/details/1601780566695559170
目录结构
本教程实验环境为Google Colab,文件目录结构如下
ALL
└── tacotron2
├── audio_processing.py
├── checkpoint_269000
├── data_utils.py
├── demo.wav
├── distributed.py
├── Dockerfile
├── filelists
│ ├── ljs_audio_text_test_filelist.txt
│ ├── ljs_audio_text_train_filelist.txt
│ └── ljs_audio_text_val_filelist.txt
├── hparams.py
├── inference.ipynb
├── layers.py
├── LICENSE
├── logger.py
├── loss_function.py
├── loss_scaler.py
├── model.py
├── multiproc.py
├── plotting_utils.py
├── __pycache__
│ ├── audio_processing.cpython-36.pyc
│ ├── data_utils.cpython-36.pyc
│ ├── distributed.cpython-36.pyc
│ ├── hparams.cpython-36.pyc
│ ├── layers.cpython-36.pyc
│ ├── logger.cpython-36.pyc
│ ├── loss_function.cpython-36.pyc
│ ├── model.cpython-36.pyc
│ ├── plotting_utils.cpython-36.pyc
│ ├── stft.cpython-36.pyc
│ ├── train.cpython-36.pyc
│ └── utils.cpython-36.pyc
├── README.md
├── requirements.txt
├── stft.py
├── tensorboard.png
├── text
│ ├── cleaners.py
│ ├── cmudict.py
│ ├── __init__.py
│ ├── LICENSE
│ ├── numbers.py
│ ├── __pycache__
│ │ ├── cleaners.cpython-36.pyc
│ │ ├── cmudict.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── numbers.cpython-36.pyc
│ │ └── symbols.cpython-36.pyc
│ └── symbols.py
├── train.py
├── utils.py
└── waveglow
├── config.json
├── convert_model.py
├── denoiser.py
├── distributed.py
├── glow_old.py
├── glow.py
├── inference.py
├── LICENSE
├── mel2samp.py
├── __pycache__
│ ├── denoiser.cpython-36.pyc
│ └── glow.cpython-36.pyc
├── README.md
├── requirements.txt
├── tacotron2
├── train.py
├── waveglow_256channels_universal_v5.pt
└── waveglow_logo.png
文件准备
首先请读者创建一个名为ALL
的空文件夹,通过git clone https://github.com/NVIDIA/tacotron2.git
命令将tacotron2完整的代码文件下载下来。此时ALL
文件夹里面会多出一个名为tacotron2
的文件夹,在这个文件夹里有一个inference.ipynb
文件,就是等会要用到的推理部分的代码
接着将预训练好的WaveGlow模型保存到waveglow
文件夹中(该模型名为waveglow_256channels_universal_v5.pt
)
最后还需要一个最重要的文件,就是tacotron2训练时保存的模型文件,一般在训练过程中,它会自动命名为checkpoint_xxxx
,将其放到tacotron2
文件夹下。如果你自己没有训练tacotron2,官方也提供了一个训练好的模型文件
修改Inference代码
再次强调,我的实验环境是Colab,以下内容均为,文字解释在上,对应代码在下
首先需要确保tensorflow版本为1.x,否则会报错
%tensorflow_version 1.x
import tensorflow as tf
tf.__version__
然后进入ALL/tacotron2
目录
%cd ALL/tacotron2
执行代码前需要确保已经安装了unidecode
库
!pip install unidecode
导入库,定义函数
import matplotlib
%matplotlib inline
import matplotlib.pylab as plt
import IPython.display as ipd
import sys
sys.path.append('waveglow/')
import numpy as np
import torch
from hparams import create_hparams
from model import Tacotron2
from layers import TacotronSTFT, STFT
from audio_processing import griffin_lim
from train import load_model
from text import text_to_sequence
from denoiser import Denoiser
def plot_data(data, figsize=(16, 4)):
fig, axes = plt.subplots(1, len(data), figsize=figsize)
for i in range(len(data)):
axes[i].imshow(data[i], aspect='auto', origin='bottom',
interpolation='none')
hparams = create_hparams()
hparams.sampling_rate = 21050 # 该参数会影响生成语音的语速,越大则语速越快
checkpoint_path = "checkpoint_269000"
model = load_model(hparams)
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
_ = model.cuda().eval().half()
接着进入waveglow
目录加载waveglow模型
%cd ALL/tacotron2/waveglow
waveglow_path = 'waveglow_256channels_universal_v5.pt'
waveglow = torch.load(waveglow_path)['model']
waveglow.cuda().eval().half()
for k in waveglow.convinv:
k.float()
denoiser = Denoiser(waveglow)
输入文本
text = "WaveGlow is really awesome!"
sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
sequence = torch.autograd.Variable(
torch.from_numpy(sequence)).cuda().long()
生成梅尔谱输出,以及画出attention图
mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)
plot_data((mel_outputs.float().data.cpu().numpy()[0],
mel_outputs_postnet.float().data.cpu().numpy()[0],
alignments.float().data.cpu().numpy()[0].T))
使用waveglow将梅尔谱合成为语音
with torch.no_grad():
audio = waveglow.infer(mel_outputs_postnet, sigma=0.666)
ipd.Audio(audio[0].data.cpu().numpy(), rate=hparams.sampling_rate)
(可选)移除waveglow的bias
audio_denoised = denoiser(audio, strength=0.01)[:, 0]
ipd.Audio(audio_denoised.cpu().numpy(), rate=hparams.sampling_rate)