1 简介
原先的 tools/paddle/infer_paddle_model_shape.py 脚本使用的是PaddlePaddle 2.5,这里将 Paddle 相关API升级到 2.6.0。
2 实现过程
Paddle2.6 和 Paddle2.5 的在推理模型输入 shape 上的差别主要在读取/保存模型以及存放函数的位置上有区别。
2.1 修改读取函数
原读取模型函数如下:
[prog, ipts, outs] = fluid.io.load_inference_model(
args.model_dir,
exe,
model_filename=args.model_filename,
params_filename=args.params_filename)
需要修改为
[program, feed_target_names, fetch_targets] = static.io.load_inference_model(args.model_path,
exe)
2.2 修改保存函数
原保存函数如下
fluid.io.save_inference_model(
args.save_dir,
ipts,
outs,
exe,
prog,
model_filename=args.model_filename,
params_filename=args.params_filename)
需要修改为:
feed_vars = [program.global_block().var(name) for name in feed_target_names]
static.io.save_inference_model(
args.save_path,
feed_vars=feed_vars,
fetch_vars=fetch_targets,
executor=exe,
program=program)
3 脚本预览
import argparse
import paddle
import paddle.base as base
import paddle.static as static
def process_old_ops_desc(program):
for i in range(len(program.blocks[0].ops)):
if program.blocks[0].ops[i].type == "matmul":
if not program.blocks[0].ops[i].has_attr("head_number"):
program.blocks[0].ops[i]._set_attr("head_number", 1)
def infer_shape(program, input_shape_dict):
paddle.enable_static()
OP_WITHOUT_KERNEL_SET = {
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id',
'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream',
'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv',
'c_wait_comm', 'c_wait_compute', 'c_gen_hccl_id', 'c_comm_init_hccl',
'copy_cross_scope'
}
model_version = program.desc._version()
paddle_version = paddle.__version__
major_ver = model_version // 1000000
minor_ver = (model_version - major_ver * 1000000) // 1000
patch_ver = model_version - major_ver * 1000000 - minor_ver * 1000
model_version = "{}.{}.{}".format(major_ver, minor_ver, patch_ver)
if model_version != paddle_version:
print(
"[WARNING] The model is saved by paddlepaddle v{}, but now your paddlepaddle is version of {}, this difference may cause error, it is recommend you reinstall a same version of paddlepaddle for this model".
format(model_version, paddle_version))
for k, v in input_shape_dict.items():
program.blocks[0].var(k).desc.set_shape(v)
for i in range(len(program.blocks)):
for j in range(len(program.blocks[0].ops)):
if program.blocks[i].ops[j].type in OP_WITHOUT_KERNEL_SET:
continue
program.blocks[i].ops[j].desc.infer_shape(program.blocks[i].desc)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_path',
required=True,
help='Directory path to input model + model name without suffix.')
parser.add_argument(
'--input_shape_dict', required=True, help="The new shape information.")
parser.add_argument(
'--save_path',
required=True,
help='Directory path to save model + model name without suffix.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
paddle.enable_static()
input_shape_dict_str = args.input_shape_dict
input_shape_dict = eval(input_shape_dict_str)
print("Start to load paddle model...")
exe = base.Executor(paddle.CPUPlace())
[program, feed_target_names, fetch_targets] = static.io.load_inference_model(args.model_path, exe)
process_old_ops_desc(program)
infer_shape(program, input_shape_dict)
feed_vars = [program.global_block().var(name) for name in feed_target_names]
static.io.save_inference_model(
args.save_path,
feed_vars=feed_vars,
fetch_vars=fetch_targets,
executor=exe,
program=program)
4 参考文档
- load_inference_model-API文档-PaddlePaddle深度学习平台
- [Utils] infer_paddle_model_shape.py support paddlepaddle2.6 by Zheng-Bicheng · Pull Request #1214 · PaddlePaddle/Paddle2ONNX