首页 > 其他分享 >yolov5-采用k-means进行锚框的聚类

yolov5-采用k-means进行锚框的聚类

时间:2024-07-25 16:53:28浏览次数:19  
标签:yolov5 clss img means self label props print 锚框

K-means算法是一种无监督学习方法,主要用于数据聚类,即将相似的数据点分组到同一类别中。其基本思想是通过迭代过程,将数据集划分为K个簇(cluster),每个簇由一个中心点(centroid)表示,而簇内的数据点与该簇中心点的距离最小。在计算机视觉中,它常被用于找到图像中物体的锚框。

K-means算法步骤:

  1. 初始化:首先随机选择K个数据点作为初始的簇中心点。

  2. 分配数据点:将每个数据点分配给最近的簇中心点,形成K个簇。这里通常使用欧几里得距离作为相似度度量。

  3. 更新中心点:重新计算每个簇的中心点,即簇内所有数据点的平均值。这一步使得簇中心点向簇内数据点的集中位置移动。

  4. 重复步骤2和3:不断重复分配数据点和更新中心点的过程,直到簇中心点不再发生显著变化,或者达到预设的最大迭代次数。

算法终止条件:

  • 簇中心点的变化小于某个阈值。
  • 达到预设的最大迭代次数。

以下是使用K-means算法找到YOLO锚框的Python代码示例:

# -*- coding:utf-8 -*-
import os
import json
import cv2
import pprint
import multiprocessing
import numpy as np



mmp = {
    "RoadDamage": {"subtype": {"Pothole": "144", "Crazing": "145", "Repair": "146", "Crack": "147", "Pulverization": "148", "Others": "214",  "ConstructionJoint": "225"},"Direction":{"Lateral": "223", "Longitudinal": "224"}},
    # 道路破损 城管类 坑洼、龟裂、修补、裂痕、粉化、施工接缝 方向:横向、竖向
}


def name_2_classid(name):
    clss = name.strip().split('/')
    clss0 = mmp[clss[0]]
    if len(clss) > 1:
        for clss_ in clss[1:]:
            p, v = clss_.strip().split('-')
            if clss0 == 'NonMotorVehicle' and p == 'WithHuman':
                assert v == 'NoHuman'
                continue
            return int(clss0[p][v])
            
    else:
        if isinstance(clss0, str):
            return int(clss0)
        else:
            if 'none' in clss0:
                return int(clss0['none'])
            elif 'None' in clss0:
                return int(clss0['None'])
            else:
                raise
    
    raise
    

all_type = list(mmp.keys())

img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng']  # acceptable image suffixes

