首页 > 编程语言 >YOLOV5源码解读-general.py、detect.py

YOLOV5源码解读-general.py、detect.py

时间:2023-09-27 15:55:32浏览次数:48  
标签:YOLOV5 torch img py 源码 weights import path save

YOLOV5.4,可能与之前版本不一样,但大同小异

general.py

  1 # YOLOv5 general utils
  2 
  3 import glob
  4 import logging
  5 import math
  6 import os
  7 import platform
  8 import random
  9 import re
 10 import subprocess
 11 import time
 12 from pathlib import Path
 13 
 14 import cv2
 15 import numpy as np
 16 import torch
 17 import torchvision
 18 import yaml
 19 
 20 from utils.google_utils import gsutil_getsize
 21 from utils.metrics import fitness
 22 from utils.torch_utils import init_torch_seeds
 23 
 24 # Settings
 25 torch.set_printoptions(linewidth=320, precision=5, profile='long')
 26 np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format})  # format short g, %precision=5
 27 cv2.setNumThreads(0)  # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
 28 os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8))  # NumExpr max threads
 29 
 30 
 31 def set_logging(rank=-1):
 32     logging.basicConfig(
 33         format="%(message)s",
 34         level=logging.INFO if rank in [-1, 0] else logging.WARN)
 35 
 36 
 37 def init_seeds(seed=0):
 38     # Initialize random number generator (RNG) seeds
 39     random.seed(seed)
 40     np.random.seed(seed)
 41     init_torch_seeds(seed)
 42 
 43 
 44 def get_latest_run(search_dir='.'):
 45     # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
 46     last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
 47     return max(last_list, key=os.path.getctime) if last_list else ''
 48 
 49 
 50 def isdocker():
 51     # Is environment a Docker container
 52     return Path('/workspace').exists()  # or Path('/.dockerenv').exists()
 53 
 54 
 55 def emojis(str=''):
 56     # Return platform-dependent emoji-safe version of string
 57     return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
 58 
 59 
 60 def check_online():
 61     # Check internet connectivity
 62     import socket
 63     try:
 64         socket.create_connection(("1.1.1.1", 443), 5)  # check host accesability
 65         return True
 66     except OSError:
 67         return False
 68 
 69 
 70 def check_git_status():
 71     # Recommend 'git pull' if code is out of date
 72     print(colorstr('github: '), end='')
 73     try:
 74         assert Path('.git').exists(), 'skipping check (not a git repository)'
 75         assert not isdocker(), 'skipping check (Docker image)'
 76         assert check_online(), 'skipping check (offline)'
 77 
 78         cmd = 'git fetch && git config --get remote.origin.url'
 79         url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git')  # github repo url
 80         branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip()  # checked out
 81         n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True))  # commits behind
 82         if n > 0:
 83             s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
 84                 f"Use 'git pull' to update or 'git clone {url}' to download latest."
 85         else:
 86             s = f'up to date with {url} ✅'
 87         print(emojis(s))  # emoji-safe
 88     except Exception as e:
 89         print(e)
 90 
 91 
 92 def check_requirements(file='requirements.txt', exclude=()):
 93     # Check installed dependencies meet requirements
 94     import pkg_resources as pkg
 95     prefix = colorstr('red', 'bold', 'requirements:')
 96     file = Path(file)
 97     if not file.exists():
 98         print(f"{prefix} {file.resolve()} not found, check failed.")
 99         return
