首页 > 其他分享 >机器学习常见的sampling策略 附PyTorch实现

机器学习常见的sampling策略 附PyTorch实现

时间:2024-04-09 21:11:05浏览次数:30  
标签:采样 机器 sampling PyTorch balanced counts Class cls

简单的采样策略

首先介绍三种简单采样策略:

  1. Instance-balanced sampling, 实例平衡采样。
  2. Class-balanced sampling, 类平衡采样。
  3. Square-root sampling, 平方根采样。

它们可抽象为:

\[p_j=\frac{n_j^q}{\sum_{i=1}^Cn_i^q}, \]

\(p_j\)表示从j类采样数据的概率;\(C\)表示类别数量;\(n_j\)表示j类样本数;\(q\in\{1,0,\frac{1}{2}\}\)
Instance-balanced sampling
最常见的数据采样方式,其中每个训练样本被选择的概率相等(\(q=1\))。j类被采样的概率\(p^{\mathbf{IB}}_j\)与j类样本数\(n_j\)成正比,即\(p^{\mathbf{IB}}_j=\frac{n_j}{\sum_{i=1}^Cn_i}\)。

Class-balanced sampling
实例平衡采样在不平衡的数据集中往往表现不佳,类平衡采样让所有的类有相同的被采样概率:\(p^{\mathbf{CB}}_j=\frac{1}{C}\)。采样可分为两个阶段:1. 从类集中统一选择一个类;2. 对该类中的实例进行统一采样。
Square-root sampling
平方根采样最常见的变体,\(q=\frac{1}{2}\)

由于这三种采样策略都是调整类别的采样概率(权重),因此可用PyTorch提供的WeightedRandomSampler实现:

import numpy as np
from torch.utils.data.sampler import WeightedRandomSampler
def get_sampler(sampling_type, targets):
    cls_counts = np.bincount(targets)
    if sampling_type == 'instance-balanced':
        cls_weights = cls_counts / np.sum(cls_counts)
        
    elif sampling_type == 'class-balanced':
        cls_num = len(cls_counts)
        cls_weights = [1. / cls_num] * cls_num
        
    elif sampling_type == 'square-root':
        sqrt_and_sum = np.sum([num**0.5 for num in cls_counts])
        cls_weights = [num**0.5 / sqrt_and_sum for num in cls_counts]
    else:
        raise ValueError('sampling_type should be instance-balanced, class-balanced or square-root')
    
    cls_weights = np.array(cls_weights)
    return WeightedRandomSampler(cls_weights[targets], len(targets), replacement=True)

WeightedRandomSampler,第一个参数表示每个样本的权重,第二个参数表示采样的样本数,第三个参数表示是否有放回采样。

在模拟的长尾数据集测试下:

