首先我们先看一下Dataset的官方api:
CLASS torch.utils.data.Dataset(*args, **kwds)
An abstract class representing a Dataset.
- All datasets that represent a map from keys to data samples should subclass it.
- All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key.
- Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.
DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.
TensorDataset的官方api:
CLASS torch.utils.data.TensorDataset(*tensors)
Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension.
Parameters
*tensors (Tensor) – tensors that have the same size of the first dimension.
顾名思义,torch.utils.data.TensorDataset 基于一系列张量构建数据集。这些张量的形状可以不尽相同,但第一个维度必须具有相同大小,这是为了保证在使用 DataLoader 时可以正常地返回一个批量的数据。
以下是 TensorDataset 的源码:
class TensorDataset(Dataset[Tuple[Tensor, ...]]): r"""Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Args: *tensors (Tensor): tensors that have the same size of the first dimension. """ tensors: Tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors" self.tensors = tensors def __getitem__(self, index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)
标签:__,TensorDataset,utils,torch,Dataset,data,tensors From: https://www.cnblogs.com/zjuhaohaoxuexi/p/16758239.html