构建图像分类数据集
视频链接:https://www.bilibili.com/video/BV1Jd4y1T7rw/?vd_source=ec0dfe3d40081b44c0160eacc0f39d0f
脚本文件:https://github.com/TommyZihao/Train_Custom_Dataset/tree/main/图像分类
一、安装配置环境
之前已跑过相关深度学习流程,已配好
二、图像采集与分类
1、图像采集与分类
子豪兄用的是爬虫爬取的(采集的同时其实就已在做分类了),由于我缺乏相关爬虫经验,后续也没有类似需求,所以选择直接下载子豪兄提供好的81种水果数据集
采集图像的几个原则:要尽可能的涵盖这一类别的所有形态,以避免OOD(Out-Of-Distribution)问题
不同尺寸、比例的图像
不同拍摄环境(光照、设备、拍摄角度、遮挡、远近、大小)
不同形态(完整西瓜、切瓣西瓜、切块西瓜)
不同部位(全瓜、瓜皮、瓜瓤、瓜子)
不同时期(瓜秧、小瓜、大瓜)
不同背景(人物、菜地、抠图)
不同图像域(照片、漫画、剪贴画、油画)
2、删除多余文件
- 垃圾文件
在win、mac系统下,很可能会出现'__MACOSX'、'.DS_Store'、'.ipynb_checkpoints'这些多余文件,这些文件我们都要进行删除才方便进行深度学习。
!for i in `find . -iname '__MACOSX'`; do rm -rf $i;done
- gif文件
dataset_path = 'dataset_delete_test'
for fruit in tqdm(os.listdir(dataset_path)):
for file in os.listdir(os.path.join(dataset_path, fruit)):
file_path = os.path.join(dataset_path, fruit, file)
img = cv2.imread(file_path)
if img is None:
print(file_path, '读取错误,删除')
os.remove(file_path)
- 非三通道图像
import numpy as np
from PIL import Image
for fruit in tqdm(os.listdir(dataset_path)):
for file in os.listdir(os.path.join(dataset_path, fruit)):
file_path = os.path.join(dataset_path, fruit, file)
img = np.array(Image.open(file_path))
try:
channel = img.shape[2]
if channel != 3:
print(file_path, '非三通道,删除')
os.remove(file_path)
except:
print(file_path, '非三通道,删除')
os.remove(file_path)
- 处理完后要查看是否又生成了多余的.ipynb_checkpoints目录
三、下载数据集并简单进行样本信息统计
1、下载数据集
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit81/fruit81_full.zip
2、统计图像尺寸、比例
- 图像尺寸
df = pd.DataFrame()
for fruit in tqdm(os.listdir()): # 遍历每个类别
os.chdir(fruit)
for file in os.listdir(): # 遍历每张图像
try:
img = cv2.imread(file)
df = df.append({'类别':fruit, '文件名':file, '图像宽':img.shape[1], '图像高':img.shape[0]}, ignore_index=True)
except:
print(os.path.join(fruit, file), '读取错误')
os.chdir('../')
os.chdir('../')
- 比例情况
from scipy.stats import gaussian_kde
from matplotlib.colors import LogNorm
x = df['图像宽']
y = df['图像高']
xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)
# Sort the points by density, so that the densest points are plotted last
idx = z.argsort()
x, y, z = x[idx], y[idx], z[idx]
plt.figure(figsize=(10,10))
# plt.figure(figsize=(12,12))
plt.scatter(x, y, c=z, s=5, cmap='Spectral_r')
# plt.colorbar()
# plt.xticks([])
# plt.yticks([])
plt.tick_params(labelsize=15)
xy_max = max(max(df['图像宽']), max(df['图像高']))
plt.xlim(xmin=0, xmax=xy_max)
plt.ylim(ymin=0, ymax=xy_max)
plt.ylabel('height', fontsize=25)
plt.xlabel('width', fontsize=25)
plt.savefig('图像尺寸分布.pdf', dpi=120, bbox_inches='tight')
plt.show()
3、拍摄地点可视化
未有相关需求,后续有需要再查看视频:https://www.bilibili.com/video/BV1m3411A786
四、划分训练集测试集
我下载的是已经划分好训练集和测试集的数据集,先前也有划分经验,作为简单回顾
1、创建文件夹:
类别较少的情况下可以手动创建,但是对于81个水果样本手动创建显然不理想,这里使用os.mkdir创建
# 创建 train 文件夹
os.mkdir(os.path.join(dataset_path, 'train'))
# 创建 test 文件夹
os.mkdir(os.path.join(dataset_path, 'val'))
# 在 train 和 test 文件夹中创建各类别子文件夹
for fruit in classes:
os.mkdir(os.path.join(dataset_path, 'train', fruit))
os.mkdir(os.path.join(dataset_path, 'val', fruit))
2、将数据集进行划分
一般划分的比例都是2:8,设置的随机种子我习惯设置为42
test_frac = 0.2 # 测试集比例
random.seed(123) # 随机数种子,便于复现
df = pd.DataFrame()
print('{:^18} {:^18} {:^18}'.format('类别', '训练集数据个数', '测试集数据个数'))
for fruit in classes: # 遍历每个类别
# 读取该类别的所有图像文件名
old_dir = os.path.join(dataset_path, fruit)
images_filename = os.listdir(old_dir)
random.shuffle(images_filename) # 随机打乱
# 划分训练集和测试集
testset_numer = int(len(images_filename) * test_frac) # 测试集图像个数
testset_images = images_filename[:testset_numer] # 获取拟移动至 test 目录的测试集图像文件名
trainset_images = images_filename[testset_numer:] # 获取拟移动至 train 目录的训练集图像文件名
# 移动图像至 test 目录
for image in testset_images:
old_img_path = os.path.join(dataset_path, fruit, image) # 获取原始文件路径
new_test_path = os.path.join(dataset_path, 'val', fruit, image) # 获取 test 目录的新文件路径
shutil.move(old_img_path, new_test_path) # 移动文件
# 移动图像至 train 目录
for image in trainset_images:
old_img_path = os.path.join(dataset_path, fruit, image) # 获取原始文件路径
new_train_path = os.path.join(dataset_path, 'train', fruit, image) # 获取 train 目录的新文件路径
shutil.move(old_img_path, new_train_path) # 移动文件
# 删除旧文件夹
assert len(os.listdir(old_dir)) == 0 # 确保旧文件夹中的所有图像都被移动走
shutil.rmtree(old_dir) # 删除文件夹
# 工整地输出每一类别的数据个数
print('{:^18} {:^18} {:^18}'.format(fruit, len(trainset_images), len(testset_images)))
# 保存到表格中
df = df.append({'class':fruit, 'trainset':len(trainset_images), 'testset':len(testset_images)}, ignore_index=True)
# 重命名数据集文件夹
shutil.move(dataset_path, dataset_name+'_split')
# 数据集各类别数量统计表格
df['total'] = df['trainset'] + df['testset']
五、图像文件夹的可视化
1、图像文件夹中图像的显示
这一步就是便于简单查看图片是否符合我们最初的图像采集要求,比一张的查看效率更高
# 指定要可视化图像的文件夹
folder_path = 'fruit81_split/train/西瓜'
# 可视化图像的个数
N = 36
# n 行 n 列
n = math.floor(np.sqrt(N))
n
# 读取文件夹中的所有图像
images = []
for each_img in os.listdir(folder_path)[:N]:
img_path = os.path.join(folder_path, each_img)
img_bgr = cv2.imread(img_path)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
images.append(img_rgb)
len(images)
# 画图展示
fig = plt.figure(figsize=(10, 10))
grid = ImageGrid(fig, 111, # 类似绘制子图 subplot(111)
nrows_ncols=(n, n), # 创建 n 行 m 列的 axes 网格
axes_pad=0.02, # 网格间距
share_all=True
)
# 遍历每张图像
for ax, im in zip(grid, images):
ax.imshow(im)
ax.axis('off')
plt.tight_layout()
plt.show()
2、图像数量的可视化
之前在进行数据集划分时候就已经知道了图像的分布情况,这一步其实就是将数据进行绘图展示
图像数量柱状图可视化
- 全部样本的
# 指定可视化的特征
feature = 'total'
## feature = 'trainset'
## feature = 'testset'
df = df.sort_values(by=feature, ascending=False)
plt.figure(figsize=(22, 7))
x = df['class']
y = df[feature]
plt.bar(x, y, facecolor='#1f77b4', edgecolor='k')
plt.xticks(rotation=90)
plt.tick_params(labelsize=15)
plt.xlabel('类别', fontsize=20)
plt.ylabel('图像数量', fontsize=20)
# plt.savefig('各类别图片数量.pdf', dpi=120, bbox_inches='tight')
plt.show()
- 分好train和test的柱状图
plt.figure(figsize=(22, 7))
x = df['class']
y1 = df['testset']
y2 = df['trainset']
width = 0.55 # 柱状图宽度
plt.xticks(rotation=90) # 横轴文字旋转
plt.bar(x, y1, width, label='测试集')
plt.bar(x, y2, width, label='训练集', bottom=y1)
plt.xlabel('类别', fontsize=20)
plt.ylabel('图像数量', fontsize=20)
plt.tick_params(labelsize=13) # 设置坐标文字大小
plt.legend(fontsize=16) # 图例
# 保存为高清的 pdf 文件
plt.savefig('各类别图像数量.pdf', dpi=120, bbox_inches='tight')
plt.show()
还有很多其他的展现形式,这部分内容很简单,用excel都能做,不过多了解先。
标签:task1,plt,df,分类,fruit,图像,path,os From: https://www.cnblogs.com/cauwj/p/17058433.html