首页 > 编程问答 >PyTorch 数据集中某些类的训练验证拆分结果为零样本

PyTorch 数据集中某些类的训练验证拆分结果为零样本

时间:2024-07-30 08:12:24浏览次数:11  
标签:python machine-learning deep-learning pytorch training-data

我正在使用 PyTorch 进行图像分类。我的数据集是目录格式。我已经设置了数据管道和模型。尽管如此,我在训练验证分割中遇到了一个问题,其中某些类在训练或验证数据集中的样本为零。这是我的代码和设置的相关部分:

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.image_paths = []
        self.labels = []
        
        for label, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for img_path in glob.glob(os.path.join(class_dir, '*.png')) + \
                           glob.glob(os.path.join(class_dir, '*.jpg')) + \
                           glob.glob(os.path.join(class_dir, '*.jpeg')):
                self.image_paths.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image)  # Convert PIL image to tensor if no transform is provided

        return image, label
class AugmentedDataset(Dataset):
    def __init__(self, base_dataset, transforms_list):
        self.base_dataset = base_dataset
        self.transforms_list = transforms_list if isinstance(transforms_list, list) else [transforms_list]

    def __len__(self):
        return len(self.base_dataset) * len(self.transforms_list)

    def __getitem__(self, idx):
        base_idx = idx // len(self.transforms_list)
        transform_idx = idx % len(self.transforms_list)

        image, label = self.base_dataset[base_idx]
        transform = self.transforms_list[transform_idx]

        if transform:
            image = transform(image)

        return image, label

训练测试分割

base_dataset = CustomDataset(train_dir, v2.Compose(basic_transformations))

train_size = int(0.8 * len(base_dataset))
val_size = len(base_dataset) - train_size

train_base_dataset, val_dataset = random_split(base_dataset, [train_size, val_size])
train_dataset = AugmentedDataset(train_base_dataset, augmentations)
val_dataset = AugmentedDataset(train_base_dataset, v2.Compose(final_transformation))

trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

执行分割后,我注意到某些类在训练或验证数据集中的样本为零。

  1. 为什么 pytorch 会发生这种情况 random_split 方法?
  2. 确保训练和验证集中所有类的平衡分割的最佳实践是什么?
  3. 我可以使用任何特定的技术或库来在分割期间保持类平衡?| ||这是我的笔记本:

EfficientNet with Augmentation EfficientNet with Augmentation


在训练/验证数据集中遇到某些类缺少样本的问题,这是由于 PyTorch 的 random_split 函数的工作方式造成的,该函数不考虑数据集中的类分布。

以下是解决此问题并确保在 PyTorch 中进行均衡分割的方法:

  1. 按类拆分:

与其将整个数据集随机拆分为训练集和验证集,不如考虑先按类拆分数据。这样,就可以控制每个类的训练集和验证集的比率。

```python from torch.utils.data import Subset import numpy as np

def stratified_split(dataset, val_split=0.2): train_indices, val_indices = [], [] for class_label in np.unique(dataset.labels): class_indices = np.where(np.array(dataset.labels) == class_label)[0] np.random.shuffle(class_indices) split_idx = int(len(class_indices) * val_split) val_indices.extend(class_indices[:split_idx]) train_indices.extend(class_indices[split_idx:]) return Subset(dataset, train_indices), Subset(dataset, val_indices)

train_base_dataset, val_dataset = stratified_split(base_dataset, val_split=0.2) ```

stratified_split 函数将确保训练集和验证集中每个类都有代表性。

  1. 使用 StratifiedShuffleSplit 进行更强大的拆分:

对于更复杂的情况或更大的数据集,可以使用 scikit-learn 的 StratifiedShuffleSplit 保留类分布。

```python from sklearn.model_selection import StratifiedShuffleSplit

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42) train_indices, val_indices = next(sss.split(np.zeros_like(base_dataset.labels), base_dataset.labels))

train_base_dataset = Subset(base_dataset, train_indices) val_dataset = Subset(base_dataset, val_indices) ```

这会根据标签数组创建分层分割。

  1. 加权采样:

解决类不平衡问题的另一种方法是在训练期间使用加权采样。这涉及根据其在数据集中出现的频率为不同类分配不同的权重。PyTorch 的 WeightedRandomSampler 可以帮助做到这一点:

```python from torch.utils.data import WeightedRandomSampler

class_counts = [len(np.where(np.array(train_base_dataset.labels) == i)[0]) for i in np.unique(train_base_dataset.labels)] weights = 1. / torch.tensor(class_counts, dtype=torch.float) sample_weights = weights[train_base_dataset.labels] sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler) ```

