首页 > 其他分享 >使用cifar100上训练的resnet18进行ood测试

使用cifar100上训练的resnet18进行ood测试

时间:2024-09-28 11:03:01浏览次数:1  
标签:resnet18 ood torch loader np cifar100 import data

以cifar100作为闭集(closed-set)数据集,使用resnet18模型进行训练,然后在常见的开集(out-of-distribution)数据集上进行OOD检测。使用MSP(Maximum Softmax Probability)作为OOD检测的依据。

开集噪声数据集使用gaussian, rademacher, blob, svhn四种类型。其中gaussian、rademacher、blob是生成的随机噪声,svhn是额外引入的噪声数据集。

输出结果

Error Rate 46.3000
AUROC: 81.9790, AUPR: 85.7377, FPR95: 73.3909
ood type: gaussian
AUROC: 68.1596, AUPR: 92.9277, FPR95: 99.4000
ood type: rademacher
AUROC: 69.9099, AUPR: 93.1788, FPR95: 96.1500
ood type: blob
AUROC: 68.0615, AUPR: 92.7477, FPR95: 97.5500
ood type: svhn
AUROC: 66.9684, AUPR: 91.6508, FPR95: 89.0500

可以看到,在使用简单的交叉熵损失且不经过其他处理的resnet18,在开集检测上的表示并不好。

闭集数据集上训练一个resnet18

# train.py
import torch
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.datasets.cifar import CIFAR100
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.models import resnet18
from sklearn.metrics import accuracy_score
import torch.nn.functional as F


def get_transform(train=True):
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]
    if train:
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    return transform


def get_loader(train=True):
    transform = get_transform(train)
    dataset = CIFAR100(root='~/data', train=train, transform=transform)
    loader = DataLoader(dataset, batch_size=128, shuffle=train, num_workers=8, pin_memory=True)
    return loader


def train_model():

    loader = get_loader(train=True)
    test_loader = get_loader(train=False)

    model = resnet18(num_classes=100)
    model = model.cuda()

    epochs = 100
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
    scheduler = MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)
    model.eval()
    all_preds = []
    all_labels = []

    for epoch in range(epochs):
        model.train()
        print('Training')
        for i, (inputs, labels) in enumerate(loader):
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                print(f'Epoch[{epoch}] Iter: {i}/{len(loader)} Loss: {loss.item()}')
        scheduler.step()
        print('Testing')
        for inputs, labels in test_loader:
            inputs = inputs.cuda()
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
        accuracy = accuracy_score(all_labels, all_preds)
        print(f'Epoch[{epoch}] acc@1 {accuracy:.4f}')

    torch.save(model.state_dict(), 'cifar100_resnet18.pth')


if __name__ == '__main__':
    train_model()

构建常用的开集数据集

# ood_data.py
import torch
import numpy as np
from torch.utils.data.dataset import TensorDataset
from torch.utils.data.dataloader import DataLoader
from skimage.filters import gaussian
from torchvision.datasets import SVHN

from train import get_transform


def build_ood_loader(noise_type, ood_num_examples, batch_size, worker):
    dummy_targets = torch.ones(ood_num_examples)
    if noise_type in ['gaussian', 'rademacher', 'blob']:
        if noise_type == 'gaussian':
            ood_data = torch.from_numpy(np.float32(np.clip(
                np.random.normal(size=(ood_num_examples, 3, 32, 32), scale=0.5), -1, 1)))
        elif noise_type == 'rademacher':
            ood_data = torch.from_numpy(np.random.binomial(
                n=1, p=0.5, size=(ood_num_examples, 3, 32, 32)).astype(np.float32)) * 2 - 1
        else:
            ood_data = np.float32(np.random.binomial(n=1, p=0.7, size=(ood_num_examples, 32, 32, 3)))
            for i in range(ood_num_examples):
                ood_data[i] = gaussian(ood_data[i], sigma=1.5)
                ood_data[i][ood_data[i] < 0.75] = 0.0
            ood_data = torch.from_numpy(ood_data.transpose((0, 3, 1, 2))) * 2 - 1
        dataset = TensorDataset(ood_data, dummy_targets)
    elif noise_type == 'svhn':
        transform = get_transform(train=False)
        dataset = SVHN(root='~/data/svhn', split='test', transform=transform, download=True)
        data = dataset.data[:ood_num_examples]
        dataset.data = data
    else:
        raise ValueError(f'Unknown noise type: {noise_type}')
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                            num_workers=worker, pin_memory=True)

    return dataloader

