import torch import torchvision from torch.utils.mobile_optimizer import optimize_for_mobile # 1、导入模型 # 1.1导入预训练好的模型 # model = torchvision.models.mobilenet_v3_small(pretrained=True) # 1.1 注意 如果你的模型是自己定义的,应当像这样进行模型加载 model = MyModel() model.load_state_dict(torch.load('semantic_human_matting.pth')) # 2、模型进入验证模式(部署时直接推理) 不会进行反向传播和梯度计算 model.eval() # 3、定义示例输入 example = torch.rand(1, 3, 320, 320) # 4、记录示例输入在模型的所有张量上执行的操作 即得到所有的前向传播的操作 traced_script_module = torch.jit.trace(model, example) # 5、针对移动端进行优化 optimize_traced_model = optimize_for_mobile(traced_script_module) # 6、将模型保存到本地 optimize_traced_model._save_for_lite_interpreter("semantic_human_matting.pt")
标签:traced,部署,模型,torch,import,model,移动,optimize From: https://www.cnblogs.com/loveDodream-zzt/p/18109836