首页 > 其他分享 >构建不平衡数据集

构建不平衡数据集

时间:2024-03-19 19:44:42浏览次数:21  
标签:num img default data dataset train 构建 平衡 数据

chatgpt呆子,不知道怎么构建不平衡数据及,不会递减的构建,长尾人表示心痛
image

琐碎直接给个万能模板

import argparse
import random

from data_utils import *
from loss import *

import torch.nn as nn
import torch.optim as optim
from torch.utils.data.sampler import WeightedRandomSampler

import os
import torch
import scipy.io as sio

from Meta_train import ResNet32_100

parser = argparse.ArgumentParser(description='Imbalanced Example')
parser.add_argument('--dataset', default='cifar100', type=str,
                    help='dataset (cifar10 or cifar100[default])')
parser.add_argument('--batch-size', type=int, default=100, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--num_classes', type=int, default=100)
parser.add_argument('--num_meta', type=int, default=0,
                    help='The number of meta data for each class.')
parser.add_argument('--imb_factor', type=float, default=0.01) #100
parser.add_argument('--test-batch-size', type=int, default=100, metavar='N',
                    help='input batch size for testing (default: 100)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
                    help='number of epochs to train')
parser.add_argument('--lr', '--learning-rate', default=1e-1, type=float,
                    help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                    help='weight decay (default: 5e-4)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--split', type=int, default=1000)
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--print-freq', '-p', default=100, type=int,
                    help='print frequency (default: 10)')
parser.add_argument('--lam', default=0.25, type=float, help='[0.25, 0.5, 0.75, 1.0]') #default=0.25
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--meta_lr', default=0.1, type=float)
parser.add_argument('--save_name', default='name', type=str)
parser.add_argument('--idx', default='0', type=str)



args = parser.parse_args()
for arg in vars(args):
    print("{}={}".format(arg, getattr(args, arg)))

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= str(args.gpu)
kwargs = {'num_workers': 1, 'pin_memory': False}
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")

train_data_meta, train_data, test_dataset = build_dataset(args.dataset, args.num_meta)

print(f'length of meta dataset:{len(train_data_meta)}')
print(f'length of train dataset: {len(train_data)}')

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True, **kwargs)

np.random.seed(42)
random.seed(42)
torch.manual_seed(args.seed)
classe_labels = range(args.num_classes)

data_list = {}


for j in range(args.num_classes):
    data_list[j] = [i for i, label in enumerate(train_loader.dataset.targets) if label == j]


img_num_list = get_img_num_per_cls(args.dataset, args.imb_factor, args.num_meta*args.num_classes)
print(img_num_list)
print(sum(img_num_list))

im_data = {}
idx_to_del = []
for cls_idx, img_id_list in data_list.items():
    random.shuffle(img_id_list)
    img_num = img_num_list[int(cls_idx)]
    im_data[cls_idx] = img_id_list[img_num:]
    idx_to_del.extend(img_id_list[img_num:])

print(len(idx_to_del))
imbalanced_train_dataset = copy.deepcopy(train_data)
imbalanced_train_dataset.targets = np.delete(train_loader.dataset.targets, idx_to_del, axis=0)
imbalanced_train_dataset.data = np.delete(train_loader.dataset.data, idx_to_del, axis=0)
print(len(imbalanced_train_dataset))
imbalanced_train_loader = torch.utils.data.DataLoader(
    imbalanced_train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)


test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)

best_prec1 = 0

def main():
 imbalanced_train_loader
 test_loader

if __name__ == '__main__':
    main()

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision
import numpy as np
import copy

np.random.seed(6)