这将在训练期间对欠表示的类进行过度采样,从而有助于缓解类不平衡问题。

选择哪种方法?

  • 按类拆分 简单直观,适合大多数情况。
  • StratifiedShuffleSplit 提供了更大的灵活性,可以在保持类分布的同时创建多个拆分(例如,用于交叉验证)。
  • 加权采样 直接解决了训练期间的类不平衡问题,并且在处理高度不平衡的数据集时可能非常有效。

选择最适合特定需求和数据集特征的方法。请记住,在机器学习中保持训练集和验证集之间的类分布平衡对于避免模型偏差和确保对所有类进行可靠评估至关重要。

标签:python,machine-learning,deep-learning,pytorch,training-data
From: 78808028

相关文章

  • Python多重处理,如何避免创建具有百万个对象的元组
    python多处理新手。我有一项任务,涉及访问网络服务数百万次并将响应保存在文件中(每个请求都有单独的文件)。我已经得到了高级工作代码,但对一些事情没有感到困惑。以下两种语法有什么区别?pool=Pool(processes=4)pool.starmap(task,listOfInputParametersTu......
  • Python OpenCV - 显示坏像素检查测试
    我想找到显示器中存在的每个坏像素。坏像素可能是颜色不正确的像素,或者像素只是黑色。显示屏的尺寸为160x320像素。所以如果显示效果好的话,必须有160*320=51200像素。如果显示器没有51200像素,那就是坏的。另外,我想知道每个坏像素的位置。一旦拍摄的图像太大,我将共享一个......
  • 在python日志输出的每一行前面添加变量缩进
    我正在将日志记录构建到一个Python应用程序中,我希望它是人类可读的。目前,调试日志记录了调用的每个函数以及参数和返回值。这意味着,实际上,嵌套函数调用的调试日志可能如下所示:2024-07-2916:52:26,641:DEBUG:MainController.initialize_componentscalledwithargs<control......
  • 使用 DQN 实现 pong,使用 python 中的特征向量而不是像素。我的 DQNA 实现代码正确吗,因
    我正在致力于使用OpenAI的Gym为Pong游戏实现强化学习(RL)环境。目标是训练人工智能代理通过控制球拍来打乒乓球。代理收到太多负面奖励,即使它看起来移动正确。具体来说,奖励函数会惩罚远离球的智能体,但这种情况发生得太频繁,即使球朝球拍移动时似乎也会发生。观察......
  • Python CDLL 无法加载两次
    我正在尝试用python创建一个密码管理器,但遇到了一个问题,一旦加载了一种类型的dll,我就无法加载不同的dll,在这个示例中,我加载了一个dll,并尝试解密加密的密码数据,它工作正常,直到我加载另一个不同的nss3.dll文件,此时它给我一个错误:“过程入口点HeapAlloc无法位于动态链......
  • 你能将 HTTPS 功能添加到 python Flask Web 服务器吗?
    我正在尝试构建一个Web界面来模拟网络设备上的静态接口,该网络设备使用摘要式身份验证和HTTPS。我想出了如何将摘要式身份验证集成到Web服务器中,但我似乎无法找到如何使用FLASK获取https,如果您可以向我展示如何实现,请评论我需要使用下面的代码做什么来实现这一点。from......
  • Python:比较 csv 文件并打印相似之处
    我需要比较两个csv文件并打印出它们的相似之处。第一个文件有名称和浓度,第二个文件就像只有名称的“最佳”列表,我需要绘制相似性图表。例如,这就是我的列表的样子:file1-old_file.csvname_id,conc_test1,conc_test2name1,####,####name2,###......
  • Python 类交叉引用
    我用Python创建了一个数独游戏。我有一个:单元格类-“保存”数字可能性单元格组-保存单元格类实例我使用这些组在数独中运行行、列和正方形功能。每个单元格包含所有组,他属于classCell:def__init__(groups):self.groups=groupscla......
  • 如何修复我的 Python Azure Function DevOps Pipeline 上的“找到 1 个函数(自定义)加载
    我正在尝试使用AzureDevOps构建管道将PythonAzureFunction部署到Azure门户。由于某种原因,代码被部署到服务器,但我在尝试访问端点时收到404错误。我收到一个错误,显示1functionsfound(Custom)0functionsloaded,以及在服务器上显示ModuleNotFound......
  • 使用 kivy 从 python 脚本的 buildozer 构建 android apk 时出错
    我想从使用kivy包构建的Python脚本构建apk为此,我使用googlecollab.这里是main.py脚本:importyoutube_dlfromkivy.appimportAppfromkivy.uix.boxlayoutimportBoxLayoutfromkivy.uix.buttonimportButtonfromkivy.uix.tex......