首页 > 其他分享 >深度学习武器库-timm-非常好用的pytorch CV模型库 - 常用模型操作

深度学习武器库-timm-非常好用的pytorch CV模型库 - 常用模型操作

时间:2024-08-11 21:40:11浏览次数:11  
标签:权重 nn 模型库 模型 pytorch model CV pretrained timm

简要介绍

timm库,全称pytorch-image-models,是最前沿的PyTorch图像模型、预训练权重和实用脚本的开源集合库,其中的模型可用于训练、推理和验证。

github源码链接
https://github.com/huggingface/pytorch-image-models

文档教程
文档:https://huggingface.co/docs/hub/timm
上手教程:https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055


优点

1、方便使用。在python环境中安装timm库,即可用几行代码创建网络模型,并可选择导入在imagenet等数据集上得到的预训练权重;无需再去扒每个模型的源代码,这对于跑模型对比实验是非常方便的,可以节省大量的时间;

2、灵活性高。导入模型的原始做法是,直接用.pth等权重文件导入,但这通常受到保存模型方法的限制,可能出现权重键名称不匹配、网络中间张量操作丢失(只能导入模型权重 却无法导入网络层之间的中间张量操作)等问题;而timm中的模型是对于整个模型的封装,在创建模型后,可以对相应网络层进行改动调整,非常灵活;

