def train_model(model,
dataset,
cfg,
validate=False,
test=dict(test_best=False, test_last=False),
timestamp=None,
meta=None):
"""Train model entry function.
Args:
model (nn.Module): The model to be trained.
dataset (:obj:`Dataset`): Train dataset.
cfg (dict): The config dict for training.
validate (bool): Whether to do evaluation. Default: False.
test (dict): The testing option, with two keys: test_last & test_best.
The value is True or False, indicating whether to test the
corresponding checkpoint.
Default: dict(test_best=False, test_last=False).
timestamp (str | None): Local time for runner. Default: None.
meta (dict | None): Meta dict to record some important information.
Default: None
"""
logger = get_root_logger(log_level=cfg.get('log_level', 'INFO'))
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
dataloader_setting = dict(
videos_per_gpu=cfg.data.get('videos_per_gpu', 1),
workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
persistent_workers=cfg.data.get('persistent_workers', False),
seed=cfg.seed)
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('train_dataloader', {}))
data_loaders = [
build_dataloader(ds, **dataloader_setting) for ds in dataset
]
data_loaders的数据类型是列表,[<torch.utils.data.dataloader.DataLoader object at 0x7f4b3bfc7b90>]
dataset 变量是列表,[<pyskl.datasets.dataset_wrappers.RepeatDataset object at 0x7fe7b3668550>]
既然 dataset
变量是一个包含 RepeatDataset
对象的列表,那么这段代码的实现就更加清晰了。
RepeatDataset
是一个数据集包装器,可以对原始数据集进行重复采样,以增加训练样本的数量。
在这种情况下,build_dataloader()
函数会为 dataset
列表中的每个 RepeatDataset
对象创建一个对应的 DataLoader
对象,并将它们存储在 data_loaders
列表中。
这种设计方式有以下好处:
- 可以同时处理多个数据集,每个数据集都有独立的
DataLoader
。 - 通过使用
RepeatDataset
,可以对数据集进行重复采样,增加训练样本的数量。 - 将
DataLoader
对象存储在data_loaders
列表中,方便后续的训练和评估操作。
总的来说,这段代码的实现逻辑是合理的,能够灵活地处理多个数据集,并对数据进行重复采样。
怎么发现data_loaders不应该是DataLoader类进行实例化的吗?比如train_dataloader = DataLoader(train_data, batch-size=16)这样?
事实上,在这段代码中,data_loaders
是正确地使用 DataLoader
类进行实例化的。
让我们回顾一下您提到的示例代码:
train_dataloader = DataLoader(train_data, batch_size=16)
这段代码是正确的,它创建了一个 DataLoader
对象,用于加载训练数据。
在前面的代码中,data_loaders
是一个列表,其中包含了多个 DataLoader
对象。这是因为原始的 dataset
变量可能包含多个数据集对象,因此需要为每个数据集创建一个 DataLoader
。
所以,data_loaders
是正确的,它是一个包含多个 DataLoader
对象的列表,而不是直接使用 DataLoader
类进行实例化。
相比之下,您提到的 train_dataloader = DataLoader(train_data, batch_size=16)
是一个单独的 DataLoader
对象,用于加载训练数据。这种方式也是正确的,只是与原始代码中的 data_loaders
实现有所不同。
总结来说,原始代码中的 data_loaders
是正确的,它是一个 DataLoader
对象的列表,用于处理可能包含多个数据集的情况。您提到的 train_dataloader = DataLoader(train_data, batch_size=16)
也是正确的,只是处理了单个训练数据集的情况。两种实现方式都是合理的,只是根据具体的需求而采用不同的方式。
pyskl/pyskl/apis/train.py at main · kennymckormick/pyskl · GitHub
标签:dataloader,训练,DataLoader,dataset,工作,train,准备,loaders,data From: https://blog.csdn.net/qq_34425255/article/details/139887939