我想更改 torch DataLoader,并在其中使用消费者/生产者模式。
我有一个队列,一个线程将文件放入其中,这些项目由框架使用
__getitem__
使用。
这是我的代码:
import glob
import time
from torch.utils.data import DataLoader, Dataset
import threading
import queue
class CustomDataset(Dataset):
size = 25
def __init__(self):
self.voices_path = "/home/hatef/mmm/data/voice/"
file_list = glob.glob(self.voices_path + "*")
self.data = []
for f in file_list:
self.data.append(f)
self.q = queue.Queue(CustomDataset.size)
def fill_queue(self):
dataz = self.data
it = 0
while it < len(dataz):
path = dataz[it]
if not self.q.full():
with open(path,"rb") as f:
self.q.put(f.read())
it += 1
else:
print("queue size: " + str(self.q.qsize()))
time.sleep(2)
def __len__(self):
return len(self.data)
def __getitem__(self, _):
print("before get item")
item = self.q.get()
self.q.task_done()
return [item, 1]
params = {'batch_size': 16,
'shuffle': True,
'num_workers': 4}
epoch = 100
dataset = CustomDataset()
training_generator = DataLoader(dataset, **params)
producer = threading.Thread(target=dataset.fill_queue, daemon=True)
producer.start()
start_time = time.time()
for i in range(epoch):
for f in training_generator:
print(f)
print("............")
print(f"**********************************run epoch***************************************: {i}")
# time.sleep(10)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time} seconds")
__getitem__
我的代码永远不会取得进展。问题是什么?
代码中的问题在于,
__getitem__
方法中的
self.q.get()
调用会在队列为空时阻塞。这意味着,当工作线程尝试从队列中获取项目时,如果队列为空,它们就会被阻塞,直到有项目可用为止。由于没有在主线程中将任何项目放入队列,因此工作线程会无限期地阻塞,从而导致死锁。
以下是如何使用生产者/消费者模式修复代码的方法:
-
在单独的线程中启动生产者 :已经使用
threading.Thread
在单独的线程中启动了生产者,这是正确的。 -
使用
q.get(block=True, timeout=None)
:修改__getitem__
方法,以便它在从队列中获取项目时使用block=True
和timeout
参数。这将确保工作线程在队列为空时不会无限期地阻塞。 -
在
fill_queue
方法中处理异常 :当队列已满时,q.put(item)
调用会引发queue.Full
异常。需要在fill_queue
方法中处理此异常,例如,通过休眠一小段时间,然后重试。 -
在主线程中加入生产者线程 :在主线程中使用
producer.join()
方法等待生产者线程完成。
以下是更新后的代码:
import glob
import time
from torch.utils.data import DataLoader, Dataset
import threading
import queue
class CustomDataset(Dataset):
size = 25
def __init__(self):
self.voices_path = "/home/hatef/mmm/data/voice/"
file_list = glob.glob(self.voices_path + "*")
self.data = file_list # 不需要将文件内容存储在内存中
self.q = queue.Queue(CustomDataset.size)
def fill_queue(self):
for path in self.data:
try:
with open(path, "rb") as f:
self.q.put(f.read(), block=True, timeout=5) # 设置超时时间
except queue.Full:
print("队列已满,稍后再试...")
time.sleep(1)
print("生产者已完成")
def __len__(self):
return len(self.data)
def __getitem__(self, _):
item = self.q.get(block=True, timeout=5) # 设置超时时间
self.q.task_done()
return [item, 1]
params = {'batch_size': 16,
'shuffle': True,
'num_workers': 4}
epoch = 100
dataset = CustomDataset()
training_generator = DataLoader(dataset, **params)
producer = threading.Thread(target=dataset.fill_queue, daemon=True)
producer.start()
start_time = time.time()
for i in range(epoch):
for f in training_generator:
print(f)
print("............")
print(f"**********************************运行时期***************************************: {i}")
# time.sleep(10)
producer.join() # 等待生产者线程完成
end_time = time.time()
execution_time = end_time - start_time
print(f"执行时间: {execution_time} 秒")
这些更改将确保:
- 生产者线程在将所有数据放入队列后完成。
- 工作线程在从队列中获取项目时不会无限期地阻塞。
- 主线程在生产者线程完成后结束。