使用常见的OOD检测评估指标

# ood_utils.py
import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve


@torch.no_grad()
def get_ood_scores(model, dataloader, closed_set=False):
    model.eval()
    scores = []
    right_scores = []
    wrong_scores = []
    for i, (data, targets) in enumerate(dataloader):
        data = data.cuda()
        output = model(data)
        smax = F.softmax(output, dim=1).cpu().numpy()

        scores.append(np.max(smax, axis=1))

        if closed_set:
            pred = np.argmax(smax, axis=1)
            targets = targets.numpy().squeeze()
            right_indices = pred == targets
            wrong_indices = np.invert(right_indices)
            right_scores.append(np.max(smax[right_indices], axis=1))
            wrong_scores.append(np.max(smax[wrong_indices], axis=1))

    if closed_set:
        return (np.concatenate(scores),
                np.concatenate(right_scores),
                np.concatenate(wrong_scores))
    else:
        return np.concatenate(scores)


def get_performance(pos, neg):
    pos = np.array(pos).reshape(-1)
    neg = np.array(neg).reshape(-1)
    scores = np.concatenate([pos, neg])
    labels = [1] * len(pos) + [0] * len(neg)
    auroc = roc_auc_score(labels, scores)
    aupr = average_precision_score(labels, scores)

    fpr, tpr, _ = roc_curve(labels, scores)
    fpr95 = fpr[np.argmax(tpr >= 0.95)]
    return auroc, aupr, fpr95


def show_performance(pos, neg):
    auroc, aupr, fpr95 = get_performance(pos, neg)
    print(f"AUROC: {auroc * 100:.4f}, AUPR: {aupr * 100:.4f}, FPR95: {fpr95 * 100:.4f}")

测试模型的OOD检测性能

# test.py
import torch
from torchvision.models import resnet18


from train import get_loader
from ood_utils import get_ood_scores, show_performance
from ood_data import build_ood_loader


def evaluate():
    model = resnet18(num_classes=100)
    model.load_state_dict(torch.load('cifar100_resnet18.pth'))
    model = model.cuda()

    # closed-set test
    test_loader = get_loader(train=False)
    in_score, right_score, wrong_score = get_ood_scores(model, test_loader, True)
    num_right, num_wrong = len(right_score), len(wrong_score)

    print(f'Error Rate {100 * num_wrong / (num_right + num_wrong):.4f}')
    show_performance(right_score, wrong_score)
    # open-set test
    ood_num_examples = len(test_loader.dataset) // 5
    ood_types = ['gaussian', 'rademacher', 'blob', 'svhn']
    for i in ood_types:
        print(f'ood type: {i}')
        ood_loader = build_ood_loader(i, ood_num_examples, batch_size=128, worker=8)
        out_score = get_ood_scores(model, ood_loader)
        show_performance(in_score, out_score)


if __name__ == '__main__':
    evaluate()

依赖

scikit-learn       1.5.2
scipy              1.14.1
torch              2.4.1

标签:resnet18,ood,torch,loader,np,cifar100,import,data
From: https://www.cnblogs.com/zh-jp/p/18437126

