ViT 的 pytorch实现代码:
import torch from torch import nn from einops import rearrange, repeat from einops.layers.torch import Rearrange # helpers def pair(t): return t if isinstance(t, tuple) else (t, t) # classes class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): super().__init__() self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return self.norm(x) class ViT(nn.Module): def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): super().__init__() image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_height // patch_height) * (image_width // patch_width) patch_dim = channels * patch_height * patch_width assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim), ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self.pool = pool self.to_latent = nn.Identity() self.mlp_head = nn.Linear(dim, num_classes) def forward(self, img): x = self.to_patch_embedding(img) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x) x = self.transformer(x) x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] x = self.to_latent(x) return self.mlp_head(x)View Code
classfy_main.py
1 import torch 2 from torch.utils.data import DataLoader 3 from torch import nn, optim 4 from torchvision import datasets, transforms 5 from torchvision.transforms.functional import InterpolationMode 6 7 from matplotlib import pyplot as plt 8 9 import time 10 11 from Lenet5 import Lenet5_new 12 # from Resnet18 import ResNet18,ResNet18_new 13 # from AlexNet import AlexNet 14 # from Vgg16 import VGGNet16 15 # from Densenet import DenseNet121, DenseNet169, DenseNet201, DenseNet264 16 17 # from NIN import NIN_Net 18 # from GoogleNet import GoogLeNet 19 # from MobileNet_v3 import mobilenet_v3 20 from shuffleNet import shuffleNet_g3_ 21 22 from vit import ViT 23 24 def main(): 25 26 print("Load datasets...") 27 28 # transforms.RandomHorizontalFlip(p=0.5)---以0.5的概率对图片做水平横向翻转 29 # transforms.ToTensor()---shape从(H,W,C)->(C,H,W), 每个像素点从(0-255)映射到(0-1):直接除以255 30 # transforms.Normalize---先将输入归一化到(0,1),像素点通过"(x-mean)/std",将每个元素分布到(-1,1) 31 transform_train = transforms.Compose([ 32 transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC), 33 # transforms.RandomCrop(32, padding=4), # 先四周填充0,在吧图像随机裁剪成32*32 34 transforms.RandomHorizontalFlip(p=0.5), 35 transforms.ToTensor(), 36 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 37 ]) 38 39 transform_test = transforms.Compose([ 40 transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC), 41 # transforms.RandomCrop(32, padding=4), # 先四周填充0,在吧图像随机裁剪成32*32 42 transforms.ToTensor(), 43 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 44 ]) 45 46 dataset_path = "/big-data/person/zhaopengpeng/deepfake_zpp/code/Transformer_code/Coco_train_code" 47 48 # 内置函数下载数据集 49 train_dataset = datasets.CIFAR10(root= dataset_path +"/data/Cifar10/", train=True, 50 transform = transform_train, 51 download=True) 52 test_dataset = datasets.CIFAR10(root = dataset_path +"/data/Cifar10/", 53 train = False, 54 transform = transform_test, 55 download=True) 56 57 print(len(train_dataset), len(test_dataset)) 58 59 Batch_size = 64 60 train_loader = DataLoader(train_dataset, batch_size=Batch_size, shuffle = True, num_workers=4) 61 test_loader = DataLoader(test_dataset, batch_size = Batch_size, shuffle = False, num_workers=4) 62 63 # 设置CUDA 64 device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 65 66 # 初始化模型 67 # 直接更换模型就行,其他无需操作 68 # model = Lenet5_new().to(device) 69 # model = ResNet18().to(device) 70 # model = ResNet18_new().to(device) 71 # model = VGGNet16().to(device) 72 # model = DenseNet121().to(device) 73 # model = DenseNet169().to(device) 74 75 # model = NIN_Net().to(device) 76 77 # model = GoogLeNet().to(device) 78 # model = mobilenet_v3().to(device) 79 80 # model = ViT(image_size=(32, 32), patch_size=(4, 4), num_classes=10, dim=256, depth=6, heads=8, mlp_dim=512, dropout=0.1) 81 model = ViT(image_size=(224, 224), patch_size=(16, 16), num_classes=10, dim=256, depth=6, heads=8, mlp_dim=512, dropout=0.1).to(device) 82 83 # model = shuffleNet_g3_().to(device) 84 # model = AlexNet(num_classes=10, init_weights=True).to(device) 85 print(" ViTViT train...") 86 87 # 构造损失函数和优化器 88 criterion = nn.CrossEntropyLoss() # 多分类softmax构造损失 89 # opt = optim.SGD(model.parameters(), lr=0.01, momentum=0.8, weight_decay=0.001) 90 opt = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005) 91 92 # 动态更新学习率 ------每隔step_size : lr = lr * gamma 93 schedule = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.6, last_epoch=-1) 94 95 # 开始训练 96 print("Start Train...") 97 98 epochs = 100 99 100 loss_list = [] 101 train_acc_list =[] 102 test_acc_list = [] 103 epochs_list = [] 104 105 for epoch in range(0, epochs): 106 107 start = time.time() 108 109 model.train() 110 111 running_loss = 0.0 112 batch_num = 0 113 114 for i, (inputs, labels) in enumerate(train_loader): 115 116 inputs, labels = inputs.to(device), labels.to(device) 117 118 # 将数据送入模型训练 119 outputs = model(inputs) 120 # 计算损失 121 loss = criterion(outputs, labels).to(device) 122 123 # 重置梯度 124 opt.zero_grad() 125 # 计算梯度,反向传播 126 loss.backward() 127 # 根据反向传播的梯度值优化更新参数 128 opt.step() 129 130 # 100个batch的 loss 之和 131 running_loss += loss.item() 132 # loss_list.append(loss.item()) 133 batch_num+=1 134 135 136 epochs_list.append(epoch) 137 138 # 每一轮结束输出一下当前的学习率 lr 139 lr_1 = opt.param_groups[0]['lr'] 140 print("learn_rate:%.15f" % lr_1) 141 schedule.step() 142 143 end = time.time() 144 print('epoch = %d/100, batch_num = %d, loss = %.6f, time = %.3f' % (epoch+1, batch_num, running_loss/batch_num, end-start)) 145 running_loss=0.0 146 147 # 每个epoch训练结束,都进行一次测试验证 148 model.eval() 149 train_correct = 0.0 150 train_total = 0 151 152 test_correct = 0.0 153 test_total = 0 154 155 # 训练模式不需要反向传播更新梯度 156 with torch.no_grad(): 157 158 # print("=======================train=======================") 159 for inputs, labels in train_loader: 160 inputs, labels = inputs.to(device), labels.to(device) 161 outputs = model(inputs) 162 163 pred = outputs.argmax(dim=1) # 返回每一行中最大值元素索引 164 train_total += inputs.size(0) 165 train_correct += torch.eq(pred, labels).sum().item() 166 167 168 # print("=======================test=======================") 169 for inputs, labels in test_loader: 170 inputs, labels = inputs.to(device), labels.to(device) 171 outputs = model(inputs) 172 173 pred = outputs.argmax(dim=1) # 返回每一行中最大值元素索引 174 test_total += inputs.size(0) 175 test_correct += torch.eq(pred, labels).sum().item() 176 177 print("train_total = %d, Accuracy = %.5f %%, test_total= %d, Accuracy = %.5f %%" %(train_total, 100 * train_correct / train_total, test_total, 100 * test_correct / test_total)) 178 179 train_acc_list.append(100 * train_correct / train_total) 180 test_acc_list.append(100 * test_correct / test_total) 181 182 # print("Accuracy of the network on the 10000 test images:%.5f %%" % (100 * test_correct / test_total)) 183 # print("===============================================") 184 185 fig = plt.figure(figsize=(4, 4)) 186 187 plt.plot(epochs_list, train_acc_list, label='train_acc_list') 188 plt.plot(epochs_list, test_acc_list, label='test_acc_list') 189 plt.legend() 190 plt.title("train_test_acc") 191 # plt.savefig('shuffleNet_g3_acc_epoch_{:04d}.png'.format(epochs)) 192 plt.savefig('ViT_acc_epoch_{:04d}.png'.format(epochs)) 193 plt.close() 194 195 if __name__ == "__main__": 196 197 main()View Code
标签:11,dim,nn,self,device,train,ViT,图像,test From: https://www.cnblogs.com/zhaopengpeng/p/18183840