前言
语义分割任务一般都存在样本类别不平衡的问题,采用类别权重来解决这个问题,本文记录类别权重的计算过程。
类别权重计算的基本思路
code
# 20240620: calculate class weights with semantic segmentation gt images. import os import numpy as np import cv2 as cv void_classes = [2, 4, 10, 12, 16, 17, 19, 21, 25, 30, 32, 33, 35] valid_classes = [0, 1, 3, 5, 6, 7, 8, 9, 11, 13, 14, 15, 18, 20, 22, 23, 24, 26, 27, 28, 29, 31, 34] # edgeai-torchvision def calc_median_frequency(classes, present_num): """ Class balancing by median frequency balancing method. Reference: https://arxiv.org/pdf/1411.4734.pdf 'a = median_freq / freq(c) where freq(c) is the number of pixels of class c divided by the total number of pixels in images where c is present, and median_freq is the median of these frequencies.' """ class_freq = classes / present_num median_freq = np.median(class_freq) return median_freq / class_freq # edgeai-torchvision def calc_log_frequency(classes, value=1.02): """Class balancing by ERFNet method. prob = each_sum_pixel / each_sum_pixel.max() a = 1 / (log(1.02 + prob)). """ class_freq = classes / classes.sum() # ERFNet is max, but ERFNet is sum # print(class_freq) # print(np.log(value + class_freq)) return 1 / np.log(value + class_freq) def calculate_class_weight_present(path): # edgeai-torchvision gtpath = os.path.join(path, 'gt') class_counts = np.zeros(len(valid_classes), dtype="f") class_freq = np.zeros(len(valid_classes), dtype="f") class_weights = np.zeros(len(valid_classes)) present_num = np.zeros(len(valid_classes), dtype="f") for filename in os.listdir(gtpath): # print('filename: ', filename) if filename.endswith('.png') or filename.endswith('.jpg'): filepath = os.path.join(gtpath, filename) gtimg = cv.imread(filepath, cv.IMREAD_GRAYSCALE) if gtimg is not None: # for i in range(len(valid_classes)): # class_counts[i] += np.sum(gtimg == valid_classes[i]) for i, classid in enumerate(valid_classes): num_pixel = np.sum(gtimg == classid) if num_pixel: class_counts[i] += np.sum(gtimg == classid) present_num[i] += 1 for i, count in enumerate(class_counts): class_freq[i] = count / present_num[i] if present_num[i] > 0 else 0 # print('class_freq: ', class_freq) medval = np.median(class_freq) # print('medval: ', medval) for i, freq in enumerate(class_freq): # class_weights[i] = medval / freq class_weights[i] = medval / freq if freq > 0 else 0 print(class_weights) # for i, weight in enumerate(class_weights): # print(f"类别 {valid_classes[i]}: 权重 = {weight}") # Normalization # # 对权重进行归一化,使它们的和为1(可选步骤,取决于你的应用) # class_weights = class_weights / class_weights.sum() # for i, weight in enumerate(class_weights): # print(f"类别 {valid_classes[i]}: 权重 = {weight}") def calculate_class_weight_all(path): gtpath = os.path.join(path, 'gt') class_counts = np.zeros(len(valid_classes), dtype=np.int64) class_weights = np.zeros(len(valid_classes)) for filename in os.listdir(gtpath): # print('filename: ', filename) if filename.endswith('.png') or filename.endswith('.jpg'): filepath = os.path.join(gtpath, filename) gtimg = cv.imread(filepath, cv.IMREAD_GRAYSCALE) if gtimg is not None: # for i in range(len(valid_classes)): # class_counts[i] += np.sum(gtimg == valid_classes[i]) for i, classid in enumerate(valid_classes): class_counts[i] += np.sum(gtimg == classid) total_pixels = class_counts.sum() # print('class_counts: \n', class_counts) # print('totalpixel: ', total_pixels) medval = np.median(class_counts) # print('medval: ', medval) for i, count in enumerate(class_counts): class_weights[i] = medval / count if count > 0 else 0 print(class_weights) # for i, weight in enumerate(class_weights): # print(f"类别 {valid_classes[i]}: 权重 = {weight}") # Normalization # # 对权重进行归一化,使它们的和为1(可选步骤,取决于你的应用) # class_weights = class_weights / class_weights.sum() # for i, weight in enumerate(class_weights): # print(f"类别 {valid_classes[i]}: 权重 = {weight}") if __name__ == "__main__": path = os.path.dirname(os.path.realpath(__file__)) # calculate_class_weight_all(path) # print("\n\n\n start present \n\n\n") calculate_class_weight_present(path)View Code
参考
1. Median Frequency Balancing 理解; 2. median frequency balancing-CSDN博客; 3. edgeai-torchvision/torchvision/edgeailite/xvision/datasets/calculate_class_weights.py at r8.1 · Texa; 4. 语义分割中的类别不平衡的权重计算_语义分割下载的权重-CSDN博客; 完 标签:valid,语义,class,np,classes,weights,类别,freq,CV From: https://www.cnblogs.com/happyamyhope/p/18414418