def json2label(js): # x y w h (0.0~1.0)

    ignore = False

    if not os.path.exists(js):
        print(js,'not exists')
        return None

    try:
        data = json.load(open(js, 'r'))  # 将JSON格式的文件对象读取并转换为Python的数据结构(如字典或列表)
    except Exception:
        print(js, 'read error')
        return None

    out_data = []
    schemaVer = data.get('schemaVer', 1)
    if  schemaVer == 2: # jinn platform
        objects = data['objects']
        image_w, image_h = data['imgSize']
        image_w, image_h = int(image_w), int(image_h)
    
        for obj in objects:
            clss = obj['class']
            one_label = None
            if clss not in mmp:
                continue
            
            coord = obj['coord']
            props = obj['props']
            if clss == 'DiscardedThing': # 多边形
                pt = []
                for p in coord:
                    pt.append([float(p[0]), float(p[1])])
                pts = np.array(pt, np.int32)
                # xmax = max(pts[:,0])
                # xmin = min(pts[:,0])
                # ymax = max(pts[:,1])
                # ymin = min(pts[:,1])
                # w = float(xmax - xmin) / image_w
                # h = float(ymax - ymin) / image_h
                # x = float(xmax + xmin) / 2.0 / image_w
                # y = float(ymax + ymin) / 2.0 / image_h
                w,h,x,y = 0,0,0,0
                

            else:
                if len(coord) > 2:
                    continue
                x1, y1 = coord[0]
                x2, y2 = coord[1]
                x1, x2, y1, y2 = int(x1),int(x2), int(y1), int(y2)
                x = (x1 + x2) / 2.0 / image_w
                y = (y1 + y2) / 2.0 / image_h
                w = 1.0 * abs(x2 - x1) / image_w
                h = 1.0 * abs(y2 - y1) / image_h

            if props:

                if clss == 'NonMotorVehicle':
                    if 'WithHuman' in props and props['WithHuman'][0].strip() == "HasHuman":
                        continue
                if clss in ['HumanBody', 'Passerby', 'NonMotorVehicle', 'MotorVehicle']:
                    if 'EstimatedStatus' in props and props['EstimatedStatus'][0].strip() == 'Completion':
                        continue
                if clss == 'Passerby' and 'Reality' in props:
                    continue    
                if isinstance(mmp[clss], dict):
                    props_keys = list(mmp[clss].keys())
                    if 'None' in props_keys:
                        props_keys.remove('None')
                    if 'none' in props_keys:
                        props_keys.remove('none')

                    props_key = props_keys[0]

                    get_prop = False
                    for props_key in props_keys:
                        if props_key in props:
                            get_prop = True
                            
                            props_value = props[props_key][0].strip()
                            try:
                                label = int(mmp[clss][props_key][props_value])
                            except:
                                print('######################################')
                                print(js)
                                print(clss, props_key, props_value)
                                continue
                            one_label = [label, x, y, w, h]
                            # out_data.append([label, x, y, w, h])
                    
                    if not get_prop:
                        ignore = True
                        print('######################################')
                        print(props_key)
                        print(js, 'props error')
                        break

                else:
                    print("{}:{}".format(clss,mmp[clss]))
                    label = int(mmp[clss])
                    one_label = [label, x, y, w, h]
                    # out_data.append([label, x, y, w, h])

            else:
                # print("{}:{}".format(clss,mmp[clss]))
                if isinstance(mmp[clss], str):
                    label = int(mmp[clss])
                else:
                    # print(clss)
                    props_keys = list(mmp[clss].keys())
                    if 'None' in props_keys:
                        label = int(mmp[clss]['None'])
                    elif 'none' in props_keys:
                        label = int(mmp[clss]['none'])
                    else:
                        print('######################################')
                        print(js, clss, mmp[clss])

                        ignore = True
                        break
                one_label = [label, x, y, w, h]
                # out_data.append([label, x, y, w, h])
            
            if clss == 'DiscardedThing':
                one_label = [label] + list(pts)
            
            out_data.append(one_label)

    else:
        print(schemaVer)
        ignore = True

    if ignore:
        return None

    return out_data

if __name__ == '__main__':
  
    print(json2label('0001024.jpg.json'))
import os
import cv2
import math
import glob
import time
import random
import argparse
import numpy as np
from tqdm import tqdm
import multiprocessing
from PIL import Image, ExifTags
from scipy.cluster.vq import kmeans

import torch
from torch.utils.data import Dataset

# import type_map
import type_map


help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'

# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
    if ExifTags.TAGS[orientation] == 'Orientation':
        break

def exif_size(img):
    # Returns exif-corrected PIL size
    s = img.size  # (width, height)
    try:
        # img._getexif().items()的作用是从图像的EXIF信息中获取所有的键值对。EXIF信息通常是一个字典,包含了关于图像的元数据,如拍摄时间、相机型号、图像方向等。
        # 通过将_getexif()的结果转换为字典并调用.items(),你可以遍历这些元数据,查找特定的标签(例如orientation),并获取其对应的值。
        rotation = dict(img._getexif().items())[orientation]
        if rotation == 6:  # rotation 270
            s = (s[1], s[0]) # 调整尺寸,将宽度和高度互换,因为旋转后高度变成了宽度,宽度变成了高度
        elif rotation == 8:  # rotation 90
            s = (s[1], s[0])
    except:
        pass

    return s