def build_dataset(dataset,num_meta):
    normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
                                          (4, 4, 4, 4), mode='reflect').squeeze()),
        transforms.ToPILImage(),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

    if dataset == 'cifar10':
        train_dataset = torchvision.datasets.CIFAR10(root='../cifar-10', train=True, download=False, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10('../cifar-10', train=False, transform=transform_test)
        img_num_list = [num_meta] * 10
        num_classes = 10

    if dataset == 'cifar100':
        train_dataset = torchvision.datasets.CIFAR100(root='../cifar-100', train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR100('../cifar-100', train=False, transform=transform_test)
        img_num_list = [num_meta] * 100
        num_classes = 100

    data_list_val = {}
    for j in range(num_classes):
        data_list_val[j] = [i for i, label in enumerate(train_dataset.targets) if label == j]

    idx_to_meta = []
    idx_to_train = []
    print(img_num_list)

    for cls_idx, img_id_list in data_list_val.items():
        np.random.shuffle(img_id_list)
        img_num = img_num_list[int(cls_idx)]
        idx_to_meta.extend(img_id_list[:img_num])
        idx_to_train.extend(img_id_list[img_num:])
    train_data = copy.deepcopy(train_dataset)
    train_data_meta = copy.deepcopy(train_dataset)

    train_data_meta.data = np.delete(train_dataset.data, idx_to_train,axis=0)
    train_data_meta.targets = np.delete(train_dataset.targets, idx_to_train, axis=0)
    train_data.data = np.delete(train_dataset.data, idx_to_meta, axis=0)
    train_data.targets = np.delete(train_dataset.targets, idx_to_meta, axis=0)

    return train_data_meta, train_data, test_dataset

def get_img_num_per_cls(dataset, imb_factor=None, num_meta=None):

    if dataset == 'cifar10':
        img_max = (50000-num_meta)/10
        cls_num = 10

    if dataset == 'cifar100':
        img_max = (50000-num_meta)/100
        cls_num = 100

    if imb_factor is None:
        return [img_max] * cls_num
    img_num_per_cls = []
    for cls_idx in range(cls_num):
        num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
        img_num_per_cls.append(int(num))
    return img_num_per_cls


标签:num,img,default,data,dataset,train,构建,平衡,数据
From: https://www.cnblogs.com/ZarkY/p/18083782

相关文章

  • 如何处理PHP中的表单数据?
    处理PHP中的表单数据其实就像是一个老师收集学生的作业一样。当学生在作业本上写完作业并交给老师后,老师需要查看、批改这些作业。在网站上,表单就像是学生的作业本,用户填写表单并提交后,网站就需要处理这些数据。下面我会用简单的步骤来解释PHP如何处理表单数据:1.创建表单首......
  • C语言 数据在内存中的存储
    目录前言一、整数在内存中的存储二、大小端字节序和字节序判断2.1.练习一2.2练习二2.3练习三2.4练习四2.5练习五2.6练习六三、浮点数在内存中的存储3.1 浮点数存的过程3.2浮点数取的过程总结前言数据在内存中根据数据类型有不同的存储方式,今天我们......
  • 使用navicat导出查询大量数据结果集并导入到其他数据库(mysql)
    在工作中我们偶尔会遇到处理数据的问题;比如需要将处理后的结果集导出来再导入到数据;具体的的实现方案有:1、导出excel再导入在navicat中选中对应数据库的表对象,右键选择导入,导出向导,如图: 2、将查询的结果集导出成sql实现步骤:先整理查询的sql;执行查询,再选择对话框上面的......
  • Spring JdbcTemplate+Druid数据源+FreeMarker 开发代码生成器
    虽然在这个时代,几乎所有成熟的开发框架都自带代码生成器,但有时候我们难免会遇到没有代码生成器的开发框架,这个时候,自己手中有一套代码生成器,把模版文件调整一下立马就能用,这就比较惬意了。这里讲一下如何利用SpringJdbcTemplate+Druid数据源+FreeMarker开发一套代码生成器。......
  • 多数据源加密(90%来自文心一言)
    在dynamic-datasource-spring-boot-starter3.2.0中,如果你希望对加密的密码进行自定义解密,你需要实现自己的PropertySourceLocator或者自定义配置解析逻辑,以便在读取配置时能够自动解密密码。以下是实现自定义解密逻辑的一般步骤:创建自定义的解密工具类首先,你需要一个能......
  • openGauss Anomaly_detection_数据库指标采集_预测与异常监控
    Anomaly-detection:数据库指标采集、预测与异常监控可获得性本特性自openGauss1.1.0版本开始引入。特性简介anomaly_detection是openGauss集成的、可以用于数据库指标采集、预测以及异常监控与诊断的AI工具,是dbmind套间中的一个组件。支持采集的信息包括IO_Read、IO_Write、CPU......
  • 大数据分析之数据下钻上卷
    声明:本次任务简单所以没有前后端分离去做,因此不需要异步处理(cors)根据Python将数据合并清洗,分析之后,将得到的数据存入数据库,数据库中就是各行业的类别以及数量。前端用java的相关知识利用echarts绘制数据下钻和上卷图前端:<!DOCTYPEhtml><html><head><metacharset="utf-......
  • 使用Python爬取豆瓣电影影评:从数据收集到情感分析
    简介在当今数字化时代,对电影的评价和反馈在很大程度上影响着人们的选择。豆瓣作为一个知名的电影评价平台,汇集了大量用户对电影的评论和评分。本文将介绍如何使用Python编写爬虫来获取豆瓣电影的影评数据,并通过情感分析对评论进行简单的情感评价。环境准备在开始之前,我们需要......
  • 直播预约丨《袋鼠云大数据实操指南》No.1:从理论到实践,离线开发全流程解析
    近年来,新质生产力、数据要素及数据资产入表等新兴概念犹如一股强劲的浪潮,持续冲击并革新着企业数字化转型的观念视野,昭示着一个以数据为核心驱动力的新时代正稳步启幕。面对这些引领经济转型的新兴概念,为了更好地服务于客户并提供切实可行的实践指导,自3月20日起,袋鼠云将推出全新......
  • WPF —— 控件模版和数据模版
    1:控件模版简介:自定义控件模版:自己添加的样式、标签,控件模版也是属于资源的一种,    每一个控件模版都有一唯一的key,在控件上通过template属性进行绑定什么场景下使用自定义控件模版,当项目里面多个地方使用到相同效果,这时候可以把相同    效果封装成一个......