chatgpt呆子,不知道怎么构建不平衡数据及,不会递减的构建,长尾人表示心痛
琐碎直接给个万能模板
import argparse
import random
from data_utils import *
from loss import *
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.sampler import WeightedRandomSampler
import os
import torch
import scipy.io as sio
from Meta_train import ResNet32_100
parser = argparse.ArgumentParser(description='Imbalanced Example')
parser.add_argument('--dataset', default='cifar100', type=str,
help='dataset (cifar10 or cifar100[default])')
parser.add_argument('--batch-size', type=int, default=100, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--num_classes', type=int, default=100)
parser.add_argument('--num_meta', type=int, default=0,
help='The number of meta data for each class.')
parser.add_argument('--imb_factor', type=float, default=0.01) #100
parser.add_argument('--test-batch-size', type=int, default=100, metavar='N',
help='input batch size for testing (default: 100)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='number of epochs to train')
parser.add_argument('--lr', '--learning-rate', default=1e-1, type=float,
help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
help='weight decay (default: 5e-4)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--split', type=int, default=1000)
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--print-freq', '-p', default=100, type=int,
help='print frequency (default: 10)')
parser.add_argument('--lam', default=0.25, type=float, help='[0.25, 0.5, 0.75, 1.0]') #default=0.25
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--meta_lr', default=0.1, type=float)
parser.add_argument('--save_name', default='name', type=str)
parser.add_argument('--idx', default='0', type=str)
args = parser.parse_args()
for arg in vars(args):
print("{}={}".format(arg, getattr(args, arg)))
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= str(args.gpu)
kwargs = {'num_workers': 1, 'pin_memory': False}
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
train_data_meta, train_data, test_dataset = build_dataset(args.dataset, args.num_meta)
print(f'length of meta dataset:{len(train_data_meta)}')
print(f'length of train dataset: {len(train_data)}')
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, **kwargs)
np.random.seed(42)
random.seed(42)
torch.manual_seed(args.seed)
classe_labels = range(args.num_classes)
data_list = {}
for j in range(args.num_classes):
data_list[j] = [i for i, label in enumerate(train_loader.dataset.targets) if label == j]
img_num_list = get_img_num_per_cls(args.dataset, args.imb_factor, args.num_meta*args.num_classes)
print(img_num_list)
print(sum(img_num_list))
im_data = {}
idx_to_del = []
for cls_idx, img_id_list in data_list.items():
random.shuffle(img_id_list)
img_num = img_num_list[int(cls_idx)]
im_data[cls_idx] = img_id_list[img_num:]
idx_to_del.extend(img_id_list[img_num:])
print(len(idx_to_del))
imbalanced_train_dataset = copy.deepcopy(train_data)
imbalanced_train_dataset.targets = np.delete(train_loader.dataset.targets, idx_to_del, axis=0)
imbalanced_train_dataset.data = np.delete(train_loader.dataset.data, idx_to_del, axis=0)
print(len(imbalanced_train_dataset))
imbalanced_train_loader = torch.utils.data.DataLoader(
imbalanced_train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)
best_prec1 = 0
def main():
imbalanced_train_loader
test_loader
if __name__ == '__main__':
main()
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision
import numpy as np
import copy
np.random.seed(6)
def build_dataset(dataset,num_meta):
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
transform_test = transforms.Compose([
transforms.ToTensor(),
normalize
])
if dataset == 'cifar10':
train_dataset = torchvision.datasets.CIFAR10(root='../cifar-10', train=True, download=False, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10('../cifar-10', train=False, transform=transform_test)
img_num_list = [num_meta] * 10
num_classes = 10
if dataset == 'cifar100':
train_dataset = torchvision.datasets.CIFAR100(root='../cifar-100', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR100('../cifar-100', train=False, transform=transform_test)
img_num_list = [num_meta] * 100
num_classes = 100
data_list_val = {}
for j in range(num_classes):
data_list_val[j] = [i for i, label in enumerate(train_dataset.targets) if label == j]
idx_to_meta = []
idx_to_train = []
print(img_num_list)
for cls_idx, img_id_list in data_list_val.items():
np.random.shuffle(img_id_list)
img_num = img_num_list[int(cls_idx)]
idx_to_meta.extend(img_id_list[:img_num])
idx_to_train.extend(img_id_list[img_num:])
train_data = copy.deepcopy(train_dataset)
train_data_meta = copy.deepcopy(train_dataset)
train_data_meta.data = np.delete(train_dataset.data, idx_to_train,axis=0)
train_data_meta.targets = np.delete(train_dataset.targets, idx_to_train, axis=0)
train_data.data = np.delete(train_dataset.data, idx_to_meta, axis=0)
train_data.targets = np.delete(train_dataset.targets, idx_to_meta, axis=0)
return train_data_meta, train_data, test_dataset
def get_img_num_per_cls(dataset, imb_factor=None, num_meta=None):
if dataset == 'cifar10':
img_max = (50000-num_meta)/10
cls_num = 10
if dataset == 'cifar100':
img_max = (50000-num_meta)/100
cls_num = 100
if imb_factor is None:
return [img_max] * cls_num
img_num_per_cls = []
for cls_idx in range(cls_num):
num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
img_num_per_cls.append(int(num))
return img_num_per_cls
标签:num,img,default,data,dataset,train,构建,平衡,数据
From: https://www.cnblogs.com/ZarkY/p/18083782