抛弃最后的输出层并非最佳方案。可以修改输出层的权重,以 mmdetection 使用预训练模型为例。
import torch
def faster_rcnn(num_classes):
model_coco = torch.load("e:/14_model.pth")
# weight
for i in list(model_coco["state_dict"]):
if 'bbox_head.fc_cls' in i:
model_coco["state_dict"][i] = model_coco["state_dict"][i][:num_classes]
elif 'bbox_head.fc_reg' in i:
model_coco["state_dict"][i] = model_coco["state_dict"][i][:(num_classes * 4)]
# save new model
torch.save(model_coco, "../checkpoints/faster_rcnn_r50_fpn_1x_%d.pth" % num_classes)
def cascade_rcnn(num_classes):
model_coco = torch.load("../checkpoints/cascade_rcnn_dconv_c3-c5_r101_fpn_1x_20190125-aaa877cc.pth")
model_coco["state_dict"]["bbox_head.0.fc_cls.weight"] =model_coco["state_dict"]["bbox_head.0.fc_cls.weight"][:num_classes, :]
model_coco["state_dict"]["bbox_head.1.fc_cls.weight"] =model_coco["state_dict"]["bbox_head.1.fc_cls.weight"][:num_classes, :]
model_coco["state_dict"]["bbox_head.2.fc_cls.weight"] =model_coco["state_dict"]["bbox_head.2.fc_cls.weight"][:num_classes, :]
model_coco["state_dict"]["bbox_head.0.fc_cls.bias"] = model_coco["state_dict"]["bbox_head.0.fc_cls.bias"][:num_classes]
model_coco["state_dict"]["bbox_head.1.fc_cls.bias"] = model_coco["state_dict"]["bbox_head.1.fc_cls.bias"][:num_classes]
model_coco["state_dict"]["bbox_head.2.fc_cls.bias"] =model_coco["state_dict"]["bbox_head.2.fc_cls.bias"][:num_classes]
# save new model
torch.save(model_coco, "../checkpoints/cascade_rcnn_dconv_c3-c5_r101_fpn_1x_%d.pth" % num_classes)
# cascade_rcnn(num_classes = 11)
faster_rcnn(4)