100 
101     n = 0  # number of packages updates
102     requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
103     for r in requirements:
104         try:
105             pkg.require(r)
106         except Exception as e:  # DistributionNotFound or VersionConflict if requirements not met
107             n += 1
108             print(f"{prefix} {e.req} not found and is required by YOLOv5, attempting auto-update...")
109             print(subprocess.check_output(f"pip install '{e.req}'", shell=True).decode())
110 
111     if n:  # if packages updated
112         s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file.resolve()}\n" \
113             f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
114         print(emojis(s))  # emoji-safe
115 
116 
117 def check_img_size(img_size, s=32):
118     # Verify img_size is a multiple of stride s
119     new_size = make_divisible(img_size, int(s))  # ceil gs-multiple
120     if new_size != img_size:
121         print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
122     return new_size
123 
124 
125 def check_imshow():
126     # Check if environment supports image displays
127     try:
128         assert not isdocker(), 'cv2.imshow() is disabled in Docker environments'
129         cv2.imshow('test', np.zeros((1, 1, 3)))
130         cv2.waitKey(1)
131         cv2.destroyAllWindows()
132         cv2.waitKey(1)
133         return True
134     except Exception as e:
135         print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
136         return False
137 
138 
139 def check_file(file):
140     # Search for file if not found
141     if os.path.isfile(file) or file == '':
142         return file
143     else:
144         files = glob.glob('./**/' + file, recursive=True)  # find file
145         assert len(files), 'File Not Found: %s' % file  # assert file was found
146         assert len(files) == 1, "Multiple files match '%s', specify exact path: %s" % (file, files)  # assert unique
147         return files[0]  # return file
148 
149 
150 def check_dataset(dict):
151     # Download dataset if not found locally
152     val, s = dict.get('val'), dict.get('download')
153     if val and len(val):
154         val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])]  # val path
155         if not all(x.exists() for x in val):
156             print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
157             if s and len(s):  # download script
158                 print('Downloading %s ...' % s)
159                 if s.startswith('http') and s.endswith('.zip'):  # URL
160                     f = Path(s).name  # filename
161                     torch.hub.download_url_to_file(s, f)
162                     r = os.system('unzip -q %s -d ../ && rm %s' % (f, f))  # unzip
163                 else:  # bash script
164                     r = os.system(s)
165                 print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure'))  # analyze return value
166             else:
167                 raise Exception('Dataset not found.')
168 
169 
170 def make_divisible(x, divisor):
171     # Returns x evenly divisible by divisor
172     return math.ceil(x / divisor) * divisor
173 
174 
175 def clean_str(s):
176     # Cleans a string by replacing special characters with underscore _
177     return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
178 
179 
180 def one_cycle(y1=0.0, y2=1.0, steps=100):
181     # lambda function for sinusoidal ramp from y1 to y2
182     return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
183 
184 
185 def colorstr(*input):
186     # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e.  colorstr('blue', 'hello world')
187     *args, string = input if len(input) > 1 else ('blue', 'bold', input[0])  # color arguments, string
188     colors = {'black': '\033[30m',  # basic colors
189               'red': '\033[31m',
190               'green': '\033[32m',
191               'yellow': '\033[33m',
192               'blue': '\033[34m',
193               'magenta': '\033[35m',
194               'cyan': '\033[36m',
195               'white': '\033[37m',
196               'bright_black': '\033[90m',  # bright colors
197               'bright_red': '\033[91m',
198               'bright_green': '\033[92m',
199               'bright_yellow': '\033[93m',
200               'bright_blue': '\033[94m',
201               'bright_magenta': '\033[95m',
202               'bright_cyan': '\033[96m',
203               'bright_white': '\033[97m',
204               'end': '\033[0m',  # misc
205               'bold': '\033[1m',
206               'underline': '\033[4m'}
207     return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
208 
209 
210 def labels_to_class_weights(labels, nc=80):
211     # Get class weights (inverse frequency) from training labels
212     if labels[0] is None:  # no labels loaded
213         return torch.Tensor()
214 
215     labels = np.concatenate(labels, 0)  # labels.shape = (866643, 5) for COCO
216     classes = labels[:, 0].astype(np.int)  # labels = [class xywh]
217     weights = np.bincount(classes, minlength=nc)  # occurrences per class
218 
219     # Prepend gridpoint count (for uCE training)
220     # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum()  # gridpoints per image
221     # weights = np.hstack([gpi * len(labels)  - weights.sum() * 9, weights * 9]) ** 0.5  # prepend gridpoints to start
222 
223     weights[weights == 0] = 1  # replace empty bins with 1
224     weights = 1 / weights  # number of targets per class
225     weights /= weights.sum()  # normalize
226     return torch.from_numpy(weights)
227 
228 
229 def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
230     # Produces image weights based on class_weights and image contents
231     class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
232     image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
233     # index = random.choices(range(n), weights=image_weights, k=1)  # weight image sample
234     return image_weights
235 
236 
237 def coco80_to_coco91_class():  # converts 80-index (val2014) to 91-index (paper)
238     # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
239     # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
240     # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
241     # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)]  # darknet to coco
242     # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)]  # coco to darknet
243     x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
244          35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
245          64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
246     return x
247 
248 
249 def xyxy2xywh(x):
250     # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
251     y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
252     y[:, 0] = (x[:, 0] + x[:, 2]) / 2  # x center
253     y[:, 1] = (x[:, 1] + x[:, 3]) / 2  # y center
254     y[:, 2] = x[:, 2] - x[:, 0]  # width
255     y[:, 3] = x[:, 3] - x[:, 1]  # height
256     return y
257 
258 
259 def xywh2xyxy(x):
260     # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
261     y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
262     y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
263     y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
264     y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
265     y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
266     return y
267 
268 
269 def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
270     # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
271     y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
272     y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw  # top left x
273     y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh  # top left y
274     y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw  # bottom right x
275     y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh  # bottom right y
276     return y
277 
278 
279 def xyn2xy(x, w=640, h=640, padw=0, padh=0):
280     # Convert normalized segments into pixel segments, shape (n,2)
281     y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
282     y[:, 0] = w * x[:, 0] + padw  # top left x
283     y[:, 1] = h * x[:, 1] + padh  # top left y
284     return y
285 
286 
287 def segment2box(segment, width=640, height=640):
288     # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
289     x, y = segment.T  # segment xy
290     inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
291     x, y, = x[inside], y[inside]
292     return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4))  # xyxy
293 
294 
295 def segments2boxes(segments):
296     # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
297     boxes = []
298     for s in segments:
299         x, y = s.T  # segment xy
300         boxes.append([x.min(), y.min(), x.max(), y.max()])  # cls, xyxy
301     return xyxy2xywh(np.array(boxes))  # cls, xywh
302 
303 
304 def resample_segments(segments, n=1000):
305     # Up-sample an (n,2) segment
306     for i, s in enumerate(segments):
307         x = np.linspace(0, len(s) - 1, n)
308         xp = np.arange(len(s))
309         segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T  # segment xy
310     return segments
311 
312 
313 def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
314     # Rescale coords (xyxy) from img1_shape to img0_shape
315     if ratio_pad is None:  # calculate from img0_shape
316         gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
317         pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
318     else:
319         gain = ratio_pad[0][0]
320         pad = ratio_pad[1]
321 
322     coords[:, [0, 2]] -= pad[0]  # x padding
323     coords[:, [1, 3]] -= pad[1]  # y padding
324     coords[:, :4] /= gain
325     clip_coords(coords, img0_shape)
326     return coords
327 
328 
329 def clip_coords(boxes, img_shape):
330     # Clip bounding xyxy bounding boxes to image shape (height, width)
331     boxes[:, 0].clamp_(0, img_shape[1])  # x1
332     boxes[:, 1].clamp_(0, img_shape[0])  # y1
333     boxes[:, 2].clamp_(0, img_shape[1])  # x2
334     boxes[:, 3].clamp_(0, img_shape[0])  # y2
335 
336 
337 def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
338     # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
339     box2 = box2.T
340 
341     # Get the coordinates of bounding boxes
342     if x1y1x2y2:  # x1, y1, x2, y2 = box1
343         b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
344         b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
345     else:  # transform from xywh to xyxy
346         b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
347         b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
348         b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
349         b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
350 
351     # Intersection area
352     inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
353             (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
354 
355     # Union Area
356     w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
357     w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
358     union = w1 * h1 + w2 * h2 - inter + eps
359 
360     iou = inter / union
361     if GIoU or DIoU or CIoU:
362         cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)  # convex (smallest enclosing box) width
363         ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)  # convex height
364         if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
365             c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
366             rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
367                     (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center distance squared
368             if DIoU:
369                 return iou - rho2 / c2  # DIoU
370             elif CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
371                 v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
372                 with torch.no_grad():
373                     alpha = v / (v - iou + (1 + eps))
374                 return iou - (rho2 / c2 + v * alpha)  # CIoU
375         else:  # GIoU https://arxiv.org/pdf/1902.09630.pdf
376             c_area = cw * ch + eps  # convex area
377             return iou - (c_area - union) / c_area  # GIoU
378     else:
379         return iou  # IoU
380 
381 
382 def box_iou(box1, box2):
383     # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
384     """
385     Return intersection-over-union (Jaccard index) of boxes.
386     Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
387     Arguments:
388         box1 (Tensor[N, 4])
389         box2 (Tensor[M, 4])
390     Returns:
391         iou (Tensor[N, M]): the NxM matrix containing the pairwise
392             IoU values for every element in boxes1 and boxes2
393     """
394 
395     def box_area(box):
396         # box = 4xn
397         return (box[2] - box[0]) * (box[3] - box[1])
398 
399     area1 = box_area(box1.T)
400     area2 = box_area(box2.T)
401 
402     # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
403     inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
404     return inter / (area1[:, None] + area2 - inter)  # iou = inter / (area1 + area2 - inter)
405 
406 
407 def wh_iou(wh1, wh2):
408     # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
409     wh1 = wh1[:, None]  # [N,1,2]
410     wh2 = wh2[None]  # [1,M,2]
411     inter = torch.min(wh1, wh2).prod(2)  # [N,M]
412     return inter / (wh1.prod(2) + wh2.prod(2) - inter)  # iou = inter / (area1 + area2 - inter)
413 
414 
415 def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
416                         labels=()):
417     """Runs Non-Maximum Suppression (NMS) on inference results
418 
419     Returns:
420          list of detections, on (n,6) tensor per image [xyxy, conf, cls]
421     """
422     # prediction.shape
423     # torch.Size([1, 10080, 7])
424     nc = prediction.shape[2] - 5  # number of classes 为什么是减去5?,因为5表示4个位置信息(xyxy),加一个置信度得分,置信度计算公式见下
425     # prediction[..., 4]:取出所有通道中矩阵的第四列
426     # xc是一个bool类型的list,IOU大于阈值的为true
427     # xc.shape
428     # torch.Size([1, 10080])
429     xc = prediction[..., 4] > conf_thres  # candidates
430 
431     # Settings
432     min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
433     max_det = 300  # maximum number of detections per image
434     max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
435     time_limit = 10.0  # seconds to quit after
436     redundant = True  # require redundant detections
437     multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
438     merge = False  # use merge-NMS
439 
440     t = time.time()
441     # output
442     # [tensor([], size=(0, 6))]:[xyxy confidence classification]
443     output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
444     for xi, x in enumerate(prediction):  # image index, image inference
445         # Apply constraints
446         # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
447         # x.shape
448         # torch.Size([2387, 7])
449         x = x[xc[xi]]  # confidence ???????????????????????????????????
450 
451         # Cat apriori labels if autolabelling
452         if labels and len(labels[xi]):
453             l = labels[xi]
454             v = torch.zeros((len(l), nc + 5), device=x.device)
455             v[:, :4] = l[:, 1:5]  # box
456             v[:, 4] = 1.0  # conf
457             v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
458             x = torch.cat((x, v), 0)
459 
460         # If none remain process next image
461         if not x.shape[0]:
462             continue
463         # 逐行相乘
464         # x.shape
465         # torch.Size([2387, 7]),有2387个候选框
466 
467         # Compute conf;
468         # conf = cls_conf * obj_conf,即:Pr(Classi) = Pr(Classi|Object) * Pr(Object),注意置信度公式和YOLOV123区别
469         # 其中 obj_conf = 候选框(bounding box)存在对象的概率;
470         # cls_conf = 如果当前网格(cell)存在对象,是类别i(i=1 2 …N,一共N个类别)的概率;
471 
472         # 第5、6列分别乘第4列
473         # 第5、6列对应:Pr(Classi|Object),对应cell两个类别概率
474         # 第4列对应:Pr(Object),对应候选框存在对象概率
475         # 所以第0、1、2、3列就是候选框几何信息
476         # 综上x中每一行为:[x y x y Pr(Object) Pr(Class1|Object) Pr(Class2|Object)]
477         x[:, 5:] *= x[:, 4:5]
478 
479         # 取出预测框位置信息
480         # Box (center x, center y, width, height) to (x1, y1, x2, y2)
481         # box.shape
482         # torch.Size([2387, 4])
483         box = xywh2xyxy(x[:, :4])
484 
485         # Detections matrix nx6 (xyxy, conf, cls)
486         if multi_label:
487             i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
488             x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
489         else:  # best class only
490             # 对比x的第5、6列的置信度,note:上面已经算过了,这里不再是概率
491             # conf.shape
492             # torch.Size([2387, 1])
493             # j.shape
494             # torch.Size([2387, 1])
495             # j的每个元素取值为0 or 1, 表示大的值对应索引
496             conf, j = x[:, 5:].max(1, keepdim=True)
497             # 这里按照(box conf j)重组x
498             # x.shape
499             # torch.Size([2387, 7])
500             x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
501             # x.shape
502             # torch.Size([2387, 6])
503 
504         # Filter by class
505         if classes is not None:
506             x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
507 
508         # Apply finite constraint
509         # if not torch.isfinite(x).all():
510         #     x = x[torch.isfinite(x).all(1)]
511 
512         # Check shape
513         n = x.shape[0]  # number of boxes
514         if not n:  # no boxes
515             continue
516         elif n > max_nms:  # excess boxes
517             # 将张量x按照第4列(置信度)进行排序,从大到小,最多取max_nms个,所以,x的行数是max_num,这一样大多数情况不会执行
518             x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence
519 
520         # Batched NMS
521         # ?????????????????????????? 这个预测框的偏移量是如何计算的??????
522         c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
523         # scores:也就是上述置信度
524         boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
525         # i:NMS后得到框的索引数组
526         i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
527         if i.shape[0] > max_det:  # limit detections
528             i = i[:max_det]
529         if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
530             # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
531             iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
532             weights = iou * scores[None]  # box weights
533             x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
534             if redundant:
535                 i = i[iou.sum(1) > 1]  # require redundancy
536         # output[0].shape
537         # torch.Size([225, 6])
538         output[xi] = x[i]
539         if (time.time() - t) > time_limit:
540             print(f'WARNING: NMS time limit {time_limit}s exceeded')
541             break  # time limit exceeded
542 
543     return output
544 
545 
546 def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_optimizer()
547     # Strip optimizer from 'f' to finalize training, optionally save as 's'
548     x = torch.load(f, map_location=torch.device('cpu'))
549     if x.get('ema'):
550         x['model'] = x['ema']  # replace model with ema
551     for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates':  # keys
552         x[k] = None
553     x['epoch'] = -1
554     x['model'].half()  # to FP16
555     for p in x['model'].parameters():
556         p.requires_grad = False
557     torch.save(x, s or f)
558     mb = os.path.getsize(s or f) / 1E6  # filesize
559     print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
560 
561 
562 def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
563     # Print mutation results to evolve.txt (for use with train.py --evolve)
564     a = '%10s' * len(hyp) % tuple(hyp.keys())  # hyperparam keys
565     b = '%10.3g' * len(hyp) % tuple(hyp.values())  # hyperparam values
566     c = '%10.4g' * len(results) % results  # results (P, R, [email protected], [email protected]:0.95, val_losses x 3)
567     print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
568 
569     if bucket:
570         url = 'gs://%s/evolve.txt' % bucket
571         if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
572             os.system('gsutil cp %s .' % url)  # download evolve.txt if larger than local
573 
574     with open('evolve.txt', 'a') as f:  # append result
575         f.write(c + b + '\n')
576     x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0)  # load unique rows
577     x = x[np.argsort(-fitness(x))]  # sort
578     np.savetxt('evolve.txt', x, '%10.3g')  # save sort by fitness
579 
580     # Save yaml
581     for i, k in enumerate(hyp.keys()):
582         hyp[k] = float(x[0, i + 7])
583     with open(yaml_file, 'w') as f:
584         results = tuple(x[0, :7])
585         c = '%10.4g' * len(results) % results  # results (P, R, [email protected], [email protected]:0.95, val_losses x 3)
586         f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
587         yaml.dump(hyp, f, sort_keys=False)
588 
589     if bucket:
590         os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket))  # upload
591 
592 
593 def apply_classifier(x, model, img, im0):
594     # applies a second stage classifier to yolo outputs
595     im0 = [im0] if isinstance(im0, np.ndarray) else im0
596     for i, d in enumerate(x):  # per image
597         if d is not None and len(d):
598             d = d.clone()
599 
600             # Reshape and pad cutouts
601             b = xyxy2xywh(d[:, :4])  # boxes
602             b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1)  # rectangle to square
603             b[:, 2:] = b[:, 2:] * 1.3 + 30  # pad
604             d[:, :4] = xywh2xyxy(b).long()
605 
606             # Rescale boxes from img_size to im0 size
607             scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
608 
609             # Classes
610             pred_cls1 = d[:, 5].long()
611             ims = []
612             for j, a in enumerate(d):  # per item
613                 cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
614                 im = cv2.resize(cutout, (224, 224))  # BGR
615                 # cv2.imwrite('test%i.jpg' % j, cutout)
616 
617                 im = im[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
618                 im = np.ascontiguousarray(im, dtype=np.float32)  # uint8 to float32
619                 im /= 255.0  # 0 - 255 to 0.0 - 1.0
620                 ims.append(im)
621 
622             pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1)  # classifier prediction
623             x[i] = x[i][pred_cls1 == pred_cls2]  # retain matching class detections
624 
625     return x
626 
627 
628 def increment_path(path, exist_ok=True, sep=''):
629     # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
630     path = Path(path)  # os-agnostic
631     if (path.exists() and exist_ok) or (not path.exists()):
632         return str(path)
633     else:
634         dirs = glob.glob(f"{path}{sep}*")  # similar paths
635         matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
636         i = [int(m.groups()[0]) for m in matches if m]  # indices
637         n = max(i) + 1 if i else 2  # increment number
638         return f"{path}{sep}{n}"  # update path

 

