首页 > 其他分享 >pytorch 训练模型很慢,卡在数据读取,卡I/O的有效解决方案

pytorch 训练模型很慢,卡在数据读取,卡I/O的有效解决方案

时间:2022-10-29 19:31:46浏览次数:72  
标签:读取 idx 解决方案 self pitch pytorch speaker text path


多线程加载

  • 在 datalaoder中指定​​num_works > 0​​,多线程加载数据集,最大可设置为 cpu 核数
  • 设置 ​​pin_memory = True​​, 固定内存访问单元,节约内存调度时间
  • 示例如下:
loader = DataLoader(
dataset,
batch_size=batch_size * group_size,
shuffle=True,
collate_fn=dataset.collate_fn,
num_workers=2,
pin_memory=True,
)

预加载数据集

说别的都没大用,还得是预加载

  • 原理:将整个数据集预先 load 到内存单元中,读取则直接访问内存,不存在与磁盘的I/O问题
  • 构建自己的dataset类
  • 示例如下:
class My_Dataset(Dataset):
def __init__(
self, filename, preprocess_config, train_config, sort=False, drop_last=False
):
self.dataset_name = preprocess_config["dataset"]
self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
self.batch_size = train_config["optimizer"]["batch_size"]

self.basename, self.speaker, self.text, self.raw_text = self.process_meta(
filename
)
with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
self.speaker_map = json.load(f)
self.sort = sort
self.drop_last = drop_last
# add
self.mel_list = []
self.pitch_list = []
self.energy_list = []
self.duration_list = []
for idx in range(len(self.text)):
basename = self.basename[idx]
speaker = self.speaker[idx]
mel_path = os.path.join(
self.preprocessed_path,
"mel",
"{}-mel-{}.npy".format(speaker, basename),
)
mel = np.load(mel_path)
pitch_path = os.path.join(
self.preprocessed_path,
"pitch",
"{}-pitch-{}.npy".format(speaker, basename),
)
pitch = np.load(pitch_path)
energy_path = os.path.join(
self.preprocessed_path,
"energy",
"{}-energy-{}.npy".format(speaker, basename),
)
energy = np.load(energy_path)
duration_path = os.path.join(
self.preprocessed_path,
"duration",
"{}-duration-{}.npy".format(speaker, basename),
)
duration = np.load(duration_path)
self.mel_list.append(mel)
self.pitch_list.append(pitch)
self.energy_list.append(energy)
self.duration_list.append(duration)

def __len__(self):
return len(self.text)

def __getitem__(self, idx):
basename = self.basename[idx]
speaker = self.speaker[idx]
speaker_id = self.speaker_map[speaker]
raw_text = self.raw_text[idx]
phone = np.array(text_to_sequence(self.text[idx], self.cleaners))

mel = self.mel_list[idx]
pitch = self.pitch_list[idx]
energy = self.energy_list[idx]
duration = self.duration_list[idx]

sample = {
"id": basename,
"speaker": speaker_id,
"text": phone,
"raw_text": raw_text,
"mel": mel,
"pitch": pitch,
"energy": energy,
"duration": duration,
}

return sample
  • 在 ​​__init__​​函数里,即将所有数据load进内存
  • ​__getitem__(self, idx):​​函数,则直接通过列表idx访问每一条数据


标签:读取,idx,解决方案,self,pitch,pytorch,speaker,text,path
From: https://blog.51cto.com/u_15365984/5806481

相关文章