首页 > 其他分享 >(4-7-01)文生图大模型实操:基于深度学习的图文匹配系统(1)工具类

(4-7-01)文生图大模型实操:基于深度学习的图文匹配系统(1)工具类

时间:2024-08-21 16:24:59浏览次数:14  
标签:loss 01 torch 文生 text self labels 实操 image

4.7  跨模态配对实战:基于深度学习的图文匹配系统

本项目旨在构建一个多模态学习系统,专注于处理图像和文本数据的配对任务,主要基于CUHK-PEDES数据集。本项目实现了多种深度学习模型,包括LSTM、MobileNetV1和ResNet,以分别处理文本和图像特征的提取与融合。通过这些模型的结合,系统能够有效地理解和匹配图像与对应的文本描述,旨在提高图像检索和描述生成的精度。整体架构包括数据处理、模型训练和评估环节,充分利用深度学习技术提升多模态任务的性能。

实例4-30基于深度学习的图文匹配系统(源码路径:codes\4\Image-Text-Matching

在本系统中,通过深度学习模型有效地将图像和其对应的文本描述进行匹配。具体而言,本项目采用了多种神经网络架构,包括LSTM用于文本特征提取,以及MobileNetV1和ResNet用于图像特征提取。这些模型的输出经过处理后进行联合嵌入,最终实现图像与文本之间的高效对应。另外,还结合了损失函数的设计,特别是约束损失,确保具有相同标签的图像和文本在特征空间中更接近,从而提高了匹配的准确性。通过在CUHK-PEDES数据集上的训练与评估,项目展示了多模态学习在图文配对任务中的有效性与潜力。

1. 工具类

在本项目的“utils”目录中提供了一系列实用工具,用于数据处理、统计分析和可视化,帮助用户对数据集进行深入分析,计算图像和标题的数量,并可视化训练过程中的损失和准确率。这些工具为项目的后续数据分析和结果展示奠定了基础,提升了工作效率。

(1)文件directory.py提供了与文件和目录操作相关的功能,主要用于确保在进行数据读写时所需的目录存在,并且能够将数据保存为 JSON 格式。

import os
import json

def makedir(root):
    if not os.path.exists(root):
        os.makedirs(root)


def write_json(data, root):
    with open(dir, 'w') as f:
        json.dump(data, f)

def check_exists(root):
    if os.path.exists(root):
        return True
    return False

对上述代码的具体说明如下所示:

  1. makedir(root):检查指定的目录是否存在,如果不存在,则创建该目录。
  2. write_json(data, root):将数据以 JSON 格式写入指定的文件。注意,此处的 dir 应该更改为 root,以确保函数能正确运行。
  3. check_exists(root):检查指定的路径是否存在,如果存在,则返回 True,否则返回 False。

(2)文件metric.py实现了“图像-文本”匹配任务所需的度量和损失函数,包括计算成对距离、独热编码、约束损失,以及交叉模态投影分类和匹配损失。此外,文件中还定义了用于计算模型性能的 Top-K 准确率的函数和管理平均值的工具类。通过这些功能,文件metric.py 支持模型的训练、评估和性能监控,帮助优化“图像-文本”的匹配效果。

  1. 类EMA实现了指数移动平均操作,用于更新和存储参数的平滑值。
class EMA():
    def __init__(self, decay=0.999):
        self.decay = decay
        self.shadow = {}

    def register(self, name, val):
        self.shadow[name] = val.cpu().detach()

    def get(self, name):
        return self.shadow[name]

    def update(self, name, x):
        assert name in self.shadow
        new_average = (1.0 - self.decay) * x.cpu().detach() + self.decay * self.shadow[name]
        self.shadow[name] = new_average.clone()
  1. 方法pairwise_distance(A, B)的功能是计算两个点集之间的成对距离,返回距离矩阵。
def pairwise_distance(A, B):
    A_square = torch.sum(A * A, dim=1, keepdim=True)
    B_square = torch.sum(B * B, dim=1, keepdim=True)

    distance = A_square + B_square.t() - 2 * torch.matmul(A, B.t())

    return distance
  1. 方法one_hot_coding(index, k)的功能是将索引转换为独热编码格式。
def constraints_old(features, labels):
    distance = pairwise_distance(features, features)
    labels_reshape = torch.reshape(labels, (features.shape[0], 1))
    labels_dist = labels_reshape - labels_reshape.t()
    labels_mask = (labels_dist == 0).float()

    num = torch.sum(labels_mask) - features.shape[0]
    if num == 0:
        con_loss = 0.0
    else:
        con_loss = torch.sum(distance * labels_mask) / num

    return con_loss
  1. 方法constraints_old 的功能是计算约束损失,该损失用于衡量特征之间的距离。通过计算特征的成对距离,并根据标签构建匹配掩码,进而求出匹配对的平均距离,以评估模型的特征学习效果。
def constraints_old(features, labels):
    distance = pairwise_distance(features, features)
    labels_reshape = torch.reshape(labels, (features.shape[0], 1))
    labels_dist = labels_reshape - labels_reshape.t()
    labels_mask = (labels_dist == 0).float()

    num = torch.sum(labels_mask) - features.shape[0]
    if num == 0:
        con_loss = 0.0
    else:
        con_loss = torch.sum(distance * labels_mask) / num

    return con_loss
  1. 方法constraints 的功能是改进的约束损失计算方法,与 constraints_old 类似,但采用了不同的方式计算每个类别的损失。方法constraints通过遍历标签中的唯一值,选择与每个类别相关的特征,并计算这些特征之间的成对距离,从而得到更精确的约束损失。
def constraints(features, labels):
    labels = torch.reshape(labels, (labels.shape[0],1))
    con_loss = AverageMeter()
    index_dict = {k.item() for k in labels}
    for index in index_dict:
        labels_mask = (labels == index)
        feas = torch.masked_select(features, labels_mask)
        feas = feas.view(-1, features.shape[1])
        distance = pairwise_distance(feas, feas)
        num = feas.shape[0] * (feas.shape[0] - 1)
        loss = torch.sum(distance) / num
        con_loss.update(loss, n = num / 2)
    return con_loss.avg
  1. 方法constraints_loss 的功能是计算整个数据集的约束损失,首先收集所有图像和文本的嵌入特征,并根据给定的标签计算图像和文本的约束损失,这为后续的模型训练和评估提供了约束损失值。
def constraints_loss(data_loader, network, args):
    network.eval()
    max_size = args.batch_size * len(data_loader)
    images_bank = torch.zeros((max_size, args.feature_size)).cuda()
    text_bank = torch.zeros((max_size,args.feature_size)).cuda()
    labels_bank = torch.zeros(max_size).cuda()
    index = 0
    con_images = 0.0
    con_text = 0.0
    with torch.no_grad():
        for images, captions, labels, captions_length in data_loader:
            images = images.cuda()
            captions = captions.cuda()
            interval = images.shape[0]
            image_embeddings, text_embeddings = network(images, captions, captions_length)
            images_bank[index: index + interval] = image_embeddings
            text_bank[index: index + interval] = text_embeddings
            labels_bank[index: index + interval] = labels
            index = index + interval
        images_bank = images_bank[:index]
        text_bank = text_bank[:index]
        labels_bank = labels_bank[:index]
    
    if args.constraints_text:
        con_text = constraints(text_bank, labels_bank)
    if args.constraints_images:
        con_images = constraints(images_bank, labels_bank)

    return con_images, con_text
  1. 类Loss的功能是定义模型的损失函数,它根据输入的参数初始化权重,并实现了交叉模态投影分类损失(CMPC)和交叉模态投影匹配损失(CMPM)的计算。这些损失函数用于训练和优化图像和文本嵌入的对齐。   
class Loss(nn.Module):
    def __init__(self, args):
        super(Loss, self).__init__()
        self.CMPM = args.CMPM
        self.CMPC = args.CMPC
        self.epsilon = args.epsilon
        self.num_classes = args.num_classes
        if args.resume:
            checkpoint = torch.load(args.model_path)
            self.W = Parameter(checkpoint['W'])
            print('=========> Loading in parameter W from pretrained models')
        else:
            self.W = Parameter(torch.randn(args.feature_size, args.num_classes))
            self.init_weight()

    def init_weight(self):
        nn.init.xavier_uniform_(self.W.data, gain=1)
  1. 方法compute_cmpc_loss 的功能是计算交叉模态投影分类损失(CMPC),该损失用于评估图像和文本嵌入的分类能力。它通过对图像和文本嵌入进行归一化和投影,计算交叉熵损失,以确保模型在图像和文本之间的相互映射。
    def compute_cmpc_loss(self, image_embeddings, text_embeddings, labels):
        """
        criterion = nn.CrossEntropyLoss(reduction='mean')
        self.W_norm = self.W / self.W.norm(dim=0)
        #labels_onehot = one_hot_coding(labels, self.num_classes).float()
        image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
        text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)

        image_proj_text = torch.sum(image_embeddings * text_norm, dim=1, keepdim=True) * text_norm
        text_proj_image = torch.sum(text_embeddings * image_norm, dim=1, keepdim=True) * image_norm

        image_logits = torch.matmul(image_proj_text, self.W_norm)
        text_logits = torch.matmul(text_proj_image, self.W_norm)
        
        cmpc_loss = criterion(image_logits, labels) + criterion(text_logits, labels)
        image_pred = torch.argmax(image_logits, dim=1)
        text_pred = torch.argmax(text_logits, dim=1)

        image_precision = torch.mean((image_pred == labels).float())
        text_precision = torch.mean((text_pred == labels).float())

        return cmpc_loss, image_precision, text_precision
  1. 方法compute_cmpm_loss 的功能是计算交叉模态投影匹配损失(CMPM),用于评估图像和文本嵌入的匹配能力。它通过计算正负样本对之间的相似性,并利用归一化标签掩码来优化嵌入的匹配性能。
    def compute_cmpm_loss(self, image_embeddings, text_embeddings, labels):

        batch_size = image_embeddings.shape[0]
        labels_reshape = torch.reshape(labels, (batch_size, 1))
        labels_dist = labels_reshape - labels_reshape.t()
        labels_mask = (labels_dist == 0)
        
        image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
        text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
        image_proj_text = torch.matmul(image_embeddings, text_norm.t())
        text_proj_image = torch.matmul(text_embeddings, image_norm.t())

        # normalize the true matching distribution
        labels_mask_norm = labels_mask.float() / labels_mask.float().norm(dim=1)
         
        i2t_pred = F.softmax(image_proj_text, dim=1)
        #i2t_loss = i2t_pred * torch.log((i2t_pred + self.epsilon)/ (labels_mask_norm + self.epsilon))
        i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + self.epsilon))
        
        t2i_pred = F.softmax(text_proj_image, dim=1)
        #t2i_loss = t2i_pred * torch.log((t2i_pred + self.epsilon)/ (labels_mask_norm + self.epsilon))
        t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + self.epsilon))

        cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1))

        sim_cos = torch.matmul(image_norm, text_norm.t())

        pos_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask))
        neg_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask == 0))
        
        return cmpm_loss, pos_avg_sim, neg_avg_sim
  1. 方法forward 的功能是执行前向传播,计算总体损失。根据输入的图像和文本嵌入,分别调用 CMPC 和 CMPM 的损失计算方法,并返回各类损失、精度以及正负样本的平均相似度。
    def forward(self, image_embeddings, text_embeddings, labels):
        cmpm_loss = 0.0
        cmpc_loss = 0.0
        image_precision = 0.0
        text_precision = 0.0
        neg_avg_sim = 0.0
        pos_avg_sim =0.0
        if self.CMPM:
            cmpm_loss, pos_avg_sim, neg_avg_sim = self.compute_cmpm_loss(image_embeddings, text_embeddings, labels)
        if self.CMPC:
            cmpc_loss, image_precision, text_precision = self.compute_cmpc_loss(image_embeddings, text_embeddings, labels)
        
        loss = cmpm_loss + cmpc_loss
        
        return cmpm_loss, cmpc_loss, loss, image_precision, text_precision, pos_avg_sim, neg_avg_sim
  1. 类AverageMeter 的功能是用于计算和存储当前值和平均值的工具。它提供了更新和重置的方法,以便在训练过程中跟踪损失和精度的变化,方便模型性能监控。
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += n * val
        self.count += n
        self.avg = self.sum / self.count
  1. 方法compute_topk 的功能是计算给定查询和图库的 Top-K 精度,它通过计算查询和图库之间的余弦相似度,并获取前 K 个相似项,以评估模型的检索性能。
def compute_topk(query, gallery, target_query, target_gallery, k=[1,10], reverse=False):
    result = []
    query = query / query.norm(dim=1,keepdim=True)
    gallery = gallery / gallery.norm(dim=1,keepdim=True)
    sim_cosine = torch.matmul(query, gallery.t())
    result.extend(topk(sim_cosine, target_gallery, target_query, k=[1,10]))
    if reverse:
        result.extend(topk(sim_cosine, target_query, target_gallery, k=[1,10], dim=0))
    return result
  1. 方法topk 的功能是实现 Top-K 精度计算的具体逻辑,它通过对相似度进行排序,找到正确标签在前 K 个预测中的数量,并计算其在总样本中的比例,最终返回各个 K 值对应的精度结果。
def topk(sim, target_gallery, target_query, k=[1,10], dim=1):
    result = []
    maxk = max(k)
    size_total = len(target_gallery)
    _, pred_index = sim.topk(maxk, dim, True, True)
    pred_labels = target_gallery[pred_index]
    if dim == 1:
        pred_labels = pred_labels.t()
    correct = pred_labels.eq(target_query.view(1,-1).expand_as(pred_labels))

    for topk in k:
        #correct_k = torch.sum(correct[:topk]).float()
        correct_k = torch.sum(correct[:topk], dim=0)
        correct_k = torch.sum(correct_k > 0).float()
        result.append(correct_k * 100 / size_total)
    return result

(3)文件statistics.py的功能是对数据集进行统计分析和可视化操作,具体实现代码如下所示。

def count_ids(root, flag=0):
    ids_dict = {}
    captions = 0
    with open(root,'r') as f:
        info = json.load(f)
        for data in info:
            label = data['id'] - flag
            ids_dict[label] = ids_dict.get(label,0) + 1
            captions += len(data['captions'])
    return ids_dict, captions


def count_images(root):
    info = pickle.load(open(root, 'rb'))['label_range']
    images_dict = {}
    for label in info:
        num_images = len(info[label]) - 1
        images_dict[num_images] = images_dict.get(num_images, 0) + 1
    return images_dict

def count_captions(root):
    info = pickle.load(open(root, 'rb'))['label_range']
    captions_dict = {}
    for label in info:
        for index in range(0, len(info[label]) - 1):
            num_captions = info[label][index] - info[label][index - 1]
            captions_dict[num_captions] = captions_dict.get(num_captions, 0) + 1
    return captions_dict

def visualize(data):
    keys = list(data.keys())
    keys.sort()
    values = []
    for key in keys:
        values.append(data[key])
    plt.figure('#captions in each image')
    a = plt.bar(keys, values)
    #plt.yticks([1,5,1,100,200,500,1000,5000])
    plt.xticks(list(range(min(keys), max(keys) + 1, 1)))
    autolabel(a)
    plt.xlim(min(keys) - 1, max(keys) + 1)
    plt.show()

def autolabel(rects):
    for rect in rects:
        height = rect.get_height()
        plt.text(rect.get_x() + rect.get_width() / 2 - 0.2, height + 2, '%s' % int(height))

if __name__ == "__main__":
    root = 'data/processed_data/train_sort.pkl'
    data = count_images(root)
    print(data)
    visualize(data)

对上述代码的具体说明如下所示:

  1. 函数count_ids:用于统计每个唯一标识符的出现次数以及总字幕数量。
  2. 函数count_images:用于计算每个标签下的图像数量并返回图像数量的分布。
  3. 函数count_captions:用于统计每个标签下的字幕数量并记录其频率。
  4. 函数visualize:负责生成柱状图,展示输入数据的可视化结果。
  5. 函数autolabel:用于在柱状图上自动标记每个条形的高度。
  6. 在主程序中,文件加载了处理过的数据集,计算图像数量,并可视化结果。

(4)文件visualize.py的功能是可视化训练过程中的损失和准确率曲线,该文件的核心是函数 visualize_curve,该函数接受一个日志文件路径作为输入,读取训练日志并提取损失和准确率信息(包括图像到文本和文本到图像的 top-1 和 top-10 准确率)。然后,它生成两个图形:一个用于显示损失曲线,另一个用于显示准确率曲线。在准确率图中,分别绘制了图像到文本和文本到图像的 top-1 和 top-10 准确率。最后,结果图像保存为 train.jpg 并展示。主程序部分设置了日志文件路径并调用可视化函数。

import matplotlib.pyplot as plot
import os
import cv2
from matplotlib import pyplot as plt

def visualize_curve(log_root):
    log_file = open(log_root, 'r')
    result_root = log_root[:log_root.rfind('/') + 1] + 'train.jpg'
    loss = []
    
    top1_i2t = []
    top10_i2t = []
    top1_t2i = []
    top10_t2i = []
    for line in log_file.readlines():
        line = line.strip().split()
        
        if 'top10_t2i' not in line[-2]:
            continue
        
        loss.append(line[1])
        top1_i2t.append(line[3])
        top10_i2t.append(line[5])
        top1_t2i.append(line[7])
        top10_t2i.append(line[9])

    log_file.close()
    plt.figure('loss')
    plt.plot(loss)
    plt.figure('accuracy')
    plt.subplot(211)
    plt.plot(top1_i2t, label = 'top1')
    plt.plot(top10_i2t, label = 'top10')
    plt.legend(['image to text'], loc = 'upper right')
    plt.subplot(212)
    plt.plot(top1_t2i, label = 'top1')
    plt.plot(top10_i2t, label = 'top10')
    plt.legend(['text to image'], loc = 'upper right')
    plt.savefig(result_root)
    plt.show()

if __name__ == '__main__':
    log_root = 'data/logs/train.log'
    visualize_curve(log_root)

标签:loss,01,torch,文生,text,self,labels,实操,image
From: https://blog.csdn.net/asd343442/article/details/141396749

相关文章

  • postman实操
    一、postman参数化1、{{变量名}}花括号,时两个括号在环境变量中设置变量参数,作用于所有的接口设置变量:当前所有的接口都可以调用这个参数get中设置变量:http://cms.duoceshi.cn/manage/loginJump.do?userAccount={{u1}}&loginPwd={{p1}}二、断言test中的内容详解:常用......
  • Docker快速入门 01 安装、部署环境
    1.简介和安装1.1简介Docker是一个应用打包、分发、部署的工具。打包:需要的环境变成一个“安装包”。分发:将“安装包”上传到云端,供他人获取。部署:将“安装包”下载下来后直接快速搭建运行环境。通俗讲就是轻量级的虚拟机,只虚拟需要的运行环境。1.2安装这里以Docker......
  • SBT30100VFCT-ASEMI无人机专用SBT30100VFCT
    编辑:llSBT30100VFCT-ASEMI无人机专用SBT30100VFCT型号:SBT30100VFCT品牌:ASEMI封装:TO-220F批号:最新最大平均正向电流(IF):30A最大循环峰值反向电压(VRRM):100V最大正向电压(VF):0.70V~0..90V工作温度:-65°C~175°C反向恢复时间:35ns芯片个数:2芯片尺寸:74mil引脚数量:3正向浪涌电流......
  • 多功能便携工具!VH501TC多类型传感器读数仪,助你完成频率、温度、电压和电流测量!
    多功能便携工具!VH501TC多类型传感器读数仪,助你完成频率、温度、电压和电流测量!VH501TC是一款专用的多类型传感器手持式读数仪,主要用于测量单弦式振弦传感器的读数,同时也可以辅助测量电压和电流传感器的数据。该设备内置了LoRA无线技术,可以与我公司的NLM系列产品配合使用,实现传感......
  • COAWST V3.8初学记录002(第二部分001:手册算例运行篇--单独运行ROMS和单独运行SWAN)
    COAWSTV3.8初学记录我是一个完完全全的海洋数值模式初学者,此前没有接触过任何海洋数值模式,在学习COAWST模式的过程中非常难受(起码从安装到算例的运行,是完完全全一个人独立学习完成,此前有求助过一些师兄和老师,但是他们也是爱莫能助,主要是距离太远,我这边的情况他们也不甚了......
  • COAWST V3.8初学记录001(第一部分:安装篇)
    COAWSTV3.8初学记录我是一个完完全全的海洋数值模式初学者,此前没有接触过任何海洋数值模式,在学习COAWST模式的过程中非常难受(起码从安装到算例的运行,是完完全全一个人独立学习完成,此前有求助过一些师兄和老师,但是他们也是爱莫能助,主要是距离太远,我这边的情况他们也不甚了......
  • MBR30100CT-ASEMI低压降肖特基MBR30100CT
    编辑:llMBR30100CT-ASEMI低压降肖特基MBR30100CT型号:MBR30100CT品牌:ASEMI封装:TO-220批号:最新恢复时间:35ns最大平均正向电流(IF):30A最大循环峰值反向电压(VRRM):100V最大正向电压(VF):0.70V~0.90V工作温度:-65°C~175°C芯片个数:2芯片尺寸:mil正向浪涌电流(IFMS):250AMBR30100CT特......
  • [题解]P3311 [SDOI2014] 数数
    P3311[SDOI2014]数数看到多模式匹配,我们考虑先对所有模式串建立AC自动机。然后发现这道题和P4052文本生成器(题解)挺像的,后者让求包含至少一个模式串的个数,这道题让求一个也不包含的个数,这个就是一个用不用\(26^m\)去减的问题,很好处理。但这道题还多了一个条件,“幸运数”必须\(......
  • 编写类A01,定义方法max,实现求某个double数组的最大值,并返回
    1publicclassHomework01{23//编写一个main方法4publicstaticvoidmain(String[]args){5A01a01=newA01();6double[]arr={1,1.4,-1.3,89.8,123.8,66};//;{};7Doubleres=a01.max(arr);8if......
  • CF2001C Guess The Tree
    欢迎前往我的博客获得也许更好的阅读体验!题意简述这是一个交互式问题。Misuki选择了一棵有\(n\)个节点的秘密树,节点编号为\(1\)到\(n\),并要求你通过以下类型的查询来猜出这棵树:“?ab”—Misuki会告诉你哪个节点\(x\)最小化了\(|d(a,x)-d(b,x)|\),其中\(d(x,......