首页 > 其他分享 >修改权重使用预训练模型权重

修改权重使用预训练模型权重

时间:2022-10-29 16:36:32浏览次数:65  
标签:权重 模型 修改 state fc dict bbox coco model


抛弃最后的输出层并非最佳方案。可以修改输出层的权重,以 mmdetection 使用预训练模型为例。

import torch

def faster_rcnn(num_classes):
model_coco = torch.load("e:/14_model.pth")
# weight
for i in list(model_coco["state_dict"]):
if 'bbox_head.fc_cls' in i:

model_coco["state_dict"][i] = model_coco["state_dict"][i][:num_classes]
elif 'bbox_head.fc_reg' in i:
model_coco["state_dict"][i] = model_coco["state_dict"][i][:(num_classes * 4)]
# save new model
torch.save(model_coco, "../checkpoints/faster_rcnn_r50_fpn_1x_%d.pth" % num_classes)

def cascade_rcnn(num_classes):
model_coco = torch.load("../checkpoints/cascade_rcnn_dconv_c3-c5_r101_fpn_1x_20190125-aaa877cc.pth")
model_coco["state_dict"]["bbox_head.0.fc_cls.weight"] =model_coco["state_dict"]["bbox_head.0.fc_cls.weight"][:num_classes, :]
model_coco["state_dict"]["bbox_head.1.fc_cls.weight"] =model_coco["state_dict"]["bbox_head.1.fc_cls.weight"][:num_classes, :]
model_coco["state_dict"]["bbox_head.2.fc_cls.weight"] =model_coco["state_dict"]["bbox_head.2.fc_cls.weight"][:num_classes, :]
model_coco["state_dict"]["bbox_head.0.fc_cls.bias"] = model_coco["state_dict"]["bbox_head.0.fc_cls.bias"][:num_classes]
model_coco["state_dict"]["bbox_head.1.fc_cls.bias"] = model_coco["state_dict"]["bbox_head.1.fc_cls.bias"][:num_classes]
model_coco["state_dict"]["bbox_head.2.fc_cls.bias"] =model_coco["state_dict"]["bbox_head.2.fc_cls.bias"][:num_classes]
# save new model
torch.save(model_coco, "../checkpoints/cascade_rcnn_dconv_c3-c5_r101_fpn_1x_%d.pth" % num_classes)

# cascade_rcnn(num_classes = 11)
faster_rcnn(4)


标签:权重,模型,修改,state,fc,dict,bbox,coco,model
From: https://blog.51cto.com/u_15847885/5806239

相关文章