整体网络结构如下:
最关键的改进是使用了一个叫深度可分离卷积的结构,将原始的3*3卷积升通道的操作分解成了两部分:
第一部分是保持通道不变的情况下做3*3卷积。
第二部分是使用1*1的卷积做通道提升操作。
结果就是能够减少很多的运算量。
下面依然是一个猫狗大战的训练程序,并且增加了断点续练的部分处理。
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from PIL import Image import torchvision.transforms as transforms import os class DepthwiseSeparableConv(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super(DepthwiseSeparableConv, self).__init__() self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False) self.bn1 = nn.BatchNorm2d(in_channels) self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.depthwise(x) x = self.bn1(x) x = self.relu(x) x = self.pointwise(x) x = self.bn2(x) x = self.relu(x) return x class MobileNet(nn.Module): def __init__(self, num_classes=1000): super(MobileNet, self).__init__() self.layers = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), DepthwiseSeparableConv(32, 64, 1), DepthwiseSeparableConv(64, 128, 2), DepthwiseSeparableConv(128, 128, 1), DepthwiseSeparableConv(128, 256, 2), DepthwiseSeparableConv(256, 256, 1), DepthwiseSeparableConv(256, 512, 2), # Repeat 5 times *[DepthwiseSeparableConv(512, 512, 1) for _ in range(5)], DepthwiseSeparableConv(512, 1024, 2), DepthwiseSeparableConv(1024, 1024, 1), nn.AdaptiveAvgPool2d(1) ) self.classifier = nn.Sequential( nn.Dropout(0.5), nn.Linear(1024, num_classes) ) def forward(self, x): out = self.layers(x) out = out.view(out.size(0), -1) out = self.classifier(out) return out # 自定义数据集类 class CustomDataset(Dataset): def __init__(self, image_folder, transform=None): self.image_folder = image_folder self.transform = transform def __len__(self): return 20000 def __getitem__(self, index): image_name = str(index+1)+".jpg" image = Image.open(self.image_folder + '/' + image_name).convert('RGB') image = self.transform(image) if index < 10000: return image, 0 # cat else: return image, 1 # dog num_epochs = 10 # 创建MobileNet模型和优化器 model = MobileNet(num_classes=2) # 加入L2正则化操作 regularization = 0.001 for param in model.parameters(): param.data = param.data + regularization * torch.randn_like(param) optimizer = optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() # 定义检查点路径和文件名 checkpoint_dir = './checkpoints' checkpoint_file = 'checkpoint.pt' # 检查检查点目录是否存在,如果不存在则创建 if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # 检查是否存在之前的检查点,如果存在则加载模型和优化器状态 if os.path.isfile(os.path.join(checkpoint_dir, checkpoint_file)): checkpoint = torch.load(os.path.join(checkpoint_dir, checkpoint_file)) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] if torch.cuda.is_available(): for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.cuda() else: start_epoch = 0 # 定义训练和测试数据集的转换方式 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # 加载数据集并进行训练 train_dataset = CustomDataset('./cat_vs_dog/train2', transform) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) for epoch in range(start_epoch, num_epochs): model.train() running_loss = 0.0 correct = 0 total = 0 for images, labels in train_loader: images = images.to(device) labels = labels.to(device) # 前向传播 outputs = model(images) loss = criterion(outputs, labels) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}, Accuracy: {(100 * correct / total):.2f}%") # 保存检查点 checkpoint = { 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() } torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_file)) print('Training finished.') # 保存模型 torch.save(model.state_dict(), 'mobilenet.pth')标签:__,checkpoint,torch,nn,self,MobileNetV1,学习,state,深度 From: https://www.cnblogs.com/tiandsp/p/17688491.html