4.7 跨模态配对实战:基于深度学习的图文匹配系统
本项目旨在构建一个多模态学习系统,专注于处理图像和文本数据的配对任务,主要基于CUHK-PEDES数据集。本项目实现了多种深度学习模型,包括LSTM、MobileNetV1和ResNet,以分别处理文本和图像特征的提取与融合。通过这些模型的结合,系统能够有效地理解和匹配图像与对应的文本描述,旨在提高图像检索和描述生成的精度。整体架构包括数据处理、模型训练和评估环节,充分利用深度学习技术提升多模态任务的性能。
实例4-30:基于深度学习的图文匹配系统(源码路径:codes\4\Image-Text-Matching)
在本系统中,通过深度学习模型有效地将图像和其对应的文本描述进行匹配。具体而言,本项目采用了多种神经网络架构,包括LSTM用于文本特征提取,以及MobileNetV1和ResNet用于图像特征提取。这些模型的输出经过处理后进行联合嵌入,最终实现图像与文本之间的高效对应。另外,还结合了损失函数的设计,特别是约束损失,确保具有相同标签的图像和文本在特征空间中更接近,从而提高了匹配的准确性。通过在CUHK-PEDES数据集上的训练与评估,项目展示了多模态学习在图文配对任务中的有效性与潜力。
1. 工具类
在本项目的“utils”目录中提供了一系列实用工具,用于数据处理、统计分析和可视化,帮助用户对数据集进行深入分析,计算图像和标题的数量,并可视化训练过程中的损失和准确率。这些工具为项目的后续数据分析和结果展示奠定了基础,提升了工作效率。
(1)文件directory.py提供了与文件和目录操作相关的功能,主要用于确保在进行数据读写时所需的目录存在,并且能够将数据保存为 JSON 格式。
import os
import json
def makedir(root):
if not os.path.exists(root):
os.makedirs(root)
def write_json(data, root):
with open(dir, 'w') as f:
json.dump(data, f)
def check_exists(root):
if os.path.exists(root):
return True
return False
对上述代码的具体说明如下所示:
- makedir(root):检查指定的目录是否存在,如果不存在,则创建该目录。
- write_json(data, root):将数据以 JSON 格式写入指定的文件。注意,此处的 dir 应该更改为 root,以确保函数能正确运行。
- check_exists(root):检查指定的路径是否存在,如果存在,则返回 True,否则返回 False。
(2)文件metric.py实现了“图像-文本”匹配任务所需的度量和损失函数,包括计算成对距离、独热编码、约束损失,以及交叉模态投影分类和匹配损失。此外,文件中还定义了用于计算模型性能的 Top-K 准确率的函数和管理平均值的工具类。通过这些功能,文件metric.py 支持模型的训练、评估和性能监控,帮助优化“图像-文本”的匹配效果。
- 类EMA实现了指数移动平均操作,用于更新和存储参数的平滑值。
class EMA():
def __init__(self, decay=0.999):
self.decay = decay
self.shadow = {}
def register(self, name, val):
self.shadow[name] = val.cpu().detach()
def get(self, name):
return self.shadow[name]
def update(self, name, x):
assert name in self.shadow
new_average = (1.0 - self.decay) * x.cpu().detach() + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
- 方法pairwise_distance(A, B)的功能是计算两个点集之间的成对距离,返回距离矩阵。
def pairwise_distance(A, B):
A_square = torch.sum(A * A, dim=1, keepdim=True)
B_square = torch.sum(B * B, dim=1, keepdim=True)
distance = A_square + B_square.t() - 2 * torch.matmul(A, B.t())
return distance
- 方法one_hot_coding(index, k)的功能是将索引转换为独热编码格式。
def constraints_old(features, labels):
distance = pairwise_distance(features, features)
labels_reshape = torch.reshape(labels, (features.shape[0], 1))
labels_dist = labels_reshape - labels_reshape.t()
labels_mask = (labels_dist == 0).float()
num = torch.sum(labels_mask) - features.shape[0]
if num == 0:
con_loss = 0.0
else:
con_loss = torch.sum(distance * labels_mask) / num
return con_loss
- 方法constraints_old 的功能是计算约束损失,该损失用于衡量特征之间的距离。通过计算特征的成对距离,并根据标签构建匹配掩码,进而求出匹配对的平均距离,以评估模型的特征学习效果。
def constraints_old(features, labels):
distance = pairwise_distance(features, features)
labels_reshape = torch.reshape(labels, (features.shape[0], 1))
labels_dist = labels_reshape - labels_reshape.t()
labels_mask = (labels_dist == 0).float()
num = torch.sum(labels_mask) - features.shape[0]
if num == 0:
con_loss = 0.0
else:
con_loss = torch.sum(distance * labels_mask) / num
return con_loss
- 方法constraints 的功能是改进的约束损失计算方法,与 constraints_old 类似,但采用了不同的方式计算每个类别的损失。方法constraints通过遍历标签中的唯一值,选择与每个类别相关的特征,并计算这些特征之间的成对距离,从而得到更精确的约束损失。
def constraints(features, labels):
labels = torch.reshape(labels, (labels.shape[0],1))
con_loss = AverageMeter()
index_dict = {k.item() for k in labels}
for index in index_dict:
labels_mask = (labels == index)
feas = torch.masked_select(features, labels_mask)
feas = feas.view(-1, features.shape[1])
distance = pairwise_distance(feas, feas)
num = feas.shape[0] * (feas.shape[0] - 1)
loss = torch.sum(distance) / num
con_loss.update(loss, n = num / 2)
return con_loss.avg
- 方法constraints_loss 的功能是计算整个数据集的约束损失,首先收集所有图像和文本的嵌入特征,并根据给定的标签计算图像和文本的约束损失,这为后续的模型训练和评估提供了约束损失值。
def constraints_loss(data_loader, network, args):
network.eval()
max_size = args.batch_size * len(data_loader)
images_bank = torch.zeros((max_size, args.feature_size)).cuda()
text_bank = torch.zeros((max_size,args.feature_size)).cuda()
labels_bank = torch.zeros(max_size).cuda()
index = 0
con_images = 0.0
con_text = 0.0
with torch.no_grad():
for images, captions, labels, captions_length in data_loader:
images = images.cuda()
captions = captions.cuda()
interval = images.shape[0]
image_embeddings, text_embeddings = network(images, captions, captions_length)
images_bank[index: index + interval] = image_embeddings
text_bank[index: index + interval] = text_embeddings
labels_bank[index: index + interval] = labels
index = index + interval
images_bank = images_bank[:index]
text_bank = text_bank[:index]
labels_bank = labels_bank[:index]
if args.constraints_text:
con_text = constraints(text_bank, labels_bank)
if args.constraints_images:
con_images = constraints(images_bank, labels_bank)
return con_images, con_text
- 类Loss的功能是定义模型的损失函数,它根据输入的参数初始化权重,并实现了交叉模态投影分类损失(CMPC)和交叉模态投影匹配损失(CMPM)的计算。这些损失函数用于训练和优化图像和文本嵌入的对齐。
class Loss(nn.Module):
def __init__(self, args):
super(Loss, self).__init__()
self.CMPM = args.CMPM
self.CMPC = args.CMPC
self.epsilon = args.epsilon
self.num_classes = args.num_classes
if args.resume:
checkpoint = torch.load(args.model_path)
self.W = Parameter(checkpoint['W'])
print('=========> Loading in parameter W from pretrained models')
else:
self.W = Parameter(torch.randn(args.feature_size, args.num_classes))
self.init_weight()
def init_weight(self):
nn.init.xavier_uniform_(self.W.data, gain=1)
- 方法compute_cmpc_loss 的功能是计算交叉模态投影分类损失(CMPC),该损失用于评估图像和文本嵌入的分类能力。它通过对图像和文本嵌入进行归一化和投影,计算交叉熵损失,以确保模型在图像和文本之间的相互映射。
def compute_cmpc_loss(self, image_embeddings, text_embeddings, labels):
"""
criterion = nn.CrossEntropyLoss(reduction='mean')
self.W_norm = self.W / self.W.norm(dim=0)
#labels_onehot = one_hot_coding(labels, self.num_classes).float()
image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
image_proj_text = torch.sum(image_embeddings * text_norm, dim=1, keepdim=True) * text_norm
text_proj_image = torch.sum(text_embeddings * image_norm, dim=1, keepdim=True) * image_norm
image_logits = torch.matmul(image_proj_text, self.W_norm)
text_logits = torch.matmul(text_proj_image, self.W_norm)
cmpc_loss = criterion(image_logits, labels) + criterion(text_logits, labels)
image_pred = torch.argmax(image_logits, dim=1)
text_pred = torch.argmax(text_logits, dim=1)
image_precision = torch.mean((image_pred == labels).float())
text_precision = torch.mean((text_pred == labels).float())
return cmpc_loss, image_precision, text_precision
- 方法compute_cmpm_loss 的功能是计算交叉模态投影匹配损失(CMPM),用于评估图像和文本嵌入的匹配能力。它通过计算正负样本对之间的相似性,并利用归一化标签掩码来优化嵌入的匹配性能。
def compute_cmpm_loss(self, image_embeddings, text_embeddings, labels):
batch_size = image_embeddings.shape[0]
labels_reshape = torch.reshape(labels, (batch_size, 1))
labels_dist = labels_reshape - labels_reshape.t()
labels_mask = (labels_dist == 0)
image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
image_proj_text = torch.matmul(image_embeddings, text_norm.t())
text_proj_image = torch.matmul(text_embeddings, image_norm.t())
# normalize the true matching distribution
labels_mask_norm = labels_mask.float() / labels_mask.float().norm(dim=1)
i2t_pred = F.softmax(image_proj_text, dim=1)
#i2t_loss = i2t_pred * torch.log((i2t_pred + self.epsilon)/ (labels_mask_norm + self.epsilon))
i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + self.epsilon))
t2i_pred = F.softmax(text_proj_image, dim=1)
#t2i_loss = t2i_pred * torch.log((t2i_pred + self.epsilon)/ (labels_mask_norm + self.epsilon))
t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + self.epsilon))
cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1))
sim_cos = torch.matmul(image_norm, text_norm.t())
pos_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask))
neg_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask == 0))
return cmpm_loss, pos_avg_sim, neg_avg_sim
- 方法forward 的功能是执行前向传播,计算总体损失。根据输入的图像和文本嵌入,分别调用 CMPC 和 CMPM 的损失计算方法,并返回各类损失、精度以及正负样本的平均相似度。
def forward(self, image_embeddings, text_embeddings, labels):
cmpm_loss = 0.0
cmpc_loss = 0.0
image_precision = 0.0
text_precision = 0.0
neg_avg_sim = 0.0
pos_avg_sim =0.0
if self.CMPM:
cmpm_loss, pos_avg_sim, neg_avg_sim = self.compute_cmpm_loss(image_embeddings, text_embeddings, labels)
if self.CMPC:
cmpc_loss, image_precision, text_precision = self.compute_cmpc_loss(image_embeddings, text_embeddings, labels)
loss = cmpm_loss + cmpc_loss
return cmpm_loss, cmpc_loss, loss, image_precision, text_precision, pos_avg_sim, neg_avg_sim
- 类AverageMeter 的功能是用于计算和存储当前值和平均值的工具。它提供了更新和重置的方法,以便在训练过程中跟踪损失和精度的变化,方便模型性能监控。
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += n * val
self.count += n
self.avg = self.sum / self.count
- 方法compute_topk 的功能是计算给定查询和图库的 Top-K 精度,它通过计算查询和图库之间的余弦相似度,并获取前 K 个相似项,以评估模型的检索性能。
def compute_topk(query, gallery, target_query, target_gallery, k=[1,10], reverse=False):
result = []
query = query / query.norm(dim=1,keepdim=True)
gallery = gallery / gallery.norm(dim=1,keepdim=True)
sim_cosine = torch.matmul(query, gallery.t())
result.extend(topk(sim_cosine, target_gallery, target_query, k=[1,10]))
if reverse:
result.extend(topk(sim_cosine, target_query, target_gallery, k=[1,10], dim=0))
return result
- 方法topk 的功能是实现 Top-K 精度计算的具体逻辑,它通过对相似度进行排序,找到正确标签在前 K 个预测中的数量,并计算其在总样本中的比例,最终返回各个 K 值对应的精度结果。
def topk(sim, target_gallery, target_query, k=[1,10], dim=1):
result = []
maxk = max(k)
size_total = len(target_gallery)
_, pred_index = sim.topk(maxk, dim, True, True)
pred_labels = target_gallery[pred_index]
if dim == 1:
pred_labels = pred_labels.t()
correct = pred_labels.eq(target_query.view(1,-1).expand_as(pred_labels))
for topk in k:
#correct_k = torch.sum(correct[:topk]).float()
correct_k = torch.sum(correct[:topk], dim=0)
correct_k = torch.sum(correct_k > 0).float()
result.append(correct_k * 100 / size_total)
return result
(3)文件statistics.py的功能是对数据集进行统计分析和可视化操作,具体实现代码如下所示。
def count_ids(root, flag=0):
ids_dict = {}
captions = 0
with open(root,'r') as f:
info = json.load(f)
for data in info:
label = data['id'] - flag
ids_dict[label] = ids_dict.get(label,0) + 1
captions += len(data['captions'])
return ids_dict, captions
def count_images(root):
info = pickle.load(open(root, 'rb'))['label_range']
images_dict = {}
for label in info:
num_images = len(info[label]) - 1
images_dict[num_images] = images_dict.get(num_images, 0) + 1
return images_dict
def count_captions(root):
info = pickle.load(open(root, 'rb'))['label_range']
captions_dict = {}
for label in info:
for index in range(0, len(info[label]) - 1):
num_captions = info[label][index] - info[label][index - 1]
captions_dict[num_captions] = captions_dict.get(num_captions, 0) + 1
return captions_dict
def visualize(data):
keys = list(data.keys())
keys.sort()
values = []
for key in keys:
values.append(data[key])
plt.figure('#captions in each image')
a = plt.bar(keys, values)
#plt.yticks([1,5,1,100,200,500,1000,5000])
plt.xticks(list(range(min(keys), max(keys) + 1, 1)))
autolabel(a)
plt.xlim(min(keys) - 1, max(keys) + 1)
plt.show()
def autolabel(rects):
for rect in rects:
height = rect.get_height()
plt.text(rect.get_x() + rect.get_width() / 2 - 0.2, height + 2, '%s' % int(height))
if __name__ == "__main__":
root = 'data/processed_data/train_sort.pkl'
data = count_images(root)
print(data)
visualize(data)
对上述代码的具体说明如下所示:
- 函数count_ids:用于统计每个唯一标识符的出现次数以及总字幕数量。
- 函数count_images:用于计算每个标签下的图像数量并返回图像数量的分布。
- 函数count_captions:用于统计每个标签下的字幕数量并记录其频率。
- 函数visualize:负责生成柱状图,展示输入数据的可视化结果。
- 函数autolabel:用于在柱状图上自动标记每个条形的高度。
- 在主程序中,文件加载了处理过的数据集,计算图像数量,并可视化结果。
(4)文件visualize.py的功能是可视化训练过程中的损失和准确率曲线,该文件的核心是函数 visualize_curve,该函数接受一个日志文件路径作为输入,读取训练日志并提取损失和准确率信息(包括图像到文本和文本到图像的 top-1 和 top-10 准确率)。然后,它生成两个图形:一个用于显示损失曲线,另一个用于显示准确率曲线。在准确率图中,分别绘制了图像到文本和文本到图像的 top-1 和 top-10 准确率。最后,结果图像保存为 train.jpg 并展示。主程序部分设置了日志文件路径并调用可视化函数。
import matplotlib.pyplot as plot
import os
import cv2
from matplotlib import pyplot as plt
def visualize_curve(log_root):
log_file = open(log_root, 'r')
result_root = log_root[:log_root.rfind('/') + 1] + 'train.jpg'
loss = []
top1_i2t = []
top10_i2t = []
top1_t2i = []
top10_t2i = []
for line in log_file.readlines():
line = line.strip().split()
if 'top10_t2i' not in line[-2]:
continue
loss.append(line[1])
top1_i2t.append(line[3])
top10_i2t.append(line[5])
top1_t2i.append(line[7])
top10_t2i.append(line[9])
log_file.close()
plt.figure('loss')
plt.plot(loss)
plt.figure('accuracy')
plt.subplot(211)
plt.plot(top1_i2t, label = 'top1')
plt.plot(top10_i2t, label = 'top10')
plt.legend(['image to text'], loc = 'upper right')
plt.subplot(212)
plt.plot(top1_t2i, label = 'top1')
plt.plot(top10_i2t, label = 'top10')
plt.legend(['text to image'], loc = 'upper right')
plt.savefig(result_root)
plt.show()
if __name__ == '__main__':
log_root = 'data/logs/train.log'
visualize_curve(log_root)
标签:loss,01,torch,文生,text,self,labels,实操,image
From: https://blog.csdn.net/asd343442/article/details/141396749