主要分两步走,先训练好模型,保存模型,然后再读取模型,保存特征
①训练模型,保存模型
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.sampler import WeightedRandomSampler
import torch.nn.functional as F
import os
# 定义基本的ResNet块
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet32(nn.Module):
def __init__(self, block, num_blocks, num_classes=100, feature_size=4096):
super(ResNet32, self).__init__()
self.in_planes = 16
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(64 * block.expansion, feature_size)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward_features(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.avgpool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
features = out
# print("Shape of features in forward_features:", features.shape) # 添加这行代码来打印特征的形状
# print("Shape of fc weight matrix:", self.fc.weight.shape)
return out, features
def forward(self, x):
return self.forward_features(x)[0] # 返回 forward_features 的第一个输出
def ResNet32_100():
return ResNet32(BasicBlock, [5, 5, 5], num_classes=100, feature_size=4096) # 添加 feature_size 参数
# 定义不平衡采样器
class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
def __init__(self, dataset, indices=None, num_samples=None):
self.indices = list(range(len(dataset))) if indices is None else indices
self.num_samples = len(self.indices) if num_samples is None else num_samples
label_to_count = {}
for idx in self.indices:
label = self._get_label(dataset, idx)
if label in label_to_count:
label_to_count[label] += 1
else:
label_to_count[label] = 1
weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
for idx in self.indices]
self.weights = torch.DoubleTensor(weights)
def _get_label(self, dataset, idx):
return dataset.targets[idx]
def __iter__(self):
return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True))
def __len__(self):
return self.num_samples
# 数据预处理
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# 加载CIFAR-100数据集
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
# 创建不平衡采样器
train_indices = torch.randperm(len(trainset)).tolist()
num_samples = int(len(train_indices) * 0.1) # 10%的样本被标签为100
imbalanced_sampler = ImbalancedDatasetSampler(trainset, indices=train_indices[:num_samples], num_samples=len(trainset))
# 创建数据加载器
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, sampler=imbalanced_sampler, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
# 初始化模型、损失函数和优化器
net = ResNet32_100()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# 将模型转移到GPU(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
# 训练模型
best_prec1 = 0
def main():
global best_prec1
for epoch in range(170): # 你可以调整这个值来增加训练轮数
net.train()
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99: # 每100个小批量打印一次损失
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
# 评估模型
prec1 = validate(testloader, net, criterion, device)
print("epoch:{},prec1:{}".format(epoch, prec1))
# 保存模型
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'state_dict': net.state_dict(),
'best_prec1': best_prec1,
'optimizer': optimizer.state_dict(),
}, is_best,epoch+1,prec1) # 传递 is_best 参数
print('Finished Training')
def save_checkpoint(state, is_best, epoch, accuracy):
path = 'checkpoint/ours/'
if not os.path.exists(path):
os.makedirs(path)
filename = str(epoch) + '_' + str(accuracy) + '.pth.tar' # 将整数转换为字符串
if is_best:
torch.save(state, os.path.join(path, filename)) # 修正保存路径
def validate(val_loader, model, criterion, device):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data in val_loader:
inputs, labels = data[0].to(device), data[1].to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = 100 * correct / total
return acc
if __name__ == '__main__':
main()
②加载模型,提取特征,保存特征
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import scipy.io as sio
# 加载已保存的 ResNet32 模型
from Test1 import ResNet32_100
def load_model(model, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
return model
# 将 CIFAR-100 数据集加载到数据加载器中
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=False)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)
# 定义一个函数来提取特征并保存为.mat文件
def extract_features(model, dataloader):
model.eval()
features_list = []
with torch.no_grad():
for images, _ in dataloader:
outputs, features = model.forward_features(images)
features_list.append(features)
features = torch.cat(features_list, dim=0)
print("Shape of features:", features.shape) # 打印特征的形状
return features
if __name__ == '__main__':
# 加载已保存的模型
model = ResNet32_100()
model = load_model(model, '/home/zy/pycharm/project/temp/MetaSAug_1/test/checkpoint/ours/138_33.67.pth.tar')
# 提取特征 train
# features = extract_features(model, trainloader)
# 提取特征 test
features = extract_features(model, testloader)
# 将特征保存为.mat文件
features_dict = {'features': features.cpu().numpy()}
filename='Test_138_33.67'
sio.savemat('/home/zy/pycharm/project/temp/MetaSAug_1/test/matFile/'+filename+'.mat', features_dict)
# 打印.mat文件的大小
print("The size of the .mat file is:", features.shape[0], "x", features.shape[1])
重要打印指令
# 添加这行代码来打印特征的形状
# print("Shape of features in forward_features:", features.shape)
# print("Shape of fc weight matrix:", self.fc.weight.shape)
# 打印.mat文件的大小
print("The size of the .mat file is:", features.shape[0], "x", features.shape[1])
标签:__,提取,mat,self,torch,num,features,cifar100,out
From: https://www.cnblogs.com/ZarkY/p/18082567