首页 > 其他分享 >一个单机多卡训练模型的例子

一个单机多卡训练模型的例子

时间:2024-08-14 19:17:12浏览次数:12  
标签:__ torch 单机 args 多卡 batch 例子 model seed

"""My demo train script."""

import argparse
import logging
import os
import random
import time
import numpy as np
import torch

from torch import nn, optim, Tensor
from torch.utils.data import DataLoader, Dataset


def parse_args() -> argparse.Namespace:
    """Parse arguments."""
    parser = argparse.ArgumentParser(description="Training")
    parser.add_argument("--seed", type=int, help="fix random seed", default=123)
    parser.add_argument(
        "--log_file", type=str, help="log file", default="test_train.log"
    )
    parser.add_argument(
        "--log_path", type=str, help="model path", default="./training_log/"
    )
    parser.add_argument(
        "--train_epochs", type=int, help="epochs of training", default=5
    )
    parser.add_argument("--batch_size", type=int, help="batch size", default=32)
    parser.add_argument(
        "--learning_rate",
        type=float,
        help="learning rate",
        default=1e-3,
    )
    parser.add_argument(
        "--device", type=str, help="run on which device (default: cuda)", default="cuda"
    )
    parser.add_argument(
        "--cuda_visible_devices", type=str, help="cuda visible devices", default="0"
    )
    return parser.parse_args()


def init_logging(log_file: str, level: str = "INFO") -> None:
    """Initialize logging."""
    logging.basicConfig(
        filename=log_file,
        filemode="w",
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        level=level,
    )
    logging.getLogger().addHandler(logging.StreamHandler())


def set_seed(seed: int) -> None:
    """Set seed for reproducibility."""
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)


def seed_worker(work_id: int) -> None:
    """Set seed for worker."""
    np.random.seed(work_id)
    random.seed(work_id)


class DatasetClass(Dataset):
    """My demo dataset class."""

    def __init__(self):
        self.input = np.random.rand(100000, 1).astype(np.float32)
        self.target = self.input + 2

    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx: int) -> tuple:
        return self.input[idx], self.target[idx]


class ModelClass(torch.nn.Module):
    """My demo model class."""

    def __init__(self):
        super().__init__()
        self.my_layer = nn.Linear(1, 1)

    def forward(self, inputs: Tensor) -> Tensor:
        """My demo forward function."""
        outputs = self.my_layer(inputs)
        return outputs


def get_loss(model_output: Tensor, target: Tensor) -> Tensor:
    """My demo loss function."""
    loss = torch.norm(model_output - target, dim=-1).sum()
    return loss


