首页 > 编程问答 >并行 GPU 中的 Pytorch 推理

并行 GPU 中的 Pytorch 推理

时间:2024-07-23 04:24:41浏览次数:10  
标签:python pytorch parallel-processing

我有一个经过训练的模型和 4 个 GPU。我有一个数据集,我想将其分割为 4 个(并在每个 GPU 上使用相同的批量大小进行处理),并且基本上将我从每个 GPU 获得的结果相加。一般来说,我对并行处理很陌生,我对 torch.multiprocessing 等的谷歌搜索没有帮助。

我的问题的一个最小示例如下: 我尝试使用 $python 和 $torchrun 运行以下脚本。

import torch 
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.distributed as dist
import torch.multiprocessing as mp

net = nn.Sequential(nn.Linear(1, 10), nn.Tanh(), nn.Linear(10,1), nn.Flatten(0)) # some model 

x = torch.linspace(0, 1, 100).unsqueeze(1) # the full dataset
s0 = net(x).sum() # some action of the model on the full data  
# I want to preform the same action on the data in parallel and get the same result


def run_inference(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    net.to(rank)

    x = torch.linspace(0, 1, 100).unsqueeze(1) # re-define data here? (Based on google results)
    data = TensorDataset(x) 

    world_size = torch.cuda.device_count()
    batch_size = 10
    data_loader = DataLoader(data, batch_size=batch_size, drop_last=False, num_workers=world_size)

    sum_i = 0 
    for idx, data in enumerate(data_loader):
        sum_i += net(data[0]).sum() # how to store and add results from different GPUs?


world_size = torch.cuda.device_count() # 4 in my case 
mp.spawn(run_inference,
    args=(world_size,),
    nprocs=world_size,
    join=True)

以下是使用 PyTorch 在 4 个 GPU 上并行进行推理并对结果求和的细分方法:

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.distributed as dist
import torch.multiprocessing as mp

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(1, 10)
        self.tanh = nn.Tanh()
        self.linear2 = nn.Linear(10, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.tanh(x)
        x = self.linear2(x)
        return x.squeeze()

def run_inference(rank, world_size):
    # 初始化进程组
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    # 设置设备
    torch.cuda.set_device(rank)

    # 模型和数据
    net = Net().to(rank)
    net = nn.parallel.DistributedDataParallel(net, device_ids=[rank])

    x = torch.linspace(0, 1, 100).unsqueeze(1).to(rank)
    dataset = TensorDataset(x)

    # 创建分布式采样器
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=False
    )

    data_loader = DataLoader(
        dataset,
        batch_size=10,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        sampler=sampler
    )

    # 推理循环
    sum_i = torch.tensor(0.0).to(rank)
    for data in data_loader:
        input_data = data[0]
        output = net(input_data)
        sum_i += output.sum()

    # 将所有 GPU 上的总和相加
    dist.all_reduce(sum_i, op=dist.ReduceOp.SUM)

    # 在主进程上打印结果
    if rank == 0:
        print(f"总和: {sum_i.item()}")

if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)

说明:

  1. 导入必要的库: 我们需要 torch torch.nn torch.utils.data torch.distributed torch.multiprocessing

  2. 定义模型: 我们将相同的模型架构 ( net ) 定义为之前。

  3. run_inference 函数:

  4. 此函数现在接受 rank (进程的 ID)和 world_size (进程总数)作为参数。
  5. 初始化进程组: 使用 nccl 后端初始化 dist.init_process_group ,该后端通常适用于多 GPU 通信。
  6. 设置设备: torch.cuda.set_device(rank) 确保每个进程在其指定的 GPU 上运行。
  7. 将模型移至设备: net 实例移至当前设备。
  8. 创建分布式数据加载器:
    • 我们将数据划分为多个工作进程。
    • 使用 DistributedSampler 确保每个进程获得唯一的数据子集。
  9. 执行推理: 对每个批次运行模型,并累积输出的总和( sum_i )。
  10. 使用 dist.all_reduce 减少总和: 此函数将所有进程上的 sum_i 值相加,并将结果存储在所有进程的 sum_i 中。
  11. 打印结果: rank 0 进程(主进程)上打印最终的总和。

  12. if __name__ == '__main__': 块:

  13. 获取可用的 GPU 数量 ( world_size )。
  14. 使用 mp.spawn 函数生成多个进程,每个进程运行 run_inference 函数。