detect.py

  1 import argparse
  2 import time
  3 from pathlib import Path
  4 
  5 import cv2
  6 import torch
  7 import torch.backends.cudnn as cudnn
  8 from numpy import random
  9 
 10 from models.experimental import attempt_load
 11 from utils.datasets import LoadStreams, LoadImages
 12 from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
 13     scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
 14 from utils.plots import plot_one_box
 15 from utils.torch_utils import select_device, load_classifier, time_synchronized
 16 
 17 
 18 def detect(save_img=False):
 19     source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
 20     save_img = not opt.nosave and not source.endswith('.txt')  # save inference images
 21     webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
 22         ('rtsp://', 'rtmp://', 'http://'))
 23 
 24     # Directories
 25     save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))  # increment run
 26     (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir
 27 
 28     # Initialize
 29     set_logging()
 30     device = select_device(opt.device)
 31     half = device.type != 'cpu'  # half precision only supported on CUDA
 32 
 33     # Load model
 34     model = attempt_load(weights, map_location=device)  # load FP32 model
 35     stride = int(model.stride.max())  # model stride
 36     imgsz = check_img_size(imgsz, s=stride)  # check img_size
 37     if half:
 38         model.half()  # to FP16
 39 
 40     # Second-stage classifier
 41     classify = False
 42     if classify:
 43         modelc = load_classifier(name='resnet101', n=2)  # initialize
 44         modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
 45 
 46     # Set Dataloader
 47     vid_path, vid_writer = None, None
 48     if webcam:
 49         view_img = check_imshow()
 50         cudnn.benchmark = True  # set True to speed up constant image size inference
 51         dataset = LoadStreams(source, img_size=imgsz, stride=stride)
 52     else:
 53         # 这里底层使用opencv读取图片,接着使用letterbox函数进行padded resize图片
 54         dataset = LoadImages(source, img_size=imgsz, stride=stride)
 55 
 56     # Get names and colors
 57     names = model.module.names if hasattr(model, 'module') else model.names
 58     colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
 59 
 60     # Run inference
 61     if device.type != 'cpu':
 62         model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))  # run once
 63     t0 = time.time()
 64     # resize后分辨率             实际分辨率
 65     # img.shape (3, 256, 640)  im0s.shape (360, 900, 3)
 66     # img.shape (3, 480, 640)  im0s.shape (375, 500, 3)
 67     for path, img, im0s, vid_cap in dataset:
 68 
 69         #img_ = cv2.cvtColor(img , cv2.COLOR_RGB2BGR)
 70         #cv2.imshow("img", img_)
 71         #cv2.imshow("im0s", im0s)
 72         #cv2.waitKey(0)
 73 
 74         img = torch.from_numpy(img).to(device)
 75         img = img.half() if half else img.float()  # uint8 to fp16/32
 76         img /= 255.0  # 0 - 255 to 0.0 - 1.0
 77         if img.ndimension() == 3:
 78             # img.shape torch.Size([3, 256, 640])
 79             img = img.unsqueeze(0) # 增加一个维度
 80             # img.shape torch.Size([1, 3, 256, 640])
 81         # Inference
 82         t1 = time_synchronized()
 83         """
 84         前向传播 返回pred的shape是(1, num_boxes, 5+num_class)
 85         h,w为传入网络图片的长和宽,注意dataset在检测时使用了矩形推理,所以这里h不一定等于w
 86         # 矩形推理:https://blog.csdn.net/songwsx/article/details/102639770
 87         num_boxes = h/32 * w/32 + h/16 * w/16 + h/8 * w/8
 88         pred[..., 0:4]为预测框坐标
 89         预测框坐标为xywh(中心点+宽长)格式
 90         pred[..., 4]为objectness置信度
 91         pred[..., 5:-1]为分类结果
 92         """
 93         pred = model(img, augment=opt.augment)[0]
 94 
 95         # Apply NMS
 96         pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
 97         # pred[0].shape
 98         # torch.Size([225, 6])
 99         t2 = time_synchronized()
