from onnxruntime.quantization import QuantType,quantize_dynamic,quantize_static,CalibrationDataReader
import onnx
txt_test_list = read_file(os.path.join(ModelConfig().data_dir_pp,"test_test.txt"))
path = "E:\py_workspace\TinyBERT-PP-New\quantized"
tokenizer = BertTokenizer.from_pretrained(path, do_lower_case=True)
datas = []
for txt in txt_test_list:
tokens = ['[CLS]'] + tokenizer.tokenize(txt)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids = np.asarray([token_ids], dtype=np.int64)
datas.append(token_ids)
# 数据批次读取器
def batch_reader(datas, batch_size):
_datas = []
length = len(datas)
for i, data in enumerate(datas):
if batch_size==1:
yield {'input_ids': data}
elif (i+1) % batch_size==0:
_datas.append(data)
yield {'input_ids': np.concatenate(_datas, 0)}
_datas = []
elif i<length-1:
_datas.append(data)
else:
_datas.append(data)
yield {'input_ids': np.concatenate(_datas, 0)}
# 构建校准数据读取器
'''
实质是一个迭代器
get_next 方法返回一个如下样式的字典
{
输入 1: 数据 1,
...
输入 n: 数据 n
}
记录了模型的各个输入和其对应的经过预处理后的数据
'''
class DataReader(CalibrationDataReader):
def __init__(self, datas, batch_size):
self.datas = batch_reader(datas, batch_size)
def get_next(self):
return next(self.datas, None)
静态量化需要传入数据,获取对应的值,传入的数据无需标签值,但是返回值yield {'input_ids': data}需与模型的输入名一致
标签:onnx,batch,ids,input,数据处理,量化,data,datas,size From: https://blog.51cto.com/u_12727662/7434934