def colorstr(*input):
    # 定义一个函数colorstr,接受任意数量的参数。这个函数的目的是给字符串添加颜色。
    # 例如,调用colorstr('blue', 'hello world')将使'hello world'显示为蓝色。
    # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e.  colorstr('blue', 'hello world')

    #  例如,如果调用colorstr('red', 'green', 'hello world'),那么prefix将是一个包含('red', 'green')的元组,str将等于'hello world'。
    *prefix, str = input  # color arguments, string
    colors = {'black': '\033[30m',  # basic colors  # 基本颜色定义,使用ANSI转义码。
              'red': '\033[31m',
              'green': '\033[32m',
              'yellow': '\033[33m',
              'blue': '\033[34m',
              'magenta': '\033[35m',
              'cyan': '\033[36m',
              'white': '\033[37m',
              'bright_black': '\033[90m',  # bright colors
              'bright_red': '\033[91m',
              'bright_green': '\033[92m',
              'bright_yellow': '\033[93m',
              'bright_blue': '\033[94m',
              'bright_magenta': '\033[95m',
              'bright_cyan': '\033[96m',
              'bright_white': '\033[97m',
              'end': '\033[0m',  # misc
              'bold': '\033[1m',
              'undelrine': '\033[4m'}  # 这个字典将颜色名称映射到相应的ANSI转义码。例如,'red'映射到'\033[31m',这是一段用于设置红色文本的ANSI转义码。
    # 根据prefix中的颜色名称,使用colors字典中的ANSI转义码给str着色,并在字符串末尾添加一个结束颜色的转义码,以防止影响后续的输出。
    return ''.join(colors[x] for x in prefix) + str + colors['end']   #构建一个彩色字符串, colors['end']则是一个用于结束颜色设置的ANSI转义码,通常是\033[0m,它会将终端的文本颜色重置为默认颜色。
    # 如果prefix = ('red', 'bold'),str = 'Hello, world!',并且colors字典中'red'和'bold'分别对应\033[31m和\033[1m,
    # 那么这行代码将返回一个类似\033[31m\033[1mHello, world!\033[0m的字符串,这会在终端上显示为加粗的红色文本“Hello, world!”。

# Ancillary functions --------------------------------------------------------------------------------------------------
def augment_gaussianblur(img):
    ksize = int(random.choice((3, 5, 7)))
    return cv2.GaussianBlur(img, (ksize, ksize), 0)

def load_image(self, index):
    # loads 1 image from dataset, returns img, original hw, resized hw
    img = self.imgs[index]
    if img is None:  # not cached
        path = self.img_files[index]
        img = cv2.imread(path)  # BGR
        assert img is not None, 'Image Not Found ' + path

        if self.augment and random.random() < self.hyp['gaussianblur']:
            img = augment_gaussianblur(img)

        h0, w0 = img.shape[:2]  # orig hw
        r = min(1.0 * self.img_size[0] / h0, 1.0 * self.img_size[1] / w0)  # resize image to img_size
        if r != 1:  # always resize down, only resize up if training with augmentation
            # interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
            interp = cv2.INTER_LINEAR
            img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
        return img, (h0, w0), img.shape[:2]  # img, hw_original, hw_resized
    else:
        return self.imgs[index], self.img_hw0[index], self.img_hw[index]  # img, hw_original, hw_resized

def get_args():

    parser = argparse.ArgumentParser()  # 创建一个ArgumentParser对象,它是argparse模块的主要接口,用于定义和解析命令行参数。
    parser.add_argument("--src", type=str, default=
                        "./train_road_damage.txt"  # 路径下是jpg文件
                        )
    # 添加一个命令行参数-n或--nworks,这是一个整数类型的参数,用于指定工作的数量(例如,多线程或进程的数量)。
    # 短选项-n和长选项--nworks都指向同一个参数,这提供了灵活性,用户可以根据喜好使用任一选项。
    parser.add_argument('-n', "--nworks", type=int, default=20)  # 默认值是20,意味着如果不从命令行提供该参数,它将采用默认值20。
    # parser.add_argument('--width', type=int, default=1344)
    # parser.add_argument('--height', type=int, default=832)
    parser.add_argument('--width', type=int, default=960)
    parser.add_argument('--height', type=int, default=576)
    # 添加一个命令行参数-l或--layer_num,这是一个整数类型的参数,用于指定层数(神经网络的层数)。
    # 默认值是3,意味着如果不从命令行提供该参数,它将采用默认值3。
    parser.add_argument("-l", '--layer_num', type=int, default=3)  
    # 添加一个命令行参数-a或--anchor_num,这也是一个整数类型的参数,用于指定锚点的数量(在目标检测任务中,用于预测不同大小的目标)。
    # 默认值是3,意味着如果不从命令行提供该参数,它将采用默认值3。
    parser.add_argument("-a", '--anchor_num', type=int, default=3)
    
    opt = parser.parse_args()
    # 调用ArgumentParser对象的parse_args方法来解析命令行参数,并将结果存储在opt变量中。
    # parse_args方法会读取命令行输入,将匹配的参数转换为相应的类型,并填充到命名空间对象中。

    return opt   # 返回解析后的命令行参数,这个返回值是一个命名空间对象,可以通过属性访问的方式获取每个参数的值。

