首页 > 其他分享 >大模型--训练加速之deepspeed demo-13

大模型--训练加速之deepspeed demo-13

时间:2024-11-11 15:41:32浏览次数:1  
标签:engine deepspeed nn -- demo self args model data

目录

1. config.json

{
  "train_batch_size": 4,
  "steps_per_print": 2000,
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.001,
      "betas": [0.8, 0.999],
      "eps": 1e-8,
      "weight_decay": 3e-7
    }
  },
  "scheduler": {
    "type": "WarmupLR",
    "params": {
      "warmup_min_lr": 0,
      "warmup_max_lr": 0.001,
      "warmup_num_steps": 1000
    }
  },
  "wall_clock_breakdown": false
}

2. main.py

import torch
import torchvision

# 用于构建神经网络模型
import torch.nn as nn
# 提供了各种神经网络层的函数版本,如激活函数、损失函数等
import torch.nn.functional as F

import argparse
import deepspeed


# 创建训练数据集
trainset = torchvision.datasets.CIFAR10(root='./data',
                                        train=True,
                                        download=True,
                                        )

# 创建数据加载器,批量加载数据并处理数据加载的并行化
trainloader = torch.utils.data.DataLoader(trainset,
                                          # 每个批次包含16张图像
                                          batch_size=16,
                                          # 在每次迭代开始时随机打乱训练数据的顺序,有助于模型训练
                                          shuffle=True,
                                          # 开启2个子进程来并行加载数据,提高效率
                                          num_workers=2)

# 创建测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True,
                                       )

testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=4,
                                         # 测试数据通常不需要打乱顺序
                                         shuffle=False,
                                         num_workers=2)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 卷积层1: 输入图像通道数为3 (例如RGB图像),输出6个特征图,卷积核大小为5x5
        self.conv1 = nn.Conv2d(3, 6, 5)
        # 最大池化层: 池化窗口大小为2x2,步长也为2
        self.pool = nn.MaxPool2d(2, 2)
        # 卷积层2: 接收上一层6个输入通道,输出16个特征图,卷积核大小为5x5
        self.conv2 = nn.Conv2d(6, 16, 5)
        # 全连接层1: 输入为16*5*5(假设输入图像大小为32x32,经过两次池化后得到的特征图大小)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        # 全连接层2
        self.fc2 = nn.Linear(120, 84)
        # 输出层: 10个输出节点,对应于10个类别
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # 第一次卷积 + ReLU激活 + 池化
        x = self.pool(F.relu(self.conv1(x)))
        # 第二次卷积 + ReLU激活 + 池化
        x = self.pool(F.relu(self.conv2(x)))
        # 将特征图展平成一维向量
        x = x.view(-1, 16 * 5 * 5)
        # 第一个全连接层 + ReLU激活
        x = F.relu(self.fc1(x))
        # 第二个全连接层 + ReLU激活
        x = F.relu(self.fc2(x))
        # 输出层,没有激活函数
        x = self.fc3(x)
        return x


# 实例化网络模型
net = Net()

# 设置损失函数:多分类交叉熵损失函数,适用于监督学习中的分类任务
criterion = nn.CrossEntropyLoss()


def add_argument():
    # 创建一个ArgumentParser对象,设置描述为"CIFAR"
    parser = argparse.ArgumentParser(description='CIFAR')

    # 设置训练时的批大小,默认值为32
    parser.add_argument(
        '-b', '--batch_size',
        default=32,
        type=int,
        help='mini-batch size (default: 32)'
    )

    # 设置总的训练轮数,默认值为30
    parser.add_argument(
        '-e', '--epochs',
        default=30,
        type=int,
        help='number of total epochs (default: 30)'
    )

    # 传递分布式训练中的排名,默认值为-1,表示未使用分布式训练
    parser.add_argument(
        '--local_rank',
        type=int,
        default=-1,
        help='local rank passed from distributed launcher'
    )

    # 设置输出日志信息的间隔,默认值为2000,即每2000次迭代打印一次日志
    parser.add_argument(
        '--log-interval',
        type=int,
        default=2000,
        help='output logging information at a given interval'
    )

    # 添加与DeepSpeed相关的配置参数
    parser = deepspeed.add_config_arguments(parser)

    # 解析命令行参数,返回一个Namespace对象,其中包含了所有定义的参数及其对应的值
    args = parser.parse_args()

    # 返回解析后的参数对象args,供后续的训练脚本使用
    return args


