我正在使用 PyTorch 进行图像分类。我的数据集是目录格式。我已经设置了数据管道和模型。尽管如此,我在训练验证分割中遇到了一个问题,其中某些类在训练或验证数据集中的样本为零。这是我的代码和设置的相关部分:
class CustomDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.classes = os.listdir(root_dir)
self.image_paths = []
self.labels = []
for label, class_name in enumerate(self.classes):
class_dir = os.path.join(root_dir, class_name)
for img_path in glob.glob(os.path.join(class_dir, '*.png')) + \
glob.glob(os.path.join(class_dir, '*.jpg')) + \
glob.glob(os.path.join(class_dir, '*.jpeg')):
self.image_paths.append(img_path)
self.labels.append(label)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
else:
image = transforms.ToTensor()(image) # Convert PIL image to tensor if no transform is provided
return image, label
class AugmentedDataset(Dataset):
def __init__(self, base_dataset, transforms_list):
self.base_dataset = base_dataset
self.transforms_list = transforms_list if isinstance(transforms_list, list) else [transforms_list]
def __len__(self):
return len(self.base_dataset) * len(self.transforms_list)
def __getitem__(self, idx):
base_idx = idx // len(self.transforms_list)
transform_idx = idx % len(self.transforms_list)
image, label = self.base_dataset[base_idx]
transform = self.transforms_list[transform_idx]
if transform:
image = transform(image)
return image, label
训练测试分割
base_dataset = CustomDataset(train_dir, v2.Compose(basic_transformations))
train_size = int(0.8 * len(base_dataset))
val_size = len(base_dataset) - train_size
train_base_dataset, val_dataset = random_split(base_dataset, [train_size, val_size])
train_dataset = AugmentedDataset(train_base_dataset, augmentations)
val_dataset = AugmentedDataset(train_base_dataset, v2.Compose(final_transformation))
trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
执行分割后,我注意到某些类在训练或验证数据集中的样本为零。
-
为什么 pytorch 会发生这种情况
random_split
方法? - 确保训练和验证集中所有类的平衡分割的最佳实践是什么?
- 我可以使用任何特定的技术或库来在分割期间保持类平衡?| ||这是我的笔记本:
EfficientNet with Augmentation EfficientNet with Augmentation
在训练/验证数据集中遇到某些类缺少样本的问题,这是由于 PyTorch 的
random_split
函数的工作方式造成的,该函数不考虑数据集中的类分布。
以下是解决此问题并确保在 PyTorch 中进行均衡分割的方法:
- 按类拆分:
与其将整个数据集随机拆分为训练集和验证集,不如考虑先按类拆分数据。这样,就可以控制每个类的训练集和验证集的比率。
```python from torch.utils.data import Subset import numpy as np
def stratified_split(dataset, val_split=0.2): train_indices, val_indices = [], [] for class_label in np.unique(dataset.labels): class_indices = np.where(np.array(dataset.labels) == class_label)[0] np.random.shuffle(class_indices) split_idx = int(len(class_indices) * val_split) val_indices.extend(class_indices[:split_idx]) train_indices.extend(class_indices[split_idx:]) return Subset(dataset, train_indices), Subset(dataset, val_indices)
train_base_dataset, val_dataset = stratified_split(base_dataset, val_split=0.2) ```
此
stratified_split
函数将确保训练集和验证集中每个类都有代表性。
-
使用
StratifiedShuffleSplit
进行更强大的拆分:
对于更复杂的情况或更大的数据集,可以使用 scikit-learn 的
StratifiedShuffleSplit
保留类分布。
```python from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42) train_indices, val_indices = next(sss.split(np.zeros_like(base_dataset.labels), base_dataset.labels))
train_base_dataset = Subset(base_dataset, train_indices) val_dataset = Subset(base_dataset, val_indices) ```
这会根据标签数组创建分层分割。
- 加权采样:
解决类不平衡问题的另一种方法是在训练期间使用加权采样。这涉及根据其在数据集中出现的频率为不同类分配不同的权重。PyTorch 的
WeightedRandomSampler
可以帮助做到这一点:
```python from torch.utils.data import WeightedRandomSampler
class_counts = [len(np.where(np.array(train_base_dataset.labels) == i)[0]) for i in np.unique(train_base_dataset.labels)] weights = 1. / torch.tensor(class_counts, dtype=torch.float) sample_weights = weights[train_base_dataset.labels] sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler) ```
这将在训练期间对欠表示的类进行过度采样,从而有助于缓解类不平衡问题。
选择哪种方法?
- 按类拆分 简单直观,适合大多数情况。
-
StratifiedShuffleSplit
提供了更大的灵活性,可以在保持类分布的同时创建多个拆分(例如,用于交叉验证)。 - 加权采样 直接解决了训练期间的类不平衡问题,并且在处理高度不平衡的数据集时可能非常有效。
选择最适合特定需求和数据集特征的方法。请记住,在机器学习中保持训练集和验证集之间的类分布平衡对于避免模型偏差和确保对所有类进行可靠评估至关重要。
标签:python,machine-learning,deep-learning,pytorch,training-data From: 78808028