"""
#coding:utf-8
__project_ = 'TF2learning'
__file_name__ = 'quantization'
__author__ = 'qilibin'
__time__ = '2021/3/17 9:18'
__product_name = PyCharm
"""
import h5py
import pandas as pd
import numpy as np
'''
读取原来的只包含权重的H5模型,按层遍历,对每层的每个权重进行16位或8位量化,将量化后的权重数值重新保存在H5文件中
'''
def quantization16bit(old_model_path,new_model_path,bit_num):
'''
:param old_model_path: 未量化的模型路径 模型是只保存了权重未保存网络结构
:param new_model_path: 量化过后的模型路径
:param bit_num: 量化位数
:return:
'''
f = h5py.File(old_model_path,'r')
f2 = h5py.File(new_model_path,'w')
for layer in f.keys():
# layer : 层的名称
print (layer)
# # 每层里面的权重名称 有的层没有参数
# name_of_weight_of_layer = f[layer].attrs['weight_names']
# # 有的层是没有参数的 比如 relu
# length = len(name_of_weight_of_layer)
length = len(list(f[layer].keys()))
if length > 0:
g1 = f2.create_group(layer)
g1.attrs["weight_names"] = layer
g2 = g1.create_group(layer)
for weight in f[layer][layer].keys():
print ("wieght name is :" + weight)
oldparam = f[layer][layer][weight][:]
print ('-----------------------------------------old-----------------------')
print (oldparam)
if type(oldparam) == np.ndarray:
if bit_num == 16:
newparam = np.float16(oldparam)
if bit_num == 8:
min_val = np.min(oldparam)
max_val = np.max(oldparam)
oldparam = np.round((oldparam - min_val) / (max_val - min_val) * 255)
newparam = np.uint8(oldparam)
else:
newparam = oldparam
print ('-----------------------------------------new-----------------------')
#print (newparam)
#f[key][key][weight_name][:] = newparam 在原来模型的基础上修改 行不通
if bit_num == 16:
d = g2.create_dataset(weight, data=newparam,dtype=np.float16)
if bit_num == 8:
d = g2.create_dataset(weight, data=newparam, dtype=np.uint8)
else:
g1 = f2.create_group(layer)
g1.attrs["weight_names"] = layer
f.close()
f2.close()
old_model_path = './yolox_s.h5'
new_model_path = './yolox_sq.h5'
quantization16bit(old_model_path,new_model_path,8)
# print (f['batch_normalization']['batch_normalization']['gamma:0'][:])