运行代码:
import timm
import torch
model = timm.create_model(
'deit_small_patch16_224',
pretrained=True,
num_classes=6,
pretrained_cfg_overlay = dict(file='/home/lingdu/zyt/works/pretrained_models/deit_small_patch16_224-cd65a155.pth'))
torch.save(model, 'timm_models/deit_small.pth')
目的是想通过本地的权重文件,通过timm库来创建一个deit_small_patch16_224
模型。
报错信息:
File "/home/lingdu/zyt/works/PD_6/get_model.py", line 10, in <module>
model = timm.create_model(
File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/timm/models/_factory.py", line 117, in create_model
model = create_fn(
File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/timm/models/deit.py", line 258, in deit_small_patch16_224
model = _create_deit('deit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/timm/models/deit.py", line 123, in _create_deit
model = build_model_with_cfg(
File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/timm/models/_builder.py", line 418, in build_model_with_cfg
load_pretrained(
File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/timm/models/_builder.py", line 168, in load_pretrained
state_dict = load_state_dict(pretrained_loc)
File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/timm/models/_helpers.py", line 54, in load_state_dict
checkpoint = torch.load(checkpoint_path, map_location=device)
File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/torch/serialization.py", line 1025, in load
return _load(opened_zipfile,
File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/torch/serialization.py", line 1446, in _load
result = unpickler.load()
File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/torch/serialization.py", line 1439, in find_class
return super().find_class(mod_name, name)
ModuleNotFoundError: No module named 'timm.models.layers.patch_embed'
Deit作为一个Transformer系列的模型,毫无疑问会使用到patch_embed这个模块,这里的bug主要是由于路径错误。
查看github中timm的源码可以看到,在新版本的timm中,patch_embed位于timm.layers.patch_embed
路径下。
这里报错是因为环境中的timm是旧版本的,但下载的模型与最新版本的timm适配,造成了路径的不匹配。
解决方法:
卸载旧版本timm,
pip uninstall timm
安装最新版。
pip install timm==1.0.8
最新版本号我是在pypi上查看的。
https://pypi.org/project/timm/