import torch
from torch.utils.data import Dataset, DataLoader
torch.manual_seed(0)
np.random.seed(0)
class LongTailDataset(Dataset):
    def __init__(self, num_classes, max_samples_per_class):
        self.num_classes = num_classes
        self.max_samples_per_class = max_samples_per_class

        # Generate number of samples for each class inversely proportional to class index
        self.samples_per_class = [self.max_samples_per_class // (i + 1) for i in range(self.num_classes)]
        self.total_samples = sum(self.samples_per_class)

        # Generate targets for the dataset
        self.targets = torch.cat([torch.full((samples,), i, dtype=torch.long) for i, samples in enumerate(self.samples_per_class)])

    def __len__(self):
        return self.total_samples

    def __getitem__(self, idx):
        # For simplicity, just return the index as the data
        return idx, self.targets[idx]

# Parameters
num_classes = 25
max_samples_per_class = 1000

# Create dataset
dataset = LongTailDataset(num_classes, max_samples_per_class)

# Create dataloader
batch_size = 64
sampler1 = get_sampler('instance-balanced', dataset.targets.numpy())
sampler2 = get_sampler('class-balanced', dataset.targets.numpy())
sampler3 = get_sampler('square-root', dataset.targets.numpy())
dataloader1 = DataLoader(dataset, batch_size=64, sampler=sampler1)
dataloader2 = DataLoader(dataset, batch_size=64, sampler=sampler2)
dataloader3 = DataLoader(dataset, batch_size=64, sampler=sampler3)

for (_, target1), (_, target2), (_, target3)  in zip(dataloader1, dataloader2, dataloader3):
    print('Instance-balanced:')
    cls_idx, cls_counts = np.unique(target1.numpy(), return_counts=True)
    print(f'Class indices: {cls_idx}')
    print(f'Class counts: {cls_counts}')
    print('-'*20)
    print('Class-balanced:')
    cls_idx, cls_counts = np.unique(target2.numpy(), return_counts=True)
    print(f'Class indices: {cls_idx}')
    print(f'Class counts: {cls_counts}')
    print('-'*20)
    print('Square-root:')
    cls_idx, cls_counts = np.unique(target3.numpy(), return_counts=True)
    print(f'Class indices: {cls_idx}')
    print(f'Class counts: {cls_counts}')
    break # just show one batch 

Output:

Instance-balanced:
Class indices: [ 0  1  2  3  5 16 22 23]
Class counts: [43  9  5  2  2  1  1  1]
--------------------
Class-balanced:
Class indices: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 16 17 20 21 23]
Class counts: [21  8  6  4  2  1  2  2  3  3  1  2  1  1  1  1  2  1  1  1]
--------------------
Square-root:
Class indices: [ 0  1  2  3  4  5  6  9 10 21 22 23]
Class counts: [37  8  3  6  3  1  1  1  1  1  1  1]

混合采样策略

最早的混合采样是在 \(0\le epoch\le t\)时采用Instance-balanced采样,\(t\le epoch\le T\)时采用Class-balanced采样,这需要设置合适的超参数t。在[1]中,作者提出了soft版本的混合采样策略:Progressively-balanced sampling。随着epoch的增加每个类的采样概率(权重)\(p_j\)也发生变化:

\[p_j^{\mathbf{PB}}(t)=(1-\frac tT)p_j^{\mathbf{IB}}+\frac tTp_j^{\mathbf{CB}} \]

t表示当前epoch,T表示总epoch数。

不平衡数据集下的采样策略

不平衡的数据集,特别是长尾数据集,为了照顾尾部类,通常设置每个类的采样概率(权重)为样本数的倒数,即\(p_j=\frac{1}{n_j}\)。

...
elif sampling_type == 'inverse':
    cls_weights = 1. / cls_counts
...

在[3]中提出了有效数(effective number)的概念,分母的位置不是简单的样本数,而是经过一定计算得到的,这里直接给出结果,证明请详见原论文。关于effective number的计算方式:

\[E_n=(1-\beta^n)/(1-\beta),\ \mathrm{where~}\beta=(N-1)/N. \]

这里N表示数据集样本总数。

相关代码:

...
elif sampling_type == 'effective':
    beta = (len(targets) - 1) / len(targets)
    cls_weights = (1.0 - beta) / (1.0 - np.power(beta, cls_counts))
...

Output

Effective number:
Class indices: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 16 17 18 20 21 22 23 24]
Class counts: [2 1 2 3 1 1 4 2 3 4 4 2 3 5 2 4 1 3 1 4 5 6 1]

在和上面一样的模拟长尾数据集上,采样的结果更加均衡。

参考文献

  1. Kang, Bingyi, et al. "Decoupling Representation and Classifier for Long-Tailed Recognition." International Conference on Learning Representations. 2019.
  2. torch.utils.data.WeightedRandomSampler
  3. Cui, Yin, et al. "Class-balanced loss based on effective number of samples." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.

标签:采样,机器,sampling,PyTorch,balanced,counts,Class,cls
From: https://www.cnblogs.com/zh-jp/p/18124824

