首页 > 其他分享 >【CV基础】语义分割任务计算类别权重

【CV基础】语义分割任务计算类别权重

时间:2024-12-09 16:47:56浏览次数:5  
标签:valid 语义 class np classes weights 类别 freq CV

前言

 语义分割任务一般都存在样本类别不平衡的问题,采用类别权重来解决这个问题,本文记录类别权重的计算过程。

类别权重计算的基本思路

 

code

# 20240620: calculate class weights with semantic segmentation gt images.
import os
import numpy as np
import cv2 as cv

void_classes = [2, 4, 10, 12, 16, 17, 19, 21, 25, 30, 32, 33, 35]
valid_classes = [0, 1, 3, 5, 6, 7, 8, 9, 11, 13, 14, 15, 18, 20, 22, 23, 24, 26, 27, 28, 29, 31, 34]

# edgeai-torchvision
def calc_median_frequency(classes, present_num):
    """
    Class balancing by median frequency balancing method.
    Reference: https://arxiv.org/pdf/1411.4734.pdf
       'a = median_freq / freq(c) where freq(c) is the number of pixels
        of class c divided by the total number of pixels in images where
        c is present, and median_freq is the median of these frequencies.'
    """
    class_freq = classes / present_num
    median_freq = np.median(class_freq)
    return median_freq / class_freq

# edgeai-torchvision
def calc_log_frequency(classes, value=1.02):
    """Class balancing by ERFNet method.
       prob = each_sum_pixel / each_sum_pixel.max()
       a = 1 / (log(1.02 + prob)).
    """
    class_freq = classes / classes.sum()  # ERFNet is max, but ERFNet is sum
    # print(class_freq)
    # print(np.log(value + class_freq))
    return 1 / np.log(value + class_freq)

def calculate_class_weight_present(path):
    # edgeai-torchvision
    gtpath = os.path.join(path, 'gt')
    class_counts = np.zeros(len(valid_classes), dtype="f")
    class_freq = np.zeros(len(valid_classes), dtype="f")
    class_weights = np.zeros(len(valid_classes))
    present_num = np.zeros(len(valid_classes), dtype="f")
    for filename in os.listdir(gtpath):  
        # print('filename: ', filename)
        if filename.endswith('.png') or filename.endswith('.jpg'): 
            filepath = os.path.join(gtpath, filename)  
            gtimg = cv.imread(filepath, cv.IMREAD_GRAYSCALE)  
            if gtimg is not None:
                # for i in range(len(valid_classes)):
                #    class_counts[i] += np.sum(gtimg == valid_classes[i])
                for i, classid in enumerate(valid_classes): 
                    num_pixel = np.sum(gtimg == classid)
                    if num_pixel:
                        class_counts[i] += np.sum(gtimg == classid)
                        present_num[i] += 1

    for i, count in enumerate(class_counts):  
        class_freq[i] = count / present_num[i] if present_num[i] > 0 else 0
    # print('class_freq: ', class_freq)
    medval = np.median(class_freq)
    # print('medval: ', medval)
    for i, freq in enumerate(class_freq):  
        # class_weights[i] = medval / freq
        class_weights[i] = medval / freq if freq > 0 else 0
    print(class_weights)
    # for i, weight in enumerate(class_weights):  
    #     print(f"类别 {valid_classes[i]}: 权重 = {weight}")  

    # Normalization
    # # 对权重进行归一化,使它们的和为1(可选步骤,取决于你的应用)
    # class_weights = class_weights / class_weights.sum()
    # for i, weight in enumerate(class_weights):  
    #     print(f"类别 {valid_classes[i]}: 权重 = {weight}")  



def calculate_class_weight_all(path):
    gtpath = os.path.join(path, 'gt')
    class_counts = np.zeros(len(valid_classes), dtype=np.int64)
    class_weights = np.zeros(len(valid_classes))
    for filename in os.listdir(gtpath):  
        # print('filename: ', filename)
        if filename.endswith('.png') or filename.endswith('.jpg'): 
            filepath = os.path.join(gtpath, filename)  
            gtimg = cv.imread(filepath, cv.IMREAD_GRAYSCALE)  
            if gtimg is not None:
                # for i in range(len(valid_classes)):
                #    class_counts[i] += np.sum(gtimg == valid_classes[i])
                for i, classid in enumerate(valid_classes): 
                    class_counts[i] += np.sum(gtimg == classid)

    total_pixels = class_counts.sum()
    # print('class_counts: \n', class_counts)
    # print('totalpixel: ', total_pixels)
    medval = np.median(class_counts)
    # print('medval: ', medval)
    for i, count in enumerate(class_counts):  
        class_weights[i] = medval / count if count > 0 else 0
    print(class_weights)
    # for i, weight in enumerate(class_weights):  
    #     print(f"类别 {valid_classes[i]}: 权重 = {weight}")  
    # Normalization

    # # 对权重进行归一化,使它们的和为1(可选步骤,取决于你的应用)
    # class_weights = class_weights / class_weights.sum()
    # for i, weight in enumerate(class_weights):  
    #     print(f"类别 {valid_classes[i]}: 权重 = {weight}")  


if __name__ == "__main__":
    path = os.path.dirname(os.path.realpath(__file__))
    # calculate_class_weight_all(path)
    # print("\n\n\n  start present \n\n\n")
    calculate_class_weight_present(path)