100 
101         # Apply Classifier
102         if classify:
103             pred = apply_classifier(pred, modelc, img, im0s)
104 
105         # Process detections
106         # det: detction
107         for i, det in enumerate(pred):  # detections per image
108             if webcam:  # batch_size >= 1
109                 p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
110             else:
111                 p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
112 
113             p = Path(p)  # to Path
114             save_path = str(save_dir / p.name)  # img.jpg
115             txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # img.txt
116             s += '%gx%g ' % img.shape[2:]  # print string
117             # im0(source image) (360, 900, 3)
118             # gn tensor([900, 360, 900, 360])
119             gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
120             if len(det):
121                 # Rescale boxes from img_size to im0 size
122                 # 调整框大小,缩放到实际原图中去
123                 det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
124 
125                 # Print results
126                 for c in det[:, -1].unique():
127                     n = (det[:, -1] == c).sum()  # detections per class
128                     s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string
129 
130                 # Write results
131                 # conf:置信度得分 cls:类别
132                 for *xyxy, conf, cls in reversed(det):
133                     if save_txt:  # Write to file
134                         # 框的几何信息转换:xyxy -> xywh
135                         xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
136                         line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh)  # label format
137                         with open(txt_path + '.txt', 'a') as f:
138                             f.write(('%g ' * len(line)).rstrip() % line + '\n')
139 
140                     if save_img or view_img:  # Add bbox to image
141                         label = f'{names[int(cls)]} {conf:.2f}'
142                         #plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
143                         plot_one_box(xyxy, im0, label='', color=colors[int(cls)], line_thickness=2)
144 
145             # Print time (inference + NMS)
146             print(f'{s}Done. ({t2 - t1:.3f}s)')
147 
148             # Stream results
149             if view_img:
150                 cv2.imshow(str(p), im0)
151                 cv2.waitKey(1)  # 1 millisecond
152 
153             # Save results (image with detections)
154             if save_img:
155                 if dataset.mode == 'image':
156                     cv2.imwrite(save_path, im0)
157                 else:  # 'video' or 'stream'
158                     if vid_path != save_path:  # new video
159                         vid_path = save_path
160                         if isinstance(vid_writer, cv2.VideoWriter):
161                             vid_writer.release()  # release previous video writer
162                         if vid_cap:  # video
163                             fps = vid_cap.get(cv2.CAP_PROP_FPS)
164                             w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
165                             h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
166                         else:  # stream
167                             fps, w, h = 30, im0.shape[1], im0.shape[0]
168                             save_path += '.mp4'
169                         vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
170                     vid_writer.write(im0)
171 
172     if save_txt or save_img:
173         s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
174         print(f"Results saved to {save_dir}{s}")
175 
176     print(f'Done. ({time.time() - t0:.3f}s)')
177 
178 #  --source D:/Data/yolo  --weights weights/best.pt --device cpu
179 if __name__ == '__main__':
180     parser = argparse.ArgumentParser()
181     parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
182     parser.add_argument('--source', type=str, default='data/images', help='source')  # file/folder, 0 for webcam
183     parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
184     parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
185     parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
186     parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
187     parser.add_argument('--view-img', action='store_true', help='display results')
188     parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
189     parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
190     parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
191     parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
192     parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
193     parser.add_argument('--augment', action='store_true', help='augmented inference')
194     parser.add_argument('--update', action='store_true', help='update all models')
195     parser.add_argument('--project', default='runs/detect', help='save results to project/name')
196     parser.add_argument('--name', default='exp', help='save results to project/name')
197     parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
198     opt = parser.parse_args()
199     print(opt)
200     check_requirements(exclude=('pycocotools', 'thop'))
201 
202     with torch.no_grad():
203         if opt.update:  # update all models (to fix SourceChangeWarning)
204             for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
205                 detect()
206                 strip_optimizer(opt.weights)
207         else:
208             detect()

 