class LoadImagesAndLabels_List(Dataset):  # for training/testing
    def __init__(self, path, img_size=640, batch_size=16, cls_map=None, augment=False, hyp=None, rect=False, image_weights=False,
                 cache_images=False, stride=32, pad=0.0, rank=-1, nworks=10):

        # cls map
        if cls_map:
            self.cls_map = {}
            for key in cls_map:
                for cls in cls_map[key]:
                    self.cls_map[cls] = key
        else:
            self.cls_map = None
        
        print('')
        print('dataset cls map:')
        print(cls_map)
        print(self.cls_map)


        self.img_files = []
        self.label_files = []
        
        try:
            lines = open(path,'r').readlines()
            for line in lines:
                # print(line.strip().split(';'))
                # img_path, label_path,_ = line.strip().split(';')
                
                img_path = line.strip()
                label_path = line.strip().replace(".jpg", ".jpg.json")

                
                self.img_files.append(img_path)
                self.label_files.append(label_path)
        except Exception as e:
            raise Exception('Error loading data from %s: %s\n' % (path, e))

        n = len(self.img_files)
        assert n > 0, 'No images found in %s. See %s' % (path, help_url)
        bi = np.floor(np.arange(n) / batch_size).astype(np.int32)  # batch index
        nb = bi[-1] + 1  # number of batches

        self.debug_num = 10
        self.debug_count = 0
        if os.path.exists('./debug'):
            os.system('rm -rf ./debug')

        self.rank = rank
        self.n = n  # number of images
        self.batch = bi  # batch index of image
        self.img_size = img_size
        self.augment = augment
        self.hyp = hyp
        self.image_weights = image_weights
        # self.rect = False if image_weights else rect
        self.rect = False
        self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
        self.mosaic_border = [-img_size[0] // 2, -img_size[1] // 2]
        self.stride = stride

        # Check cache
        cache_path = path + '.cache'  # cached labels
        print('get cache_path: ' + cache_path + ' # ' + cache_path + '_' + str(self.n))
        cache = self.cache_labels(cache_path, nworks=nworks)  # re-cache
        print('cache_name: ' + cache['name'])

        # Get labels
        self.img_files = cache['imgs']
        # assert self.n == len(self.img_files), 'cache is wrong'
        print(f'ori_img_files nums: {self.n}, cache nums:{len(self.img_files)}')
        self.n = len(self.img_files)
        self.labels = list(cache['labels'])
        self.shapes = np.array(cache['shapes'], dtype=np.float64)

        self.indices = range(self.n)
        
        # Rectangular Training  https://github.com/ultralytics/yolov3/issues/232
        """
        if self.rect:
            # Sort by aspect ratio
            s = self.shapes  # wh
            ar = s[:, 1] / s[:, 0]  # aspect ratio
            irect = ar.argsort()
            self.img_files = [self.img_files[i] for i in irect]
            self.label_files = [self.label_files[i] for i in irect]
            self.labels = [self.labels[i] for i in irect]
            self.shapes = s[irect]  # wh
            ar = ar[irect]

            # Set training image shapes
            shapes = [[1, 1]] * nb
            for i in range(nb):
                ari = ar[bi == i]
                mini, maxi = ari.min(), ari.max()
                if maxi < 1:
                    shapes[i] = [maxi, 1]
                elif mini > 1:
                    shapes[i] = [1, 1 / mini]
            shapes = [shapes[0]*img_size[0], shapes[1]*img_size[1]]
            # self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
            self.batch_shapes = np.ceil(np.array(shapes) / stride + pad).astype(np.int) * stride
        """

        # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
        self.imgs = [None] * n
        if cache_images:
            gb = 0  # Gigabytes of cached images
            pbar = tqdm(range(len(self.img_files)), desc='Caching images')
            self.img_hw0, self.img_hw = [None] * n, [None] * n
            for i in pbar:  # max 10k images
                self.imgs[i], self.img_hw0[i], self.img_hw[i] = load_image(self, i)  # img, hw_original, hw_resized
                gb += self.imgs[i].nbytes
                pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)

    def cache_labels(self, path='labels.cache', nworks=10):
        # Cache dataset labels, check images and read shapes
        cache = {'imgs':[], 'labels':[], 'shapes':[], 'name':''}  # dict
        img_files, label_files = self.img_files, self.label_files
        one = int(len(img_files) / nworks) + 1

        def cache_labels_worker(return_dict, num, img_files_one, label_files_one):
            imgs_patch = []
            labels_patch = []
            shapes_patch = []

            pbar = tqdm(zip(img_files_one, label_files_one), desc='Scanning images', total=len(img_files_one))
            for (img, label) in pbar:
                l = []
                
                if not os.path.isfile(img):
                    print(f'has no {img}')
                    continue

                if not os.path.isfile(label):
                    print(f'has no {label}')
                    continue

                image = Image.open(img)
                image.verify()  # PIL verify
                # _ = io.imread(img)  # skimage verify (from skimage import io)
                shape = exif_size(image)  # image size

                if not ((shape[0] > 9) & (shape[1] > 9)):
                    print(f'image size <10 pixels, {img}')
                    continue

               
                ll = type_map.json2label(label)
                if ll:
                    for i in ll:
                        if i[0] in self.cls_map:
                            out = np.array(i, dtype=np.float32)
                            l.append(out)

                # ans = []
                # for iter in l:
                #     if iter[0] in self.cls_map:
                #         ans.append(iter)
                # l = np.array(ans, dtype=np.float32)
                
                # if len(l) == 0:
                #     l = np.zeros((0, 5), dtype=np.float32)
                # print("------", l)
                if l:
                    l = np.array(l, dtype=np.float32)
                    imgs_patch.append(img)
                    labels_patch.append(l)
                    shapes_patch.append(shape)

            return_dict[num] = {'imgs_patch':imgs_patch, 'labels_patch':labels_patch, 'shapes_patch':shapes_patch}
            # return_dict[num] = {'imgs_patch':imgs_patch, 'labels_patch':labels_patch}

        manager = multiprocessing.Manager()
        return_dict = manager.dict()
        job = []
        for i in range(nworks):
            if i == nworks - 1:
                p = multiprocessing.Process(target=cache_labels_worker, args=(return_dict, i, img_files[one * i:], label_files[one * i:]))
                job.append(p)
                p.start()
                continue
            p = multiprocessing.Process(target=cache_labels_worker, args=(return_dict, i, img_files[one * i:one * (i + 1)], label_files[one * i:one * (i + 1)]))
            job.append(p)
            p.start()

        for p in job:
            p.join()

        for i in return_dict:
            cache['imgs'].extend(return_dict[i]['imgs_patch'])
            cache['labels'].extend(return_dict[i]['labels_patch'])
            cache['shapes'].extend(return_dict[i]['shapes_patch'])

        cache['name'] = path + '_' + str(self.n)
        return cache