相关文章

  • 机器学习&深度学习 操作tips
    1.在运行程序时,报错如下:usage:run.py[-h]--modelMODEL[--embeddingEMBEDDING][--wordWORD]run.py:error:thefollowingargumentsarerequired:--model答:出现这个问题是因为对于代码不够理解,对于在代码包中有多个models时,举例如下:不同的model类似于定义了不......
  • 人形机器人灵巧手的核心部件 —— 直流永磁伺服电动机
    相关:https://baijiahao.baidu.com/s?id=1770834878767236111&wfr=spider&for=pchttps://news.sohu.com/a/709452614_121742556https://www.bilibili.com/video/BV1gp4y1G7xg/?vd_source=f1d0f27367a99104c397918f0cf362b7......
  • 机器学习 —— MNIST手写体识别
    本文使用工具    Anaconda下载安装与使用    JupyterNotebook的使用    pytorch配置        Jupyternotebook        Pycharm本文使用数据集    机器学习实验所需内容.zip        点击跳转至正文......
  • 【机器学习】2. 支持向量机
    2.支持向量机对偶优化拉格朗日乘数法可用于解决带条件优化问题,其基本形式为:\[\begin{gather}\min_wf(w),\\\mathrm{s.t.}\quad\cases{g_i(w)\le0,\\h_i(w)=0.}\end{gather}\]该问题的拉格朗日函数为\[L(w,\alpha,\beta)=f(w)+\sum_{i}\alpha_ig_i(w)+\sum_j\beta_j......
  • 【机器学习】深入解析机器学习基础
    在本篇深入探讨中,我们将揭开机器学习背后的基础原理,这不仅包括其数学框架,更涵盖了从实际应用到理论探索的全方位视角。机器学习作为数据科学的重要分支,其力量来源于算法的能力,这些算法能够从数据中学习并做出预测或决策。下面,我们将根据提供的目录详细探讨每个部分。学习算法......
  • 人形机器人第三方方案供应商应该具备哪些能力
    和某家人形机器人公司沟通了合作意向,给出了几个合作的可能:给一些简单的API,如控制机器人挥手,控制机器人向前走一步,等等。这些提供的API只能调用机器人公司给定的动作,也就是使用动作规划和正反动力学建立好的一些动作库,然后将这些动作库提供过来,但是这种级别的API或许可以作为教育......
  • 深度探索:机器学习Deep Belief Networks(DBN)算法原理及其应用
    目录1.引言与背景2.定理3.算法原理4.算法实现5.优缺点分析优点:缺点:6.案例应用7.对比与其他算法8.结论与展望1.引言与背景深度学习在近年来取得了显著进展,其在图像识别、语音识别、自然语言处理等多个领域的成功应用引发了广泛的关注。其中,DeepBeliefNetworks......
  • 深度探索:机器学习神经图灵机(Neural Turing Machines, NTMs)原理及其应用
    目录1.引言与背景2.定理3.算法原理4.算法实现5.优缺点分析优点:缺点:6.案例应用7.对比与其他算法8.结论与展望1.引言与背景在人工智能与机器学习的前沿研究中,如何赋予计算机系统更强大的学习与推理能力,使其能模拟人类大脑的复杂认知过程,一直是科学家们不懈探索的......
  • 深度探索:机器学习堆叠泛化(Stacked Generalization, Blending)算法原理及其应用
    目录1.引言与背景2.集成学习定理3.算法原理4.算法实现5.优缺点分析优点:缺点:6.案例应用7.对比与其他算法8.结论与展望1.引言与背景机器学习领域中,模型性能的提升往往依赖于对数据特征的深入理解、恰当的模型选择以及有效的超参数调整。然而,在面对复杂且高度非线性......
  • 深度探索:机器学习多维尺度(MDS)算法原理及其应用
    目录1.引言与背景2.MDS定理3.算法原理4.算法实现5.优缺点分析优点:缺点:6.案例应用7.对比与其他算法8.结论与展望1.引言与背景多维尺度分析(Multi-DimensionalScaling,MDS)是一种统计学方法,用于将复杂、高维的相似性或距离数据转化为直观的、低维的可视化表示。MD......