3、模型收录全面且前沿。在CV这个模型日新月异的领域,timm及时更新收录了最新的模型,比如,2024.8.8 添加了ECCV2024上的新模型RDNet(工作链接:https://github.com/naver-ai/rdnet)。
image

缺点

目前使用中感到一点不方便的是,使用函数请求下载这个库中的权重文件时,链接地址大部分是huggingface上的,而huggingface得用US的节点才能有较好的网速 。。。 我的PC可以访问,但服务器访问不了。。。 但这个缺点可以通过一些操作进行规避。

常用模型操作

  • 查看目前收录模型
    使用代码:
    timm.list_models('*')
    运行效果:
    image
    可以看到目前收录共有946个模型,查看目前已收录的模型,从这个列表中确定要导入的目标模型名称

    还可以通过正则表达式匹配目标模型名称,并通过指定pretrained=True筛选有预训练权重的模型
    如下匹配预训练resnet:timm.list_models('resnet*', pretrained=True)
    image

  • 创建模型
    使用代码(以resnet50为例):
    timm.create_model('resnet50', pretrained=True, in_chans=3, num_classes=6)
    这里的主要参数有四个:
    第一个是模型名称model_name
    第二个是是否预训练pretrained,
    第三个是输入图像的通道数in_chans
    第四个是分类类别数num_classes,指最后输出FC层的维度。

    注意:在创建模型这步中包含了从网络下载模型权重的操作,
    此时就会出现我在“缺点”部分讲到的问题:因为网络无法连接huggingface网站,而导致权重下载请求失败的情况 (多在服务器端出现)。
    image

    下面是解决方法
    先在能够连接huggingface网站的PC上,手动下载权重配置文件,使用代码如下:

    backbone_name = 'resnet50'
    
    pretrained_cfg = timm.create_model(backbone_name).default_cfg
    print(pretrained_cfg)
    

    运行后输出配置信息:

    {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth', 'hf_hub_id': 'timm/resnet50.a1_in1k', 'architecture': 'resnet50', 'tag': 'a1_in1k', 'custom_load': False, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 0.95, 'test_crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'num_classes': 1000, 'pool_size': (7, 7), 'first_conv': 'conv1', 'classifier': 'fc', 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476'}
    

    其中,url对应了模型的下载请求地址,直接将这个url复制粘贴到浏览器中,手动下载权重文件。
    获得权重文件后,再使用timm.create_model方法,通过将pretrained_cfg_overlay参数指定为权重文件,来创建模型,这样就是本地创建了:

    backbone_name = 'resnet50'
    ckpt_path = './ckpt/resnet50_a1_0-14fe96d1.pth'
    
    model = timm.create_model(backbone_name,
                                       pretrained=True,
                                       pretrained_cfg_overlay=dict(file=ckpt_path))
    
  • 手动调整模型
    在cv中,最常见的操作是将某个网络的主干层,用于特征提取
    timm中有专门的方法可以实现这个目的:
    feature_ouput = model.forward_features(image)
    feature_ouput即为网络在最后的head层之前输出的特征向量。
    但这个操作还无法完美解决问题,因为通常认为提取特征就是排除网络的最后一层,但有的网络最后一层中不仅包括全连接(FC)层,在FC层之前还包含池化层。这就需要更灵活的操作,来调整、构建我们想要的网络。
    下面是我的代码(以手动添加池化层为例):

    feature_extract_model = nn.Sequential(*list(model.children())[:-1],
                                               nn.AdaptiveAvgPool2d(1))
    

    另外,也可以自己定制最后的head层(但个人感觉这个用途不多),例如:

    model.fc = nn.Sequential(
        nn.BatchNorm1d(num_in_features),
        nn.Linear(in_features=num_in_features, out_features=512, bias=False),
        nn.ReLU(),
        nn.BatchNorm1d(512),
        nn.Dropout(0.4),
        nn.Linear(in_features=512, out_features=10, bias=False))
    

本期总结

timm是一个非常全面且便捷的CV图像模型库,能够大大提升我们跑实验的效率。我们同样也能运用其中的部分模块类,用到自己的编码中,也可以在它的源码中学习模型的代码写法。本期笔记只是介绍了本人近期跑对比实验,使用后感觉最常用的一些方法和操作,timm还有很多功能和用法需要去探索,比如还有数据增强、数据集和优化器等等功能。

标签:权重,nn,模型库,模型,pytorch,model,CV,pretrained,timm
From: https://www.cnblogs.com/lingdu98/p/18353444

相关文章

  • 最优化 | 凸优化 | 二次规划cvxopt求解,如何确定系数?
    目录一、定义二、系数的确定三、例子四、代码一、定义在凸优化问题中,特别是在二次规划(QuadraticProgramming,QP)问题中,矩阵PPP通常用来定义目标函数中的二次项......
  • CVE-2019-12422~shiro反序列化【春秋云境靶场渗透】
    #今天我们来攻克CVE-2019-12422春秋云境这个靶场漏洞当我们知道了该靶场是shiro反序列化漏洞,所以直接用工具梭哈好小子,离成功又近一步!!!......
  • CVE-2023-38633~XXE注入【春秋云境靶场渗透】
    今天我们来攻克春秋云境CTF的CVE-2023-38633#我们通过抓包来构造POC来查找/etc/passwd<?xmlversion="1.0"encoding="UTF-8"standalone="no"?><svgwidth="1000"height="1000"xmlns:xi="http://www.w3.org/2001/XInclude&......
  • Pytorch入门:tensor张量的构建
    tensor数据结构是pytorch的基础与核心,本文主要介绍三种常用的tensor张量的构建方式。1.从已有其他数据转换为tensor数据常用方法有如下两种:torch.tensortorch.Tensor上述两种方法有细微的差别,具体通过示例来进行展示运行结果为 首先,torch.tensor会对转换前容器内元素......
  • pytorch深度学习实践(刘二大人)课后作业——Titanic数据集分析预测
    一、课后作业构造分类器对Titanic数据集进行预测1.数据集预处理(1)数据集下载与分析下载地址:https://www.kaggle.com/c/titanic/data导入必要的包,并查看训练集、测试集前五行数据importtorchimportnumpyasnpimportpandasaspdimportmatplotlib.pyplotaspltimp......
  • OpenCV的级联分类器训练
    使用增强级联的弱分类器包括两个主要阶段:训练和检测阶段。对象检测教程中有描述使用基于HAAR或LBP模型的检测阶段。这里主要介绍训练增强分类器级联所需的功能,包括:准备训练数据、执行实际模型训练、可视化训练。目录一、训练数据准备1、负样本2、正样本3、命令行参数......
  • OpenCV 膨胀与腐蚀
    目录膨胀腐蚀一:膨胀实现dilate二:实现腐蚀erode相关知识补充  (一)可以看做膨胀是将白色区域扩大,腐蚀是将黑色区域扩大  (二)可以不进行灰度处理,对彩色图片进行处理  (三)getStructuringElement方法  参数:  返回值:膨胀腐蚀一:膨胀实现dilateimportcv2......
  • OpenCV 开闭操作
    目录一:开操作(先腐蚀后膨胀)  特点:消除噪点,去除小的干扰块,而不影响原来的图像二:闭操作(先膨胀后腐蚀)  特点:可以填充闭合区域三:利用开操作完成的任务  (一)提取水平垂直线  原理:  (二)消除干扰线  (三)提取满足要求的形状一:开操作(先腐蚀后膨胀)特点:消除噪......
  • 利用OpenCvSharp进行图像相关操作
    前言程序设计过程,有时也需要对图像进行一些简单操作,C#没有现成的图像处理库,但有人对OpenCV进行了包装,我们可以很方便的使用OpenCvSharp对图像进行操作。当然了,这也需要使用的人员进行一些研究,但相对于C++版本,它已经非常友好了。1、显示图像代码:privatevoidbutton1_Click(......
  • 4.3.3 OpenCV 实现 高斯金字塔和拉普拉斯金字塔
    4.3.3OpenCV实现高斯金字塔和拉普拉斯金字塔参考教程:图像处理中的高斯金字塔和拉普拉斯金字塔_拉普拉斯金字塔插入偶数行,偶数列也是用卷积算法吗-CSDN博客1.安装OpenCV1.1下载OpenCV参考教程:无法定位软件包libjasper-dev的解决办法-CSDN博客视觉slam14讲ch5opencv......