def kmean_anchors(path='./data/coco128.yaml', cls_map=None, n=9, img_size=640, thr=4.0, gen=1000, verbose=True, nworks=10):
    """ Creates kmeans-evolved anchors from training dataset

        Arguments:
            path: path to dataset *.yaml, or a loaded dataset   数据集的 YAML 配置文件路径或已加载的数据集。
            n: number of anchors                       锚点的数量,默认为 9。
            img_size: image size used for training      用于训练的图像尺寸,默认为 640。
            thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0  锚点与目标框宽高比阈值的倒数,用于训练的超参数,通常在配置文件中指定,这里默认为 4.0 的倒数,即 0.25。
            gen: generations to evolve anchors using genetic algorithm     用于进化锚点的遗传算法的代数,默认为 1000。
            verbose: print all results
            nworks:并行工作的线程数,默认为 10

        Return:
            k: kmeans evolved anchors

        Usage:
            from utils.autoanchor import *; _ = kmean_anchors()   
    """
    thr = 1. / thr
    prefix = colorstr('blue', 'bold', 'autoanchor') + ': '   # 定义一个前缀字符串,用于在打印信息时添加颜色和样式,使得输出更加醒目和易于区分。

    def metric(k, wh):  # compute metrics  k:一组锚框。  wh:目标框的宽度和高度的数组。  匹配度是通过计算比例度量(ratio metric)来实现的,而不是计算交并比
        r = wh[:, None] / k[None]
        x = torch.min(r, 1. / r).min(2)[0]  # ratio metric
        # x = wh_iou(wh, torch.tensor(k))  # iou metric
        return x, x.max(1)[0]  # x, best_x

    def anchor_fitness(k):  # mutation fitness
        _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
        return (best * (best > thr).float()).mean()  # fitness

    def print_results(k):
        k = k[np.argsort(k.prod(1))]  # sort small to large
        x, best = metric(k, wh0)
        bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n  # best possible recall, anch > thr
        print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr')
        print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, '
              f'past_thr={x[x > thr].mean():.3f}-mean: ', end='')
        for i, x in enumerate(k):
            print('%i,%i' % (round(x[0]), round(x[1])), end=',  ' if i < len(k) - 1 else '\n')  # use in *.cfg
        return k

    if isinstance(path, str):  # *.yaml file
        dataset = LoadImagesAndLabels_List(path, img_size=img_size, nworks=nworks, cls_map = cls_map)
    else:
        dataset = path  # dataset
    
    # Get label wh
    shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
    wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])  # wh


    # Filter  这部分对目标框的宽高进行过滤,移除那些小于3像素的目标框,并警告存在极小目标框的情况。然后,进一步过滤掉小于2像素的目标框。
    i = (wh0 < 3.0).any(1).sum()
    if i:
        print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
    wh = wh0[(wh0 >= 2.0).any(1)]  # filter > 2 pixels
    # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1)  # multiply by random scale 0-1

    # Kmeans calculation
    print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
    s = wh.std(0)  # sigmas for whitening
    k, dist = kmeans(wh / s, n, iter=30)  # points, mean distance
    k *= s
    wh = torch.tensor(wh, dtype=torch.float32)  # filtered
    wh0 = torch.tensor(wh0, dtype=torch.float32)  # unfiltered
    k = print_results(k)

    # Plot
    # k, d = [None] * 20, [None] * 20
    # for i in tqdm(range(1, 21)):
    #     k[i-1], d[i-1] = kmeans(wh / s, i)  # points, mean distance
    # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
    # ax = ax.ravel()
    # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
    # fig, ax = plt.subplots(1, 2, figsize=(14, 7))  # plot wh
    # ax[0].hist(wh[wh[:, 0]<100, 0],400)
    # ax[1].hist(wh[wh[:, 1]<100, 1],400)
    # fig.savefig('wh.png', dpi=200)

    # Evolve
    npr = np.random
    f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1  # fitness, generations, mutation prob, sigma
    pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:')  # progress bar
    for _ in pbar:
        v = np.ones(sh)
        while (v == 1).all():  # mutate until a change occurs (prevent duplicates)
            v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
        kg = (k.copy() * v).clip(min=2.0)
        fg = anchor_fitness(kg)
        if fg > f:
            f, k = fg, kg.copy()
            pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
            if verbose:
                print_results(k)

    return print_results(k)