# 调用之前定义的add_argument函数,解析命令行参数,并将结果存储在args变量中
args = add_argument()

# 创建Net类的实例
net = Net()

# 筛选出模型中需要梯度计算的参数
parameters = filter(lambda p: p.requires_grad, net.parameters())

# 使用deepspeed.initialize初始化模型引擎、优化器、数据加载器以及其他可能的组件
model_engine, optimizer, trainloader, _ = deepspeed.initialize(
    args=args,
    model=net,
    model_parameters=parameters,
    training_data=trainset
)


def train():
    # 定义进行2个epoch的训练
    for epoch in range(2):
        running_loss = 0.0

        # 对于每个epoch,遍历训练数据加载器trainloader中的每一个小批量数据
        # 同时提供索引i和数据data
        for i, data in enumerate(trainloader):
            # 将输入数据inputs和标签labels移动到当前GPU设备上,
            # 具体是哪个GPU由model_engine.local_rank决定,
            # 这对于分布式训练非常重要,确保数据被正确地分配到各个参与训练的GPU上
            inputs, labels = data[0].to(model_engine.local_rank), data[1].to(
                model_engine.local_rank)

            # 通过model_engine执行前向传播,计算模型预测输出
            outputs = model_engine(inputs)

            # 计算预测输出outputs与真实标签labels之间的损失
            loss = criterion(outputs, labels)

            # 反向传播计算梯度
            model_engine.backward(loss)

            # 更新模型参数
            model_engine.step()

            # 计算并累加每个小批量的损失值
            running_loss += loss.item()

            # 当达到args.log_interval指定的迭代次数时,打印平均损失值,
            # 然后重置running_loss为0,以便计算下一个区间的平均损失
            if i % args.log_interval == (args.log_interval - 1):
                print(
                    f'[{epoch + 1}, {i + 1}] loss: {running_loss / args.log_interval:.3f}')
                running_loss = 0.0


def test():
    # 初始化计数器
    # correct用于记录分类正确的样本数量
    # total用于记录评估的总样本数
    correct = 0
    total = 0

    # 上下文管理器,关闭梯度计算,
    # 因为在验证阶段我们不需要计算梯度,这可以提高计算效率
    with torch.no_grad():
        # 遍历测试数据加载器testloader中的每个小批量数据
        for data in testloader:
            # 获取当前小批量数据的图像和标签
            images, labels = data

            # 在当前GPU上执行模型的前向传播
            # 这里将图像数据移动到与模型相同的GPU上,然后通过模型得到预测输出
            outputs = net(images.to(model_engine.local_rank))

            # 找到每个样本的最大概率对应的类别
            _, predicted = torch.max(outputs.data, 1)

            # 增加总样本数,同时计算分类正确的样本数。
            # 注意,这里将标签也移动到与模型相同的GPU上进行比较
            total += labels.size(0)
            correct += (predicted == labels.to(
                model_engine.local_rank)).sum().item()

    # 遍历完整个测试集后,计算并打印模型在测试集上的准确率
    print('Accuracy of the network on the 10000 test images: %d %%' %
          (100 * correct / total))


if __name__ == '__main__':
    train()
    test()


.env

export CUDA_VISIBLE_DEVICES="6,7,8,9"

3. start.sh

deepspeed main.py --deepspeed_config config.json

标签:engine,deepspeed,nn,--,demo,self,args,model,data
From: https://www.cnblogs.com/cavalier-chen/p/18539870