相关文章

  • 云汉芯城、立创商城、亿配芯城(ICGOODFIND)对比有什么区别?
    亿配芯城(ICGOODFIND)与云汉芯城、立创商城在电子元器件采购领域各有其特色和优势,以下是对这三者区别的详细对比:一、平台定位与服务范围亿配芯城(ICGOODFIND):作为超10年的电子元器件平台,ICGOODFIND专注于为全球客户提供高品质的电子元器件采购服务。该平台覆盖了全球多个国家和地......
  • food facts食物营养成分数据集en.openfoodfacts.org.products
    可以参考这里的apiApi.md·琴弦断丶冷笛残/World_Food_Facts_Web_Demo-Gitee.com内容如图:最新官方文件有1Gen.openfoodfacts.org.products.tsv(1.01GB)有个版本是2017年的,50M左右的,解开340M左右 相关:GitHub-openfoodfacts/openfoodfacts-ai:Thisisatrackin......
  • 开放食物营养库python SDK套件:openfoodfacts-python
    官网源码:GitHub-openfoodfacts/openfoodfacts-python:......
  • 阻止SYN Flood攻击
    SYNFlood攻击介绍SYNFlood攻击是一种拒绝服务(DoS)攻击,攻击者向目标服务器发送大量SYN请求,以半开连接压垮目标服务器,这会消耗服务器资源,阻止合法用户建立连接。这种攻击会破坏服务并降低网络性能,使其成为网络攻击中常用的方法。SYNFlood攻击对网络安全的影响可能非常严......
  • 一文学会开源图书库Koodo+Reader本地Windows电脑安装与远程访问
    文章目录前言1.KoodoReader功能特点1.1开源免费1.2支持众多格式1.3多平台兼容1.4多端数据备份同步1.5多功能阅读体验1.6界面简洁直观2.KoodoReader安装流程2.1安装Git2.2安装Node.js2.3下载koodoreader3.安装Cpolar内网穿透3.1配置公网地址3.2配置固......
  • Graph Edge Partitioning via Neighborhood Heuristic
    目录概符号说明VertexvsEdgepartitioningNE(NeighborExpansion)代码ZhangC.,WeiF.,LiuQ.,TangZ.G.andLiZ.Graphedgepartitioningvianeighborhoodheuristic.KDD,2017.概本文提出了一种图分割方法(edgepartitioning),保证只有少量的重复结点.符号......
  • 【优秀程序设计】【good-practice】聚合系统如何实现通道侧回调的业务结果通知?
    §.短信平台(聚合系统)的回调-业务说明我司短信平台聚合了朗宇、漫道、华信等多家短信服务商通道,并输出统一的接口能力供业务系统使用。本文以短信平台(sms)为例。来说一下各短信通道回调sms的代码实现。注:下文提到的”短信服务商“、”短信通道“、”通道“表示相同概念。  ......
  • 利用深度学习实现验证码识别-4-ResNet18+imagecaptcha
    在当今的数字化世界中,验证码(CAPTCHA)是保护网站免受自动化攻击的重要工具。然而,对于用户来说,验证码有时可能会成为一种烦恼。为了解决这个问题,我们可以利用深度学习技术来自动识别验证码,从而提高用户体验。本文将介绍如何使用ResNet18模型来识别ImageCaptcha生成的验证码。......
  • 【机器学习】8. 逻辑斯蒂回归 Logistic function(sigmoid),cross-entropy error,Log-l
    Logisticfunction线性分类器Logisticfunction(sigmoid)极大似然估计Log-likelihoodloss线性分类器Logisticregression是一个线性分类器。如图,1为蓝色,0为红色。这条直线叫做直线边界Logisticfunction(sigmoid)......
  • C. Turtle and Good Pairs
    https://codeforces.com/contest/2003/problem/C题意:。。。思路:如果要使满足条件的有序对最多,那么首先如果两个字符相等,那么无论如何排列,最终的贡献值都不会变。再看字符不相等的情况,假如有aabbcc,那么abcabc总是优于aabbcc,因为如果一个字符出现了多次,那么像aab,bcc这种就会没......