训练的数据集:
含有数据集的:链接:https://pan.baidu.com/s/1u8N_yRnxrNoIMc4aP55rcQ 提取码:6wfe
不含数据集的:链接:https://pan.baidu.com/s/1BNVj2XSajJx8u1ZlKadnmw 提取码:xrng
model.py
1 import numpy as np 2 import cv2 3 import torch 4 import torch.nn as nn 5 import torch.optim as optim 6 import torch.nn.functional as F 7 8 class AlexNet(nn.Module): 9 def __init__(self,num_classes=1000,init_weights=False): 10 super(AlexNet, self).__init__() 11 self.features = nn.Sequential( #Sequential能将层结构打包 12 nn.Conv2d(3,48,kernel_size=11,stride=4,padding=2), #input_channel=3,output_channel=48 13 nn.ReLU(inplace=True), 14 nn.MaxPool2d(kernel_size=3,stride=2), 15 16 nn.Conv2d(48, 128, kernel_size=5, padding=2), # input_channel=3,output_channel=48 17 nn.ReLU(inplace=True), 18 nn.MaxPool2d(kernel_size=3, stride=2), 19 20 nn.Conv2d(128, 192, kernel_size=3, padding=1), # input_channel=3,output_channel=48 21 nn.ReLU(inplace=True), 22 23 nn.Conv2d(192, 192, kernel_size=3, padding=1), # input_channel=3,output_channel=48 24 nn.ReLU(inplace=True), 25 26 nn.Conv2d(192, 128, kernel_size=3, padding=1), # input_channel=3,output_channel=48 27 nn.ReLU(inplace=True), 28 nn.MaxPool2d(kernel_size=3, stride=2), 29 ) 30 self.classifier = nn.Sequential( 31 nn.Dropout(p=0.5), #默认随机失活 32 nn.Linear(128*6*6,2048), 33 nn.ReLU(inplace=True), 34 nn.Dropout(p=0.5), # 默认随机失活 35 nn.Linear(2048, 2048), 36 nn.ReLU(inplace=True), 37 nn.Linear(2048,num_classes), 38 ) 39 if init_weights: 40 self._initialize_weights() 41 42 def forward(self,x): 43 x = self.features(x) 44 x = torch.flatten(x,start_dim=1) 45 x = self.classifier(x) 46 return x 47 48 def _initialize_weights(self): 49 for m in self.modules(): 50 if isinstance(m,nn.Conv2d): 51 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #凯明初始化,国人大佬 52 if m.bias is not None: 53 nn.init.constant_(m.bias,0) 54 elif isinstance(m,nn.Linear): 55 nn.init.normal_(m.weight,0,0.01) 56 nn.init.constant_(m.bias,0)View Code
train.py
1 import os 2 import sys 3 import json 4 5 import torch 6 import torch.nn as nn 7 from torchvision import transforms, datasets, utils 8 import matplotlib.pyplot as plt 9 import numpy as np 10 import torch.optim as optim 11 from tqdm import tqdm 12 13 from model import AlexNet 14 15 16 def main(): 17 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 print("using {} device.".format(device)) 19 20 data_transform = { 21 "train": transforms.Compose([transforms.RandomResizedCrop(224), 22 transforms.RandomHorizontalFlip(), 23 transforms.ToTensor(), 24 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), 25 "val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224) 26 transforms.ToTensor(), 27 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} 28 29 data_root = os.path.abspath(os.path.join(os.getcwd(), "./")) # get data root path 30 image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path 31 assert os.path.exists(image_path), "{} path does not exist.".format(image_path) 32 train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), 33 transform=data_transform["train"]) 34 train_num = len(train_dataset) 35 36 # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} 37 flower_list = train_dataset.class_to_idx 38 cla_dict = dict((val, key) for key, val in flower_list.items()) 39 # write dict into json file 40 json_str = json.dumps(cla_dict, indent=4) 41 with open('class_indices.json', 'w') as json_file: 42 json_file.write(json_str) 43 44 batch_size = 32 45 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 46 print('Using {} dataloader workers every process'.format(nw)) 47 48 train_loader = torch.utils.data.DataLoader(train_dataset, 49 batch_size=batch_size, shuffle=True, 50 num_workers=nw) 51 52 validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), 53 transform=data_transform["val"]) 54 val_num = len(validate_dataset) 55 validate_loader = torch.utils.data.DataLoader(validate_dataset, 56 batch_size=4, shuffle=False, 57 num_workers=nw) 58 59 print("using {} images for training, {} images for validation.".format(train_num, 60 val_num)) 61 # test_data_iter = iter(validate_loader) 62 # test_image, test_label = test_data_iter.next() 63 # 64 # def imshow(img): 65 # img = img / 2 + 0.5 # unnormalize 66 # npimg = img.numpy() 67 # plt.imshow(np.transpose(npimg, (1, 2, 0))) 68 # plt.show() 69 # 70 # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4))) 71 # imshow(utils.make_grid(test_image)) 72 73 net = AlexNet(num_classes=5, init_weights=True) 74 75 net.to(device) 76 loss_function = nn.CrossEntropyLoss() 77 # pata = list(net.parameters()) 78 optimizer = optim.Adam(net.parameters(), lr=0.0002) 79 80 epochs = 10 81 save_path = './AlexNet.pth' 82 best_acc = 0.0 83 train_steps = len(train_loader) 84 for epoch in range(epochs): 85 # train 86 net.train() 87 running_loss = 0.0 88 train_bar = tqdm(train_loader, file=sys.stdout) 89 for step, data in enumerate(train_bar): 90 images, labels = data 91 optimizer.zero_grad() 92 outputs = net(images.to(device)) 93 loss = loss_function(outputs, labels.to(device)) 94 loss.backward() 95 optimizer.step() 96 97 # print statistics 98 running_loss += loss.item() 99 100 train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, 101 epochs, 102 loss) 103 104 # validate 105 net.eval() 106 acc = 0.0 # accumulate accurate number / epoch 107 with torch.no_grad(): 108 val_bar = tqdm(validate_loader, file=sys.stdout) 109 for val_data in val_bar: 110 val_images, val_labels = val_data 111 outputs = net(val_images.to(device)) 112 predict_y = torch.max(outputs, dim=1)[1] 113 acc += torch.eq(predict_y, val_labels.to(device)).sum().item() 114 115 val_accurate = acc / val_num 116 print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % 117 (epoch + 1, running_loss / train_steps, val_accurate)) 118 119 if val_accurate > best_acc: 120 best_acc = val_accurate 121 torch.save(net.state_dict(), save_path) 122 123 print('Finished Training') 124 125 126 if __name__ == '__main__': 127 main()View Code
predict.py
1 import os 2 import json 3 4 import torch 5 from PIL import Image 6 from torchvision import transforms 7 import matplotlib.pyplot as plt 8 9 from model import AlexNet 10 11 12 def main(): 13 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 15 data_transform = transforms.Compose( 16 [transforms.Resize((224, 224)), 17 transforms.ToTensor(), 18 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 19 20 # load image 21 img_path = "./1.png" 22 assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 23 img = Image.open(img_path) 24 25 plt.imshow(img) 26 # [N, C, H, W] 27 img = data_transform(img) 28 # expand batch dimension 29 img = torch.unsqueeze(img, dim=0) 30 31 # read class_indict 32 json_path = './class_indices.json' 33 assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) 34 35 with open(json_path, "r") as f: 36 class_indict = json.load(f) 37 38 # create model 39 model = AlexNet(num_classes=5).to(device) 40 41 # load model weights 42 weights_path = "./AlexNet.pth" 43 assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path) 44 model.load_state_dict(torch.load(weights_path)) 45 46 model.eval() 47 with torch.no_grad(): 48 # predict class 49 output = torch.squeeze(model(img.to(device))).cpu() 50 predict = torch.softmax(output, dim=0) 51 predict_cla = torch.argmax(predict).numpy() 52 53 print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], 54 predict[predict_cla].numpy()) 55 plt.title(print_res) 56 for i in range(len(predict)): 57 print("class: {:10} prob: {:.3}".format(class_indict[str(i)], 58 predict[i].numpy())) 59 plt.show() 60 61 62 if __name__ == '__main__': 63 main()View Code
标签:21,nn,torch,0.5,神经网络,train,import,path,AlexNet From: https://www.cnblogs.com/zhaopengpeng/p/16845584.html