torch_utils.py
utils\torch_utils.py
目录
2.def torch_distributed_zero_first(local_rank: int):
3.def init_torch_seeds(seed=0):
4.def date_modified(path=__file__):
5.def git_describe(path=Path(__file__).parent):
6.def select_device(device='', batch_size=None):
8.def profile(x, ops, n=100, device=None):
10.def intersect_dicts(da, db, exclude=()):
11.def initialize_weights(model):
12.def find_modules(model, mclass=nn.Conv2d):
14.def prune(model, amount=0.3):
15.def fuse_conv_and_bn(conv, bn):
16.def model_info(model, verbose=False, img_size=640):
17.def load_classifier(name='resnet101', n=2):
18.def scale_img(img, ratio=1.0, same_shape=False, gs=32):
19.def copy_attr(a, b, include=(), exclude=()):
21.class BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
22.def revert_sync_batchnorm(module):
23.class TracedModel(nn.Module):
1.所需的库和模块
# YOLOR PyTorch utils
import datetime
import logging
import math
import os
import platform
import subprocess
import time
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torchvision
try:
import thop # for FLOPS computation
except ImportError:
thop = None
# 它用于获取一个名为 __name__ 的日志记录器。在Python的 logging 模块中, getLogger 函数用于获取一个日志记录器对象, __name__ 是Python的一个内置变量,它返回当前模块的名称。这行代码通常用于初始化日志记录,以便在程序中记录信息。
logger = logging.getLogger(__name__)
2.def torch_distributed_zero_first(local_rank: int):
# 这段代码定义了一个名为 torch_distributed_zero_first 的上下文管理器(context manager),它用于在分布式训练中协调多个进程,确保所有进程等待某个特定的进程(通常是本地的主进程,即 local_rank 为 0 的进程)首先执行某些操作,然后其他进程再继续执行。
# @contextmanager 是一个装饰器,用于创建一个上下文管理器(context manager),它允许你使用 with 语句来管理代码执行的上下文。
# 在 Python 中,上下文管理器通常用于获取资源、确保资源在使用后被正确清理,以及在进入和退出上下文时执行特定的代码。
# 使用 @contextmanager 装饰器的函数必须使用 yield 语句,它将函数分成两部分 :在 yield 之前执行的代码(进入上下文)和 yield 之后执行的代码(退出上下文)。
# 当 with 块开始执行时,会运行到 yield 之前的代码,然后暂停执行,直到 with 块的代码执行完毕。之后,会执行 yield 之后的代码。
# @contextmanager :这是一个装饰器,用于定义一个上下文管理器,允许使用 with 语句来管理资源。
@contextmanager
# 1.local_rank :这是一个参数,表示当前进程在本地节点中的排名。
def torch_distributed_zero_first(local_rank: int):
# 装饰器使分布式训练中的所有进程等待每个 local_master 执行某项操作。
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
# 这个条件检查当前进程的 local_rank 是否不是 -1 也不是 0 。 local_rank 为 -1 通常表示非分布式训练环境,而 0 表示分布式训练中的第一个进程(主进程)。
if local_rank not in [-1, 0]:
# 如果当前进程不是主进程(即 local_rank 不是 0 ),那么它会在这里等待。 torch.distributed.barrier() 是一个同步操作,它会阻塞当前进程,直到所有进程都到达这个屏障点。
torch.distributed.barrier()
# 这是一个生成器的 yield 语句,它标志着上下文管理器的开始。当 with 块的代码执行到这里时,主进程会继续执行 yield 之后的代码,而非主进程则在 torch.distributed.barrier() 处等待。
yield
# 在 yield 之后,只有 local_rank 为 0 的主进程会继续执行。
if local_rank == 0:
# 主进程在执行完 yield 之后的代码后,会再次调用 torch.distributed.barrier() 。这确保了在主进程完成特定操作后,所有进程都会在这里同步,然后继续执行后续的代码。
torch.distributed.barrier()
# 这种同步机制确保了在分布式训练中,所有进程在执行某些关键操作之前都会等待主进程完成。这可以避免数据加载和处理的冲突,确保数据的一致性,特别是在需要主进程首先执行某些初始化或配置操作时。
3.def init_torch_seeds(seed=0):
# 这段代码定义了一个名为 init_torch_seeds 的函数,它用于初始化 PyTorch 的随机数生成器种子,以确保代码的随机性是可重复的。这个函数接受一个参数。
# 1.seed :默认值为0。
def init_torch_seeds(seed=0):
# torch.manual_seed(seed)
# torch.manual_seed() 是 PyTorch 中用于设置随机数生成器种子的函数。这个函数确保了 PyTorch 操作产生的随机数序列是可重复的,即在相同的种子下,每次运行程序时产生的随机数序列都是相同的。
# 参数 :
# seed:种子值,一个非负整数。如果为0,则 PyTorch 会使用一个随机种子。
# 作用 :
# 当你提供一个特定的种子值时, torch.manual_seed(seed) 会重置 PyTorch 的随机数生成器的状态,使得随后的随机数生成(如初始化权重、打乱数据等操作)可以预测。
# 如果不提供种子值,或者种子值为0,则 PyTorch 将使用一个随机种子,通常是根据当前时间生成的,这使得每次程序运行时产生的随机数序列都是不同的。
# 请注意, torch.manual_seed() 只影响 PyTorch 的 CPU 随机数生成器。如果你在使用 CUDA(GPU)时也需要确保随机性是可重复的,你可能还需要调用 torch.cuda.manual_seed() 或 torch.cuda.manual_seed_all() 来设置 GPU 上的随机数种子。
# 速度与可重复性的权衡 https://pytorch.org/docs/stable/notes/randomness.html。
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
# 这行代码设置了 PyTorch 的随机数种子。当提供相同的种子值时,PyTorch 将生成相同的随机数序列。
torch.manual_seed(seed)
# 这个条件判断用于决定是否使用更慢但更可重复的设置,或者是更快但可重复性较低的设置。
if seed == 0: # slower, more reproducible 速度更慢,可重复性更高。
# cudnn.benchmark = False :当设置为 False 时,PyTorch 将不使用 CUDA 的自动调优功能,这意味着每次运行代码时,卷积操作将使用相同的算法,这有助于确保结果的可重复性,但可能会牺牲一些性能。
# cudnn.deterministic = True :当设置为 True 时,PyTorch 将确保每次运行代码时,卷积操作的结果都是确定的。这也有助于确保结果的可重复性,但同样可能会牺牲一些性能。
cudnn.benchmark, cudnn.deterministic = False, True
# 如果种子值不为0,则使用更快但可重复性较低的设置。
else: # faster, less reproducible 速度更快,可重复性更差。
# cudnn.benchmark = True :当设置为 True 时,PyTorch 将使用 CUDA 的自动调优功能,这意味着每次运行代码时,卷积操作可能会使用不同的算法,这有助于提高性能,但可能会降低结果的可重复性。
# cudnn.deterministic = False :当设置为 False 时,PyTorch 将不保证每次运行代码时,卷积操作的结果都是确定的,这也有助于提高性能,但可能会降低结果的可重复性。
cudnn.benchmark, cudnn.deterministic = True, False
# 这个函数的设计考虑了速度和可重复性之间的权衡。如果你需要确保结果的可重复性,可以将种子值设置为0;如果你更关心性能,可以设置一个非零的种子值。
4.def date_modified(path=__file__):
# 定义了一个名为 date_modified 的函数,它接受一个参数 1.path ,默认值为 __file__ ,即当前文件的路径。函数的作用是返回文件的修改日期,格式为 'YYYY-MM-DD' 。
def date_modified(path=__file__):
# 返回人类可读的文件修改日期,即“2021-3-26”。
# return human-readable file modification date, i.e. '2021-3-26'
# Path(path)
# 在Python中, Path 是 pathlib 模块中的一个类,用于表示文件系统路径。 pathlib 是一个现代的文件路径操作库,它提供了面向对象的方式来处理文件和目录路径。
# 导入 Path 类 : from pathlib import Path。
# 可以使用 Path 类来创建一个路径对象。这个对象可以是一个文件或者目录的路径。
# 路径操作 :
# Path 对象提供了许多方法来操作路径。
# p.exists() :检查路径是否存在。
# p.is_file() :检查路径是否指向一个文件。
# p.is_dir() :检查路径是否指向一个目录。
# p.resolve() :解析路径,返回绝对路径。
# p.parent :返回路径的父目录。
# p.name :返回路径的最后一部分(文件名)。
# p.suffix :返回文件的后缀名。
# p.stem :返回文件名不包括后缀的部分。
# 这行代码做了几件事情 :
# Path(path) 使用了 pathlib 模块的 Path 类来表示文件路径。
# .stat() 获取了文件的状态信息。
# .st_mtime 是状态信息中的一个属性,表示文件最后修改的时间戳。
# datetime.datetime.fromtimestamp() 是 datetime 模块的一个函数,它将时间戳转换为 datetime 对象。
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
# 这行代码返回一个格式化的字符串,包含了 datetime 对象的年、月、日。
return f'{t.year}-{t.month}-{t.day}'
# 这个函数可以被用来获取任何文件的最后修改日期,并以人类可读的格式返回。
5.def git_describe(path=Path(__file__).parent):
# 定义了一个函数 git_describe ,它有一个参数 1.path ,默认值是 __file__ 的父目录,这里 __file__ 是当前文件的路径, .parent 表示父目录。
def git_describe(path=Path(__file__).parent): # path must be a directory
# 返回人类可读的 git 描述,即 v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe。
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
# 这行代码构建了一个字符串 s ,它是一个命令行命令,用于执行 git describe 。 -C 参数用于指定Git命令的目录, --tags 表示使用标签作为参考点, --long 表示返回完整的描述, --always 表示即使没有标签也返回一个描述。
s = f'git -C {path} describe --tags --long --always'
# 开始了一个 try 块,用于捕获可能发生的异常。
try:
# 这行代码使用 subprocess.check_output 执行构建的命令 s , shell=True 允许直接执行命令行字符串, stderr=subprocess.STDOUT 将标准错误重定向到标准输出, .decode() 将输出从字节串解码为字符串, [:-1] 用于去除字符串末尾的换行符。
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
# 捕获了 subprocess.CalledProcessError 异常,这通常发生在命令执行失败时。
except subprocess.CalledProcessError as e:
# 如果捕获到异常,函数返回一个空字符串,表示当前目录不是一个Git仓库。
return '' # not a git repository 不是 git 存储库。
# 这段代码假设你正在运行的目录或其父目录是一个Git仓库。如果不是,函数将返回一个空字符串。
6.def select_device(device='', batch_size=None):
# 这段代码定义了一个名为 select_device 的函数,它用于选择并配置 PyTorch 模型将使用的计算设备,可以是 CPU 或者一个或多个 GPU。
# 1.device :一个字符串,指定使用的设备,可以是 'cpu' 或者 GPU 编号(如 '0' 或 '0,1,2,3' )。
# 2.batch_size :一个整数,指定批量大小,用于检查是否与 GPU 数量兼容。
def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3'
# 创建一个字符串 s ,包含模型名称、版本信息和 PyTorch 版本。
s = f'YOLOR
标签:YOLOv7,torch,0.1,模型,py,PyTorch,参数,模块,model
From: https://blog.csdn.net/m0_58169876/article/details/143826412