def training() -> None:
    """My demo training function."""
    train_set = DatasetClass()
    g = torch.Generator()
    g.manual_seed(args.seed)
    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=os.cpu_count(),
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=g,
    )
    model = ModelClass()
    model = nn.DataParallel(model)
    model.to(args.device)
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
    for epoch in range(args.train_epochs):
        model.train()
        for batch_index, (features, labels) in enumerate(train_loader):
            features = features.to(args.device)
            labels = labels.to(args.device)
            model_outputs = model(features)
            loss = get_loss(model_outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if batch_index % 1000 == 0:
                logging.info(
                    "Epoch: %s, Batch index: %s, Loss: %s",
                    epoch,
                    batch_index,
                    loss.item(),
                )
    torch.save(model.state_dict(), f"{args.log_path}/trained_model.pth")


def testing() -> None:
    """My demo testing function."""
    test_set = DatasetClass()
    g = torch.Generator()
    g.manual_seed(args.seed)
    test_loader = DataLoader(
        test_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=os.cpu_count(),
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=g,
    )
    model = ModelClass()
    model = nn.DataParallel(model)
    model.load_state_dict(torch.load(f"{args.log_path}/trained_model.pth"))
    model.to(args.device)
    model.eval()
    with torch.no_grad():
        for batch_index, (features, labels) in enumerate(test_loader):
            features = features.to(args.device)
            labels = labels.to(args.device)
            model_outputs = model(features)
            loss = get_loss(model_outputs, labels)
            if batch_index % 1000 == 0:
                logging.info(
                    "Batch index: %s, Loss: %s",
                    batch_index,
                    loss.item() / args.batch_size,
                )


if __name__ == "__main__":
    args = parse_args()
    set_seed(args.seed)
    init_logging(args.log_file)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices
    main_start_time = time.time()
    training()
    main_end_time = time.time()
    logging.info("Main time: %s", main_end_time - main_start_time)
    testing()

  

标签:__,torch,单机,args,多卡,batch,例子,model,seed
From: https://www.cnblogs.com/qiandeheng/p/18359617

相关文章

  • milvus调用阿里云大模型例子
    环境:OS:Windowspycharm:2022.1python:3.11.9 1.安装依赖模块pipinstallpymilvustqdmdashscope或是分别单独安装pipinstalldashscope--timeout=100pipinstalltqdm--timeout=100pipinstallpymilvus--timeout=100 2.导入文本报道内容将如下文本文件解压到项目的......
  • 转义字符及例子
    转义字符简单来说就是转变原来字符的意思文章目录转义字符1.转义字符例子2.常用转义字符总览3.常用转义字符及例子3.1\?3.2\’3.3\“小拓展:3.4\\\3.5\a3.6\n3.7\r3.8\t3.9\ddd3.10\xdd1.转义字符例子1.例子代码用换行符来举例子#include<stdio.h>......
  • 【C++】protobuf的简单使用(通讯录例子)
    protobuf的简单使用(通讯录例子).proto文件的编写保留字段字段唯一编号protobuf的类型enum类型Any类型oneof类型map类型完整通讯录代码.proto文件write文件read文件运行结果.proto文件的编写syntax用于指定protobuf的语法;package当.proto文件编译后再*.pb.h文件中会......
  • Docker-Compose单机容器集群编排工具
    目录容器编排管理与传统的容器管理的区别什么Docker-Compose?Docker-Compose的简介Docker-Compose的作用Docker-compose的三大概念什么YAML文件?YAML文件介绍使用YAML时的注意事项YAML文件的基本数据结构Docker-Compose配置常用字段Docker-Compose常用命令我们知道......
  • 传奇单机版:复古三职业+无需虚拟机一键安装
    今天给大家带来一款单机游戏的架设:传奇单机版。沉默版本三职业数值不变态,没有花里胡哨的东西(比如切割,生肖,时装等功能),客户端为16周年客户端。另外:本人承接各种游戏架设(单机+联网)本人为了学习和研究软件内含的设计思想和原理,带了单机架设教程,不适用于联网,仅供娱乐。教程是本人......
  • 多人同屏渲染例子——1、思路分析
    Unity引擎制作万人同屏效果  大家好,我是阿赵。  经常在各种渠道看到游戏的广告,会经常看到一些很宏大的场景,比如什么万人国战、千人同屏之类的说法。多人同屏是某些游戏的卖点,比如经典的割草游戏无双系列,或者是末日类主题的丧尸围城游戏。  说出来很失败,阿赵......
  • ObjectARX 判断实体是否是在位编辑块对象简单例子
    判断使用acdbIsInLongTransaction应该就可以。ads_nameent;ads_pointpt;if(RTNORM!=acedEntSel(_T("\n选择对象:"),ent,pt)){return;}AcDbObjectIdobjId;acdbGetObjectId(objId,ent);//直接判断//if(acdbIsInLongTransaction......
  • macos Cpp webserver的例子
    一、hello.h#include<iostream>usingnamespacestd;intns__hello(std::string*name,std::string&greeting);~二、helloclient.cpp#include"soapH.h"#include"ns.nsmap"......
  • CMAKE 《多模块例子》
    概述生成sort\calc的静态库,并生成app1.exeapp2.exe目录结构CMakeLists.txt位置以及配置根CMakeLists.txtcmake_minimum_required(VERSION3.15)project(mulitiple_modules)set(CMAKE_CXX_STANDARD17)#definevariables#LIBPATH库存储位置set(LIBPATH${PROJECT......
  • webservice 的参考例子 sample
    一、参考https://blog.csdn.net/Ikaros_521/article/details/103232677二、hello.hh__hello(char*&);三、//helloclient.cpp#include"soapH.h"#include"h.nsmap"intmain(){char*s;structsoap*soap=soap_new();so......