if __name__ == '__main__':

    t0 = time.time()

    opt = get_args()
    opt.src = opt.src.strip()
    
    # 道路损坏
    cls_map = {
        0: [144],
        1: [145],
        2: [146],
        3: [147],
        }
 
    img_size = np.array([opt.width, opt.height])

    k = kmean_anchors(path=opt.src, cls_map = cls_map, n=opt.anchor_num*opt.layer_num, img_size=img_size, thr=4.0, gen=1000, verbose=False, nworks=20)
    #opt.nworks)
    print(colorstr('blue', 'bold', 'autoanchor') + ':', (k+0.5).astype(np.int).reshape(opt.layer_num, opt.anchor_num * 2).tolist())

    t1 = time.time()
    print(colorstr('blue', 'bold', 'time') + ':', t1-t0)

    c_w, c_h = opt.width//2, opt.height//2
    anchors = np.array(k).reshape(-1, 2).astype(int)
    img = np.ones((opt.height, opt.width, 3), dtype=np.uint8)
    for n, box in enumerate(anchors):
        cv2.rectangle(img, (c_w-box[0]//2, c_h-box[1]//2), (c_w+box[0]//2, c_h+box[1]//2), (0, 255, 0), 1)
        cv2.putText(img, str(n), (c_w-box[0]//2, c_h-box[1]//2), cv2.FONT_HERSHEY_TRIPLEX, 0.5, (0, 255, 0), 1)

    cv2.imshow('anchors', img)
    cv2.waitKey(0)

标签:yolov5,clss,img,means,self,label,props,print,锚框
From: https://blog.csdn.net/KIKI3666/article/details/140694010

相关文章

  • 我如何为 yolov5 制作 gui,从 pytorch 和 opencv 加载到 tkinker?
    请帮助我,我不明白如何使用yolo和tkinker作为gui来制作用于实时检测的gui。以及如何将边界框从pytorch渲染到tkinker?这里是代码:importtorchfrommatplotlibimportpyplotaspltimportnumpyasnpimportcv2model=torch.hub.load('ultralytics/yolov5......
  • __yolov5+deepsort+slowfast win部署
     运行程序报错:yolov5_trt_create...yolov5_trt_createcudaengine...yolov5_trt_createbuffer...yolov5_trt_createstream...yolov5_trt_createdone...createyolov5-trt,instance=000001AFB3B05EC0[07/19/2024-21:23:10][E][TRT]1:[stdArchiveRea......
  • 在2024年部署Yolov5到本地(包含部署以及训练全过程,绝对保姆)
    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档文章目录前言一、pandas是什么?二、使用步骤1.引入库2.读入数据总结前言        刚开始用yolo是用k210入门的,在那里学会了制作数据集以及进行训练,第一次了解到了目标检测,机器视觉,主要是因为电赛......
  • 深度学习印章检测(自动生成数据集+yolov5)
    目录1概述1.1简介1.2演示2软件安装3数据集3.1生成随机字符3.2生成印章图片3.3生成word文件3.4Word转PDF文件3.5PDF转图像4labelme标记5yolov5训练1概述本文将从代码层面的角度来剖析印章数据集如何自动生成,以及如何进行训练与测试,如果希望获取直......
  • 【YOLOv5/v7改进系列】引入SAConv——即插即用的卷积块
    一、导言《DetectoRS:使用递归特征金字塔和可切换空洞卷积进行物体检测》这篇文章提出了一种用于物体检测的新方法,结合了递归特征金字塔(RecursiveFeaturePyramid,RFP)和可切换空洞卷积(SwitchableAtrousConvolution,SAC)。以下是对该研究的优缺点分析:优点:机制灵感来源于人......
  • 深度学习第P9周:YOLOv5-Backbone模块实现
    >-**......
  • 计算机毕业设计Python+Tensorflow小说推荐系统 K-means聚类推荐算法 深度学习 Kears
    2、基于物品协同过滤推荐算法2.1、基于⽤户的协同过滤算法(UserCF)该算法利⽤⽤户之间的相似性来推荐⽤户感兴趣的信息,个⼈通过合作的机制给予信息相当程度的回应(如评分)并记录下来以达到过滤的⽬的进⽽帮助别⼈筛选信息,回应不⼀定局限于特别感兴趣的,特别不感兴趣信息的纪录也相......
  • 一次相对完整的K-means聚类流程
    数据结构(第一题数据)如下:nox1x2x3112520442121184331201742412420455122184361201944712117418122194391221742101211945首先是导入的一些准备工作:#科学计算,启动!importnumpyasnpimportpandasaspdimportseabornassnsimportmatplotlib.pyplotasplti......
  • yolov5 损失函数代码详解
    前言模型的损失计算包括3个方面,分别是:定位损失分类损失置信度损失损失的计算公式如下:损失计算的代码流程也是按照这三大块来计算的。本篇主要讲解yolov5中损失计算的实现,包括损失的逻辑实现,张量操作的细节等。准备工作初始化损失张量的值,获取正样本的信息。lcls=to......
  • yolov5 上手
    0介绍YOLO(YouOnlyLookOnce)是一种流行的物体检测和图像分割模型,由华盛顿大学的约瑟夫-雷德蒙(JosephRedmon)和阿里-法哈迪(AliFarhadi)开发。YOLO于2015年推出,因其高速度和高精确度而迅速受到欢迎。YOLOv5在YOLOv4的基础上进一步提高了模型的性能,并增加了超参数......