首页 > 其他分享 >深度学习(单机多gpu训练)

深度学习(单机多gpu训练)

时间:2024-03-31 18:00:50浏览次数:43  
标签:nn 单机 import torch transforms 深度 device gpu model

如果一个机器上有多个gpu,可以使用多gpu训练。

一般数据量和模型比较大的时候训练速度会有明显的提升,模型和数据比较小的时候反而可能因为数据通信原因导致性能下降。

下面是一个简单的例子:

import time
import torch
import torchvision.models
from torchvision.transforms import transforms
from torch import nn, optim
from torchvision.datasets import CIFAR10

if __name__ == "__main__":

    device = torch.device("cuda")
    
    dataTransforms = transforms.Compose([
            transforms.ToTensor()
            , transforms.RandomCrop(32, padding=4)  
            , transforms.RandomHorizontalFlip(p=0.5) 
            , transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  
        ])

    trainset = CIFAR10(root='./data', train=True, download=True, transform=dataTransforms)
    trainLoader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)
 
    model = torchvision.models.resnet18(pretrained=False)
    model.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=True)  
    model.maxpool = nn.MaxPool2d(1, 1, 0) 
    model.fc = nn.Linear(model.fc.in_features, 10)
 
    model.to(device)

    # 将模型包装成 DataParallel
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    cross = nn.CrossEntropyLoss()
    cross.to(device)

    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

    start = time.time()
    for epoch in range(10):
   
        model.train()  

        correctSum = 0.0
        lossSum = 0.0
        dataLen = 0

        for inputs, labels in trainLoader:
            inputs = inputs.to(device)
            labels = labels.to(device)
 
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = cross(outputs, labels)
 
            _, preds = torch.max(outputs, 1)  
 
            loss.backward() 
            optimizer.step()  
 
            correct = (preds == labels).sum() 
            correctSum +=correct
            lossSum += loss.item()
            dataLen +=inputs.size(0)
        
        print(lossSum/dataLen, correctSum/dataLen)

    timeElapsed = time.time() - start
    print('耗时 {:.0f}m {:.0f}s'.format(timeElapsed // 60, timeElapsed % 60))

标签:nn,单机,import,torch,transforms,深度,device,gpu,model
From: https://www.cnblogs.com/tiandsp/p/18095358

相关文章