一、本文介绍
本文记录的是利用DySample上采样对YOLOv9的颈部网络进行改进的方法研究。YOLOv9
采用传统的最近邻插值的方法进行上采样可能无法有效地捕捉特征的细节和语义信息,从而影响模型在密集预测任务中的性能。DySample
通过动态采样的方式进行上采样,能够更好地处理特征的细节和语义信息。
文章目录
二、DySample介绍
通过学习采样来学习上采样
DySample
是一种超轻量且有效的动态上采样器,其设计出发点、原理和优势如下:
2.1、DySample原理
- 初步设计:通过PyTorch的内置函数,假设输入特征通过双线性插值被插值为连续的特征图,然后通过生成内容感知的采样点来重新采样该连续图。具体实现为,给定特征图 X X X和上采样尺度因子 s s s,使用线性层生成偏移 O O O,并通过Pixel Shuffling将其重塑为 2 × s H × s W 2 \times sH \times sW 2×sH×sW,然后将偏移 O O O与原始采样网格 G G G相加得到采样集 S S S,最后通过网格采样函数根据采样集生成上采样后的特征图 X ′ X' X′。
- 改进步骤:
- 初始采样位置:在初步版本中,初始采样位置固定且分布不均匀,类似于“最近邻初始化”。为解决此问题,改为“双线性初始化”,即改变初始位置,使零偏移时能得到双线性插值的特征图,从而提高性能。
- 偏移范围:由于归一化层的存在,输出特征值的范围通常在 [ − 1 , 1 ] [ - 1, 1] [−1,1],导致局部采样位置的偏移范围可能重叠,影响边界预测并导致输出伪影。通过将偏移乘以0.25的“静态范围因子”,局部约束了采样位置的偏移范围,缓解了该问题。
- 分组:组向上采样,将特征图沿通道维度划分为 g g g组,并为每组生成偏移。当 g = 4 g = 4 g=4时,性能得到提升。
- 动态范围因子:为增加偏移的灵活性,通过线性投影输入特征生成点级的“动态范围因子”,动态范围因子的值在 [ 0 , 0.5 ] [0, 0.5] [0,0.5]范围内,以0.25为中心,进一步提升了性能。
- 偏移生成方式:研究了两种偏移生成方式,“线性 + 像素洗牌”(LP)和“像素洗牌 + 线性”(PL)。通过实验,根据不同模型设置了不同的组数量,并且发现PL版本在某些模型上表现更好,但在其他模型上略逊于LP版本。
- 最终变体:根据范围因子(静态/动态)和偏移生成方式(LP/PL),研究了四个变体:DySample(LP风格,静态范围因子)、DySample +(LP风格,动态范围因子)、DySample - S(PL风格,静态范围因子)、DySample - S +(PL风格,动态范围因子)。
2.2、优势
- 轻量高效:与其他动态上采样器相比,
DySample
不需要高分辨率的引导特征作为输入,也不需要除PyTorch之外的任何额外CUDA包,具有更少的推理延迟、内存占用、FLOPs和参数数量。 - 性能优越:在五个密集预测任务(语义分割、目标检测、实例分割、全景分割和单目深度估计)中,与其他上采样器相比,
DySample
报告了更好的性能。
论文:https://arxiv.org/pdf/2308.15085
源码:https://github.com/tiny-smart/dysample
三、DySample的实现代码
DySample模块
的实现代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
def normal_init(module, mean=0, std=1, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
class DySample(nn.Module):
def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False):
super().__init__()
self.scale = scale
self.style = style
self.groups = groups
assert style in ['lp', 'pl']
if style == 'pl':
assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
assert in_channels >= groups and in_channels % groups == 0
if style == 'pl':
in_channels = in_channels // scale ** 2
out_channels = 2 * groups
else:
out_channels = 2 * groups * scale ** 2
self.offset = nn.Conv2d(in_channels, out_channels, 1)
normal_init(self.offset, std=0.001)
if dyscope:
self.scope = nn.Conv2d(in_channels, out_channels, 1)
constant_init(self.scope, val=0.)
self.register_buffer('init_pos', self._init_pos())
def _init_pos(self):
h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)
def sample(self, x, offset):
B, _, H, W = offset.shape
offset = offset.view(B, 2, -1, H, W)
coords_h = torch.arange(H) + 0.5
coords_w = torch.arange(W) + 0.5
coords = torch.stack(torch.meshgrid([coords_w, coords_h])
).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
coords = 2 * (coords + offset) / normalizer - 1
coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view(
B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear',
align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W)
def forward_lp(self, x):
if hasattr(self, 'scope'):
offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos
else:
offset = self.offset(x) * 0.25 + self.init_pos
return self.sample(x, offset)
def forward_pl(self, x):
x_ = F.pixel_shuffle(x, self.scale)
if hasattr(self, 'scope'):
offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos
else:
offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos
return self.sample(x, offset)
def forward(self, x):
if self.style == 'pl':
return self.forward_pl(x)
return self.forward_lp(x)
四、添加步骤
4.1 修改common.py
此处需要修改的文件是models/common.py
common.py中定义了网络结构的通用模块
,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。
DySample模块
添加后如下:
注意❗:在4.2小节
中的yolo.py
文件中需要声明的模块名称为:DySample
。
4.2 修改yolo.py
此处需要修改的文件是models/yolo.py
yolo.py用于函数调用
,我们只需要将common.py
中定义的新的模块名添加到parse_model函数
下即可。
在def parse_model(d, ch)
中将DySample模块
添加后如下:
elif m in [DySample]:
args = [ch[f], *args[0:]]
五、yaml模型文件
5.1 模型改进
在代码配置完成后,配置模型的YAML文件。
此处以models/detect/yolov9-c.yaml
为例,在同目录下创建一个用于自己数据集训练的模型文件yolov9-c-dysamlpe.yaml
。
将yolov9-c.yaml
中的内容复制到yolov9-c-dysamlpe.yaml
文件下,修改nc
数量等于自己数据中目标的数量。