标签:YOLOV5,torch,img,py,源码,weights,import,path,save
From: https://www.cnblogs.com/feiyull/p/14589991.html

相关文章

  • python numpy 计算1-10000 平方 立方 执行效率
    importsysfromdatetimeimportdatetimeimportnumpyasnpimportmatplotlib.pyplotasplt#使用NumPy计算defnumpysum(n):a=np.arange(n)**2b=np.arange(n)**3c=a+breturnc#使用Python计算#并这里由于源码为Python2的,python3中rang......
  • python装饰器执行顺序
    Python的装饰器是应用的函数或方法的特殊类型改变,它们会在被装饰的函数或方法被调用时执行。你可以使用多个装饰器来装饰一个函数,装饰器的执行顺序与它们应用的顺序有关。#使用两个装饰器装饰一个函数@decorator1@decorator2deffunc():pass在上述代码中,首先应用的......
  • WEBRTC回声消除-AECM算法源码解析之参数解析
    一概述 webrtc针对回声问题一共开源了3种回声消除算法,分别为aec,aecm,以及aec3,其中aec是最早期的版本,在后续的更新中aec3的出现代替了aec在webrtc中的地位,而aecm主要是针对计算能力较弱的移动端或是嵌入式设备而开发的,但同时也带来了它自己的劣势;本文主要介绍AECM算法的计......
  • python爬取手机壁纸
    无聊随便玩玩,要爬成功还早着呢,代码很乱可以整理,写了就记录一下吧,有机会再改。importrequestsimportosfrombs4importBeautifulSoupfromrequests.packagesimporturllib3importrandomimportthreadingimporttimeurllib3.disable_warnings()start_page=1end_......
  • DeepLabV3+网络模型与源码解读
    源码链接:链接:https://pan.baidu.com/s/1GkUM9WiGpzUHuFgBe1t2rA提取码:57zrorhttps://github.com/VainF/DeepLabV3Plus-Pytorch以上两个连接是一样的,只不过百度盘中的包含voc数据。 环境安装:先装pytorch&torchvision,再安装requirments.txt其他依赖 报错处理:#error:ra......
  • Python工具箱系列(四十三)
    tar文件操作tar命令是Unix/Linux平台用的最多的命令之一。原始的tar只具备打包和解包的功能:TapeARchive,本义就是把文件打包备份到磁带机。GNU为tar增加了很多新功能,比如支持各种压缩格式。在Unix中一切都是文件:普通文件,文件夹,符号链接,设备文件等等。tar包就是由一个个文件顺序排......
  • python DAY4
    有时候输入时候就可以解决处理问题,比如下面这种做法:  记得这种写法:这个写法算的是从1到x。  当无法判断有多少个输入样例时候,持续输入的大条件可以是: 赋值可以这么写: 这样就能避免赋值错误   for语句实际上是遍历一个集合,上图是遍历字符串 ......
  • python numpy 数组操作
          ......
  • Pycharm安装bs4第三方库出错
    昨日正好写的demo需要bs4包,然而安装该库出现了许多问题,下面是复盘以及解决方式(最后直达)。直接安装:点击file(文件)->setting设置进入下界面后,找到自己的项目中的PythonInterperter,发现确实没有bs4,当然就想到进行安装。 点击右上角的加号进入AvailablePackages界面,即可寻找......
  • 手机直播源码,Android 简单的弹框
    手机直播源码,Android简单的弹框   privatestaticString[]items=newString[]{      "拍照",      "从相册中选择",  }; AlertDialog.Builderbuilder=newAlertDialog.Builder(MainActivity.this)        .setTitle(......