首页 > 其他分享 >YOLOV5 train模块写,没验证对错,能跑

YOLOV5 train模块写,没验证对错,能跑

时间:2022-11-19 18:56:10浏览次数:44  
标签:YOLOV5 num self torch anchor 对错 train targets select

import torch
import torch.nn as nn
import torch.optim as optim

import dataset
import models
import nn_utils


class YoloHead(nn.Module):
    def __init__(self, num_classes):
        super(YoloHead, self).__init__()

        self.strides = [8, 16, 32]
        # 因为predict都是feateuremap出来的图片大小不对所以需要除以scale
        self.anchors = torch.tensor([
            [10, 13, 16, 30, 33, 23],  # P3/8
            [30, 61, 62, 45, 59, 119],  # P4/16
            [116, 90, 156, 198, 373, 326]  # P5/32
        ]).view(3, 3, 2) / torch.FloatTensor(self.strides).view(3, 1, 1)
        self.offset_boundary = torch.FloatTensor([
            [+1, 0],
            [0, +1],
            [-1, 0],
            [0, -1]
        ])
        self.num_anchor_per_level = self.anchors.size(1)
        self.num_classes = num_classes
        self.anchor_t = 4.0
        self.BCEClassification = nn.BCEWithLogitsLoss(reduction="mean")
        self.BCEObjectness = nn.BCEWithLogitsLoss(reduction="mean")
        self.balance = [4.0, 1.0, 0.4]  # 8, 16, 32
        self.box_weight = 0.05
        self.objectness_weight = 1.0
        self.classification_weight = 0.5 * self.num_classes / 80  # 80指coco的类别数

    def to(self, device):
        self.anchors = self.anchors.to(device)
        self.offset_boundary = self.offset_boundary.to(device)
        return super().to(device)

    def giou(self, a, b):
        '''
        计算a与b的GIoU
        参数:
        a[Nx4]:      要求是[cx, cy, width, height]
        b[Nx4]:       要求是[cx, cy, width, height]
        GIoU的计算,left = cx - (width - 1) / 2,或者是left = cx - width / 2。两者皆可行
            - 但是,前者的计算与后者在特定场合下,会存在浮点数精度问题。导致小数点后7位不同
            - 如果严格复现,请按照官方写法
            - 如果自己实现,可以选择一种即可
        '''
        # a is n x 4
        # b is n x 4

        # cx, cy, width, height
        a_xmin, a_xmax = a[:, 0] - a[:, 2] / 2, a[:, 0] + a[:, 2] / 2
        a_ymin, a_ymax = a[:, 1] - a[:, 3] / 2, a[:, 1] + a[:, 3] / 2
        b_xmin, b_xmax = b[:, 0] - b[:, 2] / 2, b[:, 0] + b[:, 2] / 2
        b_ymin, b_ymax = b[:, 1] - b[:, 3] / 2, b[:, 1] + b[:, 3] / 2

        inter_xmin = torch.max(a_xmin, b_xmin)
        inter_xmax = torch.min(a_xmax, b_xmax)
        inter_ymin = torch.max(a_ymin, b_ymin)
        inter_ymax = torch.min(a_ymax, b_ymax)
        inter_width = (inter_xmax - inter_xmin).clamp(0)
        inter_height = (inter_ymax - inter_ymin).clamp(0)
        inter_area = inter_width * inter_height

        a_width, a_height = (a_xmax - a_xmin), (a_ymax - a_ymin)
        b_width, b_height = (b_xmax - b_xmin), (b_ymax - b_ymin)
        union = (a_width * a_height) + (b_width * b_height) - inter_area
        iou = inter_area / union

        # smallest enclosing box
        convex_width = torch.max(a_xmax, b_xmax) - torch.min(a_xmin, b_xmin) + 1e-16
        convex_height = torch.max(a_ymax, b_ymax) - torch.min(a_ymin, b_ymin)
        convex_area = convex_width * convex_height + 1e-16
        return iou - (convex_area - union) / convex_area

    def forward(self, predict, targets):
        """
        计算loss
        :param predict: model预测
        :param targets:真是值
        :return:loss
        """
        # predict [b,(5+num_classes)*3,height,width]
        # targets num_target*[image_index,class_index,cx,cy,width,height]
        num_target = targets.size(0)
        device = targets.device
        loss_box_regression = torch.FloatTensor([0]).to(device)
        loss_classification = torch.FloatTensor([0]).to(device)
        loss_objectness = torch.FloatTensor([0]).to(device)

        for ilayer, layer in enumerate(predict):
            # layer [1,(5+num_classes)*3,height,width]
            layer.to(device)
            layer_height, layer_width = layer.shape[-2:]
            # 因为输出的3个featuremap的特征图
            layer = layer.view(-1, self.num_anchor_per_level, 5 + self.num_classes, layer_height, layer_width)
            # 转化维度为b,num_anchor,layer_height,layer_width,5+num_classes
            layer = layer.permute(0, 1, 3, 4, 2).contiguous()
            # 因为真是值是normalize过后的值,所以需要对应到三个特征图上
            # targets [N*6] 6[image_index,classes,cx,cy,width,height]
            feature_size_gain = targets.new_tensor([1, 1, layer_width, layer_height, layer_width, layer_height])
            # 放大到feature map大小[n,6]
            targets_feature_scale = targets * feature_size_gain
            # 因为预测出来是 len(predict) = 3 所以对应的索引就是不同维度的anchor
            # anchor [3*2]
            anchors = self.anchors[ilayer]
            num_anchor = anchors.size(0)

            anchor_wh = anchors.view(num_anchor, 1, 2)
            targets_wh = targets_feature_scale[:, [4, 5]].view(1, num_target, 2)
            # num_anchor,num_target,2
            # 获得宽宽比高高比,目标框需要的宽宽比高高比
            wh_ratio = targets_wh / anchor_wh
            max_wh_ration_values, _ = torch.max(wh_ratio, 1 / wh_ratio).max(dim=2)
            # select_mask [num_anchor,num_target]
            select_mask = max_wh_ration_values < self.anchor_t
            # targets_feature_scale[n,6]
            select_targets = targets_feature_scale.repeat(num_anchor, 1, 1)[select_mask]
            # 选择后的形状为 num_select_targets * 6 【image_id,classes_id,cx,cy,width,height】
            num_select_target = len(select_targets)
            # layer转化了维度过后最后一个维度的第4个位置为是否是物体的标签
            # layer -> shape  [batch, anchor, height, width, (5 + class)]
            #                                                [cx, cy, width, height, objectness]
            featuremap_object = layer[..., 4]
            objectness_ground_true = torch.zeros_like(featuremap_object)

            # 需要有选择的目标宽宽比,高高比大于阈值,如果有需要回归的目标才需要回归框
            # 1.宽宽比,高高比,取最大值,小于阈值anchor_t,被认为是选中  √
            # 2.拓展样本
            # 3.计算GIoU
            # 4.计算loss
            # 5.loss加权合并
            if num_select_target > 0:
                # 2.拓展样本
                # select_anchor_index.shape = (num_select_target,1)
                select_anchor_index = torch.arange(num_anchor).view(num_anchor, 1).repeat(1, num_target)[select_mask]

                # 先获取到targets的中心点坐标
                # 这里默认就是cx, cy
                # select_targets.shape   num_matched_target x 6
                #                        [image_id, class_index, cx, cy, width, height]
                # select_targets的值域是什么?   是featuremap尺度
                # select_targets[:, 2:4] select_targets_xy.shape[num_select_targets]
                select_targets_xy = select_targets[:, [2, 3]]

                xy_divided_one_remainder = select_targets_xy % 1.0

                # 计算中心位置,宽高的上边界和下边界,
                coord_cell_middle = 0.5
                feature_map_low_boundary = 1.0
                feature_map_high_boundary = feature_size_gain[[2, 3]] - 1.0

                less_x_matched, less_y_matched = ((xy_divided_one_remainder < coord_cell_middle) & (
                        select_targets_xy > feature_map_low_boundary)).T
                greater_x_matched, greater_y_matched = ((xy_divided_one_remainder > (1 - coord_cell_middle)) & (
                        select_targets_xy < feature_map_high_boundary)).T

                select_anchor_index = torch.cat([
                    select_anchor_index,
                    select_anchor_index[less_x_matched],
                    select_anchor_index[less_y_matched],
                    select_anchor_index[greater_x_matched],
                    select_anchor_index[greater_y_matched],
                ], dim=0)

                select_targets = torch.cat([
                    select_targets,
                    select_targets[less_x_matched],
                    select_targets[less_y_matched],
                    select_targets[greater_x_matched],
                    select_targets[greater_y_matched],
                ])

                xy_offset = torch.zeros_like(select_targets_xy)
                xy_offset = torch.cat([
                    xy_offset,
                    xy_offset[less_x_matched] + self.offset_boundary[0],  # 左边
                    xy_offset[less_y_matched] + self.offset_boundary[1],  # 上边
                    xy_offset[greater_x_matched] + self.offset_boundary[2],  # 右边
                    xy_offset[greater_y_matched] + self.offset_boundary[3]  # 下边
                ]) * coord_cell_middle

                matched_extend_num_target = len(select_targets)
                gt_image_id, gt_classes_id = select_targets[:, [0, 1]].long().T
                gt_xy = select_targets[:, [2, 3]]
                gt_wh = select_targets[:, [4, 5]]
                grid_xy = (gt_xy - xy_offset).long()
                grid_x, grid_y = grid_xy.T

                # 需要回归的xy
                gt_xy = gt_xy - grid_xy

                select_anchors = anchors[select_anchor_index]

                # 开始准备计算GIoU
                # 在这之前,需要把预测框给计算出来
                # layer.shape -> batch, num_anchor, height, width, 5+class
                # 目的:因为要选中predict box,与gtxy, gtwh计算他的GIoU。所以需要提取layer中指定项
                # layer中
                #   - image_id指定的batch
                #   - select_anchor_index指定某个anchor
                #   - grid_y指定height维度
                #   - grid_x指定width维度
                #   - 提取后,得到: num_matched_extend_target x (5 + class)
                # object_predict.shape -> num_matched_extend_target x (5 + class)
                object_predict = layer[gt_image_id, select_anchor_index, grid_y, grid_x]

                # object_predict_xy 值域是 (-0.5, +1.5)
                object_predict_xy = object_predict[:, [0, 1]].sigmoid() * 2.0 - 0.5

                # object_predict_wh 值域是 (0, +4)
                object_predict_wh = torch.pow(object_predict[:, [2, 3]].sigmoid() * 2.0, 2.0) * select_anchors

                # 拼接为:N x 4,[cx, cy, width, height]
                object_predict_box = torch.cat((object_predict_xy, object_predict_wh), dim=1)

                # 拼接为: N x 4,[cx, cy, width, height]
                object_ground_truth_box = torch.cat((gt_xy, gt_wh), dim=1)
                gious = self.giou(object_predict_box, object_ground_truth_box)
                giou_loss = 1.0 - gious
                loss_box_regression += giou_loss.mean()

                objectness_ground_true[gt_image_id, select_anchor_index, grid_y, grid_x] = gious.detach().clamp(0)
                if self.num_classes > 1:
                    object_classification = object_predict[:, 5:]
                    # 这里使用二元进行多分类问题
                    # 假设【猪,狗,猫】
                    # 二元进行多分类进行多分类
                    # 如果标签是猫
                    classification_targets = torch.zeros_like(object_classification)
                    classification_targets[torch.arange(matched_extend_num_target), gt_classes_id] = 1.0
                    loss_classification += self.BCEClassification(object_classification, classification_targets)
            loss_objectness += self.BCEObjectness(featuremap_object, objectness_ground_true) * self.balance[ilayer]

        # 加权求和
        num_level = len(predict)
        scale = 3 / num_level

        batch_size = predict[0].shape[0]
        loss_box_regression *= self.box_weight * scale
        loss_objectness *= self.objectness_weight * scale  # 如果 num_level == 4 这里需要乘以1.4,否则乘以1.0
        loss_classification *= self.classification_weight * scale

        loss = loss_box_regression + loss_objectness + loss_classification
        return loss * batch_size


def train():
    train_set = dataset.VOCDataSet(True, 640, "E:\VOC2007\VOCdevkit\VOC2007")
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=2,
                                               num_workers=0,
                                               shuffle=True,
                                               pin_memory=True,
                                               collate_fn=train_set.collate_fn)
    device = "cuda"
    head = YoloHead(train_set.num_classes).to(device)
    model = models.Yolo(train_set.num_classes, "E:/杜老师课程/100_du/02.22/yolov5-2.0/models/yolov5s.yaml").to(device)
    optimizer = optim.SGD(model.parameters(), 1e-2, 0.9)

    for batch_index, (images, targets, visuals) in enumerate(train_loader):
        images = images.to(device)
        targets = targets.to(device)
        predict = model(images)
        loss = head(predict, targets)
        print(loss)
        break


if __name__ == '__main__':
    nn_utils.setup_seed(3)
    train()
View Code

其中不明白的就是拓展样本,不是很理解

标签:YOLOV5,num,self,torch,anchor,对错,train,targets,select
From: https://www.cnblogs.com/xiaoruirui/p/16906747.html

相关文章