如何运行:

  1. 确保有 4 个可用的 GPU。
  2. 使用 python -m torch.distributed.launch --nproc_per_node=4 your_script.py 命令运行脚本,其中 your_script.py 是的 Python 文件名。

此脚本将在 4 个 GPU 上并行运行推理,并将结果聚合以计算最终的总和。 每个 GPU 处理数据的不同部分,并且 dist.all_reduce 确保最终结果是所有部分结果的总和。

标签:python,pytorch,parallel-processing
From: 78780627

相关文章

  • 无法在 python 中安装 pip install expliot - bluepy 的 Building Wheel (pyproject.t
    在此处输入图像描述当我尝试在Windows计算机中通过cmd安装pipinstallexpliot包时,我收到2个错误名称×Buildingwheelforbluepy(pyproject.toml)didnotrunsuccessfully.│exitcode:1**AND**opt=self.warn_dash_deprecation......
  • python 用单斜杠-反斜杠替换url字符串中的双斜杠
    我的URL包含错误的双斜杠(“//”),我需要将其转换为单斜杠。不用说,我想保持“https:”后面的双斜杠不变。可以在字符串中进行此更改的最短Python代码是什么?我一直在尝试使用re.sub,带有冒号否定的正则表达式(即,[^:](//)),但它想要替换整个匹配项(包括前面......
  • 如何使用 Selenium Python 搜索 Excel 文件中的文本
    我有一些数据在Excel文件中。我想要转到Excel文件,然后搜索文本(取自网站表),然后获取该行的所有数据,这些数据将用于在浏览器中填充表格。示例:我希望selenium搜索ST0003然后获取名称,该学生ID的父亲姓名,以便我可以在大学网站中填写此信息。我想我会从网站......
  • Python 套接字请求在很多情况下都会失败
    我在python中尝试了超过5种不同的方法,尽管人们说它在其他论坛上有效,但所有这些方法都惨遭失败。importsocketmessage="test"clientsocket=socket.socket(socket.AF_INET,socket.SOCK_STREAM)clientsocket.connect(('1.1.1.1',80))clientsocket.send(mes......
  • Python 网络套接字
    我一直尝试通过Python访问该网站的websocket,但是需要绕过CloudFlare,现在我尝试通过cookie进行绕过,但是这不起作用。我已经尝试在没有cookie的情况下执行此操作,但这也不起作用。importwebsocketimportbase64importosdriver=selenium.webdriver.Firefox()driver.ge......
  • 如何在Python中使用Selenium提取data-v-xxx?
    因为我想查看每个class='num'内的文本是否大于0。如果测试通过,那么我需要获取venuen-name内的文本。我观察到,data-v是相同的。所以我的方法是获取相同的data-v-<hashvalue>来查找场地名称。我尝试了不同的方法来提取,但仍然无法提取。有什么建议吗?这是DOM<div......
  • Python:添加异常上下文
    假设我想提出一个异常并提供额外的处理信息;最好的做法是什么?我想出了以下方法,但对我来说有点可疑:definternal_function():raiseValueError("smellysocks!")defcontext_function():try:internal_function()exceptExceptionase:......
  • 【视频】Python遗传算法GA优化SVR、ANFIS预测证券指数ISE数据-CSDN博客
    全文链接:https://tecdat.cn/?p=37060本文旨在通过应用多种机器学习技术,对交易所的历史数据进行深入分析和预测。我们帮助客户使用了遗传算法GA优化的支持向量回归(SVR)、自适应神经模糊推理系统(ANFIS)等方法,对数据进行了特征选择、数据预处理、模型训练与评估。实验结果表明,这些方法......
  • Python学习笔记42:游戏篇之外星人入侵(三)
    前言在之前我们已经创建好了目录,并且编写好了游戏入口的模块。今天的内容主要是讲讲需求的分析以及项目各模块的代码初步编写。在正式编写代码前,碎碎念几句。在正式编写一个项目代码之前,实际是有很多工作要做的。就项目而言,简单的定项,需求对齐,项目架构设计,实际的代码编写,......
  • Python入门知识点 5--流程控制语句
    先来分享一个pycharm使用小技巧   红色波浪线:提醒可能报错   黄色波浪线:提醒书写不规范,ctrl+alt+l去掉黄线   code--Reformatcode,就可以去掉黄线,调整代码格式1、程序三大执行流程(1)顺序执行        程序执行时,代码从上往下,从左往右执行,中间......