多线程加载
- 在 datalaoder中指定
num_works > 0
,多线程加载数据集,最大可设置为 cpu 核数 - 设置
pin_memory = True
, 固定内存访问单元,节约内存调度时间 - 示例如下:
loader = DataLoader(
dataset,
batch_size=batch_size * group_size,
shuffle=True,
collate_fn=dataset.collate_fn,
num_workers=2,
pin_memory=True,
)
预加载数据集
说别的都没大用,还得是预加载
- 原理:将整个数据集预先 load 到内存单元中,读取则直接访问内存,不存在与磁盘的I/O问题
- 构建自己的dataset类
- 示例如下:
class My_Dataset(Dataset):
def __init__(
self, filename, preprocess_config, train_config, sort=False, drop_last=False
):
self.dataset_name = preprocess_config["dataset"]
self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
self.batch_size = train_config["optimizer"]["batch_size"]
self.basename, self.speaker, self.text, self.raw_text = self.process_meta(
filename
)
with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
self.speaker_map = json.load(f)
self.sort = sort
self.drop_last = drop_last
# add
self.mel_list = []
self.pitch_list = []
self.energy_list = []
self.duration_list = []
for idx in range(len(self.text)):
basename = self.basename[idx]
speaker = self.speaker[idx]
mel_path = os.path.join(
self.preprocessed_path,
"mel",
"{}-mel-{}.npy".format(speaker, basename),
)
mel = np.load(mel_path)
pitch_path = os.path.join(
self.preprocessed_path,
"pitch",
"{}-pitch-{}.npy".format(speaker, basename),
)
pitch = np.load(pitch_path)
energy_path = os.path.join(
self.preprocessed_path,
"energy",
"{}-energy-{}.npy".format(speaker, basename),
)
energy = np.load(energy_path)
duration_path = os.path.join(
self.preprocessed_path,
"duration",
"{}-duration-{}.npy".format(speaker, basename),
)
duration = np.load(duration_path)
self.mel_list.append(mel)
self.pitch_list.append(pitch)
self.energy_list.append(energy)
self.duration_list.append(duration)
def __len__(self):
return len(self.text)
def __getitem__(self, idx):
basename = self.basename[idx]
speaker = self.speaker[idx]
speaker_id = self.speaker_map[speaker]
raw_text = self.raw_text[idx]
phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
mel = self.mel_list[idx]
pitch = self.pitch_list[idx]
energy = self.energy_list[idx]
duration = self.duration_list[idx]
sample = {
"id": basename,
"speaker": speaker_id,
"text": phone,
"raw_text": raw_text,
"mel": mel,
"pitch": pitch,
"energy": energy,
"duration": duration,
}
return sample
- 在
__init__
函数里,即将所有数据load进内存 -
__getitem__(self, idx):
函数,则直接通过列表idx访问每一条数据