View Code

 

参考

1. Median Frequency Balancing 理解; 2. median frequency balancing-CSDN博客; 3. edgeai-torchvision/torchvision/edgeailite/xvision/datasets/calculate_class_weights.py at r8.1 · Texa; 4. 语义分割中的类别不平衡的权重计算_语义分割下载的权重-CSDN博客; 完

标签:valid,语义,class,np,classes,weights,类别,freq,CV
From: https://www.cnblogs.com/happyamyhope/p/18414418

相关文章

  • 【opencv基础】resize使用的问题
    前言最近语义分割任务的gt文件resize前后标签数值发生了错误,最后发现是resize函数调用过程中参数调用出现错误,主要是参数顺序,记录之。问题分析源码 结果: 虽然使用最近邻插值,但是resize后和预想的数值不一致,多方分析、调试,最后小伙伴发现是调用函数参数不正确。opencv官......
  • Mitel MiCollab企业协作平台存在任意文件读取漏洞(CVE-2024-41713)
    免责声明:本文旨在提供有关特定漏洞的深入信息,帮助用户充分了解潜在的安全风险。发布此信息的目的在于提升网络安全意识和推动技术进步,未经授权访问系统、网络或应用程序,可能会导致法律责任或严重后果。因此,作者不对读者基于本文内容所采取的任何行为承担责任。读者在使用本......
  • MitelMiCollab 身份绕过导致任意文件读取漏洞复现(CVE-2024-41713)
    0x01产品描述:        MitelMiCollab是一个企业协作平台,它将各种通信工具整合到一个应用程序中,提供语音和视频通话、消息传递、状态信息、音频会议、移动支持和团队协作功能。0x02漏洞描述:        MitelMiCollab的NuPoint统一消息(NPM)组件中存在身......
  • GA/T1400视图库平台EasyCVR宇视设备视频平台:RTSP视频流不能在网页端播放的问题与解决
    在现代视频监控系统中,RTSP(实时流协议)是一种广泛应用于网络摄像机的协议,用于控制和传输音视频数据。然而,当尝试在网页端播放RTSP视频流时,我们可能会遇到一系列挑战。本文将探讨这些常见问题及其解决方案,并介绍如何使用GA/T1400视图库平台EasyCVR来有效地处理和播放RTSP视频流。通过......
  • ISUP协议视频平台EasyCVR宇视设备视频平台:天地伟业安防摄像头忘记密码如何处理
    在数字化时代,安防设备已成为保护个人和企业安全的重要工具。然而,随着技术的进步和设备的智能化,我们可能会遇到一些常见的问题,比如忘记密码。这不仅影响设备的使用,还可能带来安全隐患。本文将为您提供关于如何处理天地伟业安防摄像头忘记密码的问题,以及如何重置密码的详细步骤。无......
  • EHOME视频平台EasyCVR私有化视频平台:安防监控网络摄像机根据接口类型可以分为哪几类?
    在现代监控系统中,摄像机作为捕捉视频信息的核心设备,其输出接口的多样性对于视频信号的传输和应用至关重要。随着技术的发展,摄像机的输出接口已经从单一的模拟信号发展到了多种数字接口,以适应不同的监控环境和需求。本文将为您详细介绍各种摄像机输出接口的类型及其应用场景,帮助您......
  • 国标GB28181视频平台EasyCVR视频融合平台:什么是电梯五方通话?怎样施工安装?
    在现代城市建筑中,电梯已成为不可或缺的垂直交通工具,其安全性和可靠性对人们的日常生活和工作至关重要。随着技术的发展,电梯五方通话系统作为电梯安全的重要组成部分,已经广泛应用于各大楼宇和住宅区。本文将详细介绍电梯五方通话系统的概念、功能、安装说明以及其在现代视频监控管......
  • HTML为什么要语义化?语义化有什么好处?
    HTML语义化是指使用合适的HTML标签来清晰地表达网页内容的结构和含义,而不是仅仅关注网页的外观呈现。例如,使用<article>表示文章内容,<h1>到<h6>表示不同级别的标题,<nav>表示导航菜单,而不是用<div>和<span>等通用标签来随意包裹内容。HTML语义化带来的好处有很多,主要体现在以下几......
  • 使用深度学习框架进行街景语义分割-数据准备、模型选择、模型训练、模型评估及如何使
    使用深度学习框架进行街景语义分割-数据准备、模型选择、模型训练、模型评估以及如何使用PyQt5构建一个简单的应用来展示分割结果街景语义分割数据集数据集:jingjingji,长三角,珠三角共49个城市群百度街景(全景)数据,50m采样。包含街景图像、shp、csv等数据处理结果文件。......
  • 如何使用yolov8纽扣电池缺陷检测数据集进行训练,并提供详细的步骤和代码示例 纽扣电池
    纽扣电池缺陷检测数据集包含3种缺陷类别(脏污、凹陷、划痕),已经划分为训练集和验证集,有xml和txt标签,yolo可用,共1110张类别名称names:0:dirty1:depression2:scratch数据集包含3种缺陷类别(脏污、凹陷、划痕),已经划分为训练集和验证集,并且标注为YOLO格式和VOC格式,亲......