相关文章

  • Linux中文件系统层次结构简述
    在Linux操作系统中,并没有像Windows那样的“盘符”概念。相反,Linux使用一个统一的文件系统层次结构,所有的文件和目录都挂载在一个单一的根目录/下。这种设计使得文件系统的管理更加灵活和一致。文件系统层次结构在Linux中,文件系统通常按照以下层次结构组织:/(根目录):文件系......
  • 11.11随笔
    这里是11.11随笔。课堂作业留档:简单的判断分数,给出等级代码:importjava.util.Scanner;publicclassThrowDemo{publicstaticvoidmain(String[]args){//doubledata;System.out.println("输入分数:");Scannersc=newScanner(System.in);data=sc.nextDouble();......
  • Nginx的一些基本配置
    1.基本配置首先,我们需要编辑Nginx的主配置文件nginx.conf。这个文件通常位于/etc/nginx/nginx.conf或/usr/local/nginx/conf/nginx.conf。示例:基本配置usernginx;worker_processesauto;error_log/var/log/nginx/error.logwarn;pid/var/run/nginx.p......
  • linux系统的简单介绍
    一个项目的工作流程:1.linux系统Linux,全称GNU/Linux,是一种免费使用和自由传播的类UNIX操作系统,其内核由林纳斯·本纳第克特·托瓦兹(LinusBenedictTorvalds)于1991年10月5日首次发布,它主要受到Minix和Unix思想的启发,是一个基于POSIX的多用户、多任务、支持多线程和多CPU的操作......
  • 一键安装yum-utils安装包
    一键安装yum-utils安装包使用yum下载离线安装包及依赖说明:1.方式1:使用yum-yinstall的方式将锁下载包及依赖进行备份更改yum配置文件,将下载的包进行保存vim/etc/yum.conf[main]cachedir=/var/cache/yum/$basearch/$releaseverkeepcache=0#将此处改为1将会保......
  • 实录:电话咨询数据库数据迁移“100” 个问题
    参加“央企”项目改造会后的,“数据库瞎想”这段时间国产数据库的话题频繁出现,新时代新需求,最近研究如何替换MySQL到国产数据库的过程中,发现有这样的需求。不乏一些老的系统,软件没人维护,之前编写软件的开发人员已经找不到踪影,应用系统的数据一直增长上涨,这些客户共同的特点,数......
  • 【双11最后一天】活动商品低至8折!DL32逻辑分析仪Pro、加热台、电烙铁、开发板等活动产
    【双11最后一天】活动商品低至8折!DL32逻辑分析仪、加热台、电烙铁、开发板等活动产品限时优惠!各种爆款产品火热售卖中!双11最后一天,全店活动商品低至8折!全新DL32逻辑分析仪、HP15加热台,还有爆款产品T80智能电烙铁、DS100Mini手持示波器,开发板等活动产品限时优惠!各种新品及爆款产......
  • Spring Security 防止 CSRF 攻击
    使用security是3.3.2版本1、启用CSRF,security自带功能1@Bean2publicSecurityFilterChainfilterChain(HttpSecurityhttpSecurity)throwsException{3//禁用默认的登录和退出4httpSecurity.formLogin(AbstractHttpConfigurer::di......
  • CF 1257 题解
    CF1257题解ATwoRivalStudents每次交换都可以让距离增加\(1\),上界是\(n-1\).题目说至多而不是恰好交换\(x\)次,于是不需要考虑边界.BMagicStick一个重要的观察是:如果能够得到\(x\),那么就能得到任意小于等于\(x\)的数,这是操作二保证的.考虑操作\(1\)......
  • apropos——在 whatis 数据库中查找字符串
    转自于:https://github.com/jaywcjlove/linux-command,后不赘述apropos在whatis数据库中查找字符串补充说明apropos命令在一些特定的包含系统命令的简短描述的数据库文件里查找关键字,然后把结果送到标准输出。如果你不知道完成某个特定任务所需要命令的名称,可以使用一个关......