首页 > 编程语言 >MegEngine Python 层模块串讲(下)

MegEngine Python 层模块串讲(下)

时间:2023-07-31 16:35:37浏览次数:60  
标签:串讲 Python 模型 Module channels MegEngine 量化 self

前面的文章中,我们简单介绍了在 MegEngine imperative 中的各模块以及它们的作用。对于新用户而言可能不太了解各个模块的使用方法,对于模块的结构和原理也是一头雾水。Python 作为现在深度学习领域的主流编程语言,其相关的模块自然也是深度学习框架的重中之重。

模块串讲将对 MegEngine 的 Python 层相关模块分别进行更加深入的介绍,会涉及到一些原理的解释和代码解读。Python 层模块串讲共分为上、中、下三个部分,本文将介绍 Python 层的 quantization 模块。量化是为了减少模型的存储空间和计算量,从而加速模型的推理过程。在量化中,我们将权重和激活值从浮点数转换为整数,从而减少模型的大小和运算的复杂性。通过本文读者将会对量化的基本原理和使用 MegEngine 得到量化模型有所了解。

降低模型内存占用利器 —— quantization 模块

量化是一种对深度学习模型参数进行压缩以降低计算量的技术。它基于这样一种思想:神经网络是一个近似计算过程,不需要其中每个计算过程的绝对的精确。因此在某些情况下可以把需要较多比特存储的模型参数转为使用较少比特存储,而不影响模型的精度。

量化通过舍弃数值表示上的精度来追求极致的推理速度。直觉上用低精度/比特类型的模型参数会带来较大的模型精度下降(称之为掉点),但在经过一系列精妙的量化处理之后,掉点可以变得微乎其微。

如下图所示,量化通常是将浮点模型(常见神经网络的 Tensor 数据类型一般是 float32)处理为一个量化模型(Tensor 数据类型为 int8 等)。

1.png

量化基本流程

MegEngine 中支持工业界的两类主流量化技术,分别是训练后量化(PTQ)和量化感知训练(QAT)。

  1. 训练后量化(Post-Training QuantizationPTQ

    训练后量化,顾名思义就是将训练后的 Float 模型转换成低精度/比特模型。

    比较常见的做法是对模型的权重(weight)和激活值(activation)进行处理,把它们转换成精度更低的类型。虽然是在训练后再进行精度转换,但为了获取到模型转换需要的一些统计信息(比如缩放因子 scale),仍然需要在模型进行前向计算时插入观察者(Observer)。

    使用训练后量化技术通常会导致模型掉点,某些情况下甚至会导致模型不可用。可以使用小批量数据在量化之前对 Observer 进行校准(Calibration),这种方案叫做 Calibration 后量化。也可以使用 QAT 方案。

  2. 量化感知训练(Quantization-Aware TrainingQAT

    QAT 会向 Float 模型中插入一些伪量化(FakeQuantize)算子,在前向计算过程中伪量化算子根据 Observer 观察到的信息进行量化模拟,模拟数值截断的情况下的数值转换,再将转换后的值还原为原类型。让被量化对象在训练时“提前适应”量化操作,减少训练后量化的掉点影响。

    而增加这些伪量化算子模拟量化过程又会增加训练开销,因此模型量化通常的思路是:

    • 按照平时训练模型的流程,设计好 Float 模型并进行训练,得到一个预训练模型;
    • 插入 Observer 和 FakeQuantize 算子,得到 Quantized-Float 模型(QFloat 模型)进行量化感知训练;
    • 训练后量化,得到真正的 Quantized 模型(Q 模型),也就是最终用来进行推理的低比特模型。

    过程如下图所示(实际使用时,量化流程也可能会有变化):

2.png

  1. 注意这里的量化感知训练 QAT 是在预训练好的 QFloat 模型上微调(Fine-tune)的(而不是在原来的 Float 模型上),这样减小了训练的开销,得到的微调后的模型再做训练后量化 PTQ(“真量化”),QModel 就是最终部署的模型。

模型(Model)与模块(Module

量化是一个对模型(Model)的转换操作,但其本质其实是对模型中的模块( Module) 进行替换。

在 MegEngine 中,对应与 Float Model 、QFloat Model 和 Q Model 的 Module 分别为:

  1. 进行正常 float 运算的默认 Module
  2. 带有 Observer 和 FakeQuantize 算子的 qat.QATModule
  3. 无法训练、专门用于部署的 quantized.QuantizedModule

以 Conv 算子为例,这些 Module 对应的实现分别在:

量化配置 QConfig

量化配置包括 Observer 和 FakeQuantize 两部分,要设置它们,用户可以使用 MegEngine 预设配置也可以自定义配置。

  1. 使用 MegEngine 预设配置

    MegEngine 提供了多种量化预设配置

    以 ema_fakequant_qconfig 为例,用户可以通过如下代码使用该预设配置:

import megengine.quantization as Q
Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)
  1. 用户自定义量化配置

    用户还可以自己选择 Observer 和 FakeQuantize,灵活配置 QConfig 灵活选择 weight_observeract_observerweight_fake_quant 和 act_fake_quant)。

    可选的 Observer 和 FakeQuantize 可参考量化 API 参考页面。

QConfig 提供了一系列用于对模型做量化的接口,要使用这些接口,需要网络的 Module 能够在 forward 时给权重、激活值加上 Observer 以及进行 FakeQuantize

模型转换的作用是:将普通的 Float Module 替换为支持这些操作的 QATModule(可以训练),再替换为 QuantizeModule(无法训练、专用于部署)。

以 Conv2d 为例,模型转换的过程如图:

3.png

在量化时常常会用到算子融合(Fusion)。比如一个 Conv2d 算子加上一个 BatchNorm2d 算子,可以用一个 ConvBn2d 算子来等价替代,这里 ConvBn2d 算子就是 Conv2d 和 BatchNorm2d 的融合算子。

MegEngine 中提供了一些预先融合好的 Module,比如 ConvRelu2dConvBn2d 和 ConvBnRelu2d 等。使用融合算子会使用底层实现好的融合算子(kernel),而不会分别调用子模块在底层的 kernel,因此能够加快模型的速度,而且框架还无需根据网络结构进行自动匹配和融合优化,同时存在融合和不需融合的算子也可以让用户能更好的控制网络转换的过程。

实现预先融合的 Module 也有缺点,那就是用户需要在代码中修改原先的网络结构(把可以融合的多个 Module 改为融合后的 Module)。

模型转换的原理是,将父 Module 中的 Quantable (可被量化的)子 Module 替换为新 Module。而这些 Quantable submodule 中可能又包含 Quantable submodule,这些 submodule 不会再进一步转换,因为其父 Module 被替换后的 forward 计算过程已经改变了,不再依赖于这些子 Module

有时候用户不希望对模型的部分 Module 进行转换,而是保留其 Float 状态(比如转换会导致模型掉点),则可以使用 disable_quantize 方法关闭量化。

比如下面这行代码关闭了 fc 层的量化处理:

model.fc.disable_quantize()

由于模型转换过程修改了原网络结构,因此模型保存与加载无法直接适用于转换后的网络,读取新网络保存的参数时,需要先调用转换接口得到转换后的网络,才能用 load_state_dict 将参数进行加载。

量化代码

要从一个 Float 模型得到一个可用于部署的量化模型,大致需要经历三个步骤:

  1. 修改网络结构。将 Float 模型中的普通 Module 替换为已经融合好的 Module,比如 ConvBn2dConvBnRelu2d 等(可以参考 imperative/python/megengine/module/quantized 目录下提供的已融合模块)。然后在正常模式下预训练模型,并且在每轮迭代保存网络检查点。

    以 ResNet18 的 BasicBlock 为例,模块修改前的代码为:

class BasicBlock(M.Module):
      def __init__(self, in_channels, channels):
         super().__init__()
         self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=dilation, bias=False)
         self.bn1 = M.BatchNorm2d
         self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
         self.bn2 = M.BatchNorm2d
         self.downsample = (
            M.Identity()
            if in_channels == channels and stride == 1
            else M.Sequential(
            M.Conv2d(in_channels, channels, 1, stride, bias=False)
            M.BatchNorm2d
         )
​
      def forward(self, x):
         identity = x
         x = F.relu(self.bn1(self.conv1(x)))
         x = self.bn2(self.conv2(x))
         identity = self.downsample(identity)
         x = F.relu(x + identity)
         return x

注意到现在的前向中使用的都是普通 Module 拼接在一起,而实际上许多模块是可以融合的。

用可以融合的模块替换掉原先的 Module

class BasicBlock(M.Module):
      def __init__(self, in_channels, channels):
         super().__init__()
         self.conv_bn_relu1 = M.ConvBnRelu2d(in_channels, channels, 3, 1, padding=dilation, bias=False)
         self.conv_bn2 = M.ConvBn2d(channels, channels, 3, 1, padding=1, bias=False)
         self.downsample = (
            M.Identity()
            if in_channels == channels and stride == 1
            else M.ConvBn2d(in_channels, channels, 1, 1, bias=False)
         )
         self.add_relu = M.Elemwise("FUSE_ADD_RELU")
​
      def forward(self, x):
         identity = x
         x = self.conv_bn_relu1(x)
         x = self.conv_bn2(x)
         identity = self.downsample(identity)
         x = self.add_relu(x, identity)
         return x

注意到此时前向中已经有许多模块使用的是融合后的 Module

再对该模型进行若干论迭代训练,并保存检查点:

for step in range(0, total_steps):
    # Linear learning rate decay
    epoch = step // steps_per_epoch
    learning_rate = adjust_learning_rate(step, epoch)
​
    image, label = next(train_queue)
    image = tensor(image.astype("float32"))
    label = tensor(label.astype("int32"))
​
    n = image.shape[0]
​
    loss, acc1, acc5 = train_func(image, label, net, gm)  # traced
    optimizer.step().clear_grad()
​
    # Save checkpoints

完整代码见:

-   [修改前的模型结构](https://github.com/MegEngine/Models/blob/master/official/vision/classification/resnet/model.py)
-   [修改后的模型结构](https://github.com/MegEngine/Models/blob/master/official/quantization/models/resnet.py)
  1. 调用 quantize_qat 方法 将 Float 模型转换为 QFloat 模型,并进行微调(量化感知训练或校准,取决于 QConfig)。

    使用 quantize_qat 方法将 Float 模型转换为 QFloat 模型的代码大致为:

from megengine.quantization import ema_fakequant_qconfig, quantize_qat
​
model = ResNet18()
​
# QAT
quantize_qat(model, ema_fakequant_qconfig)
​
# Or Calibration:
# quantize_qat(model, calibration_qconfig)

将 Float 模型转换为 QFloat 模型后,加载预训练 Float 模型保存的检查点进行微调 / 校准:

if args.checkpoint:
    logger.info("Load pretrained weights from %s", args.checkpoint)
    ckpt = mge.load(args.checkpoint)
    ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    model.load_state_dict(ckpt, strict=False)
​
# Fine-tune / Calibrate with new traced train_func
# Save checkpoints

完整代码见:

-   [Finetune](https://github.com/MegEngine/Models/blob/master/official/quantization/finetune.py)
-   [Calibration](https://github.com/MegEngine/Models/blob/master/official/quantization/calibration.py)
  1. 调用 quantize 方法将 QFloat 模型转换为 Q 模型,也就是可用于模型部署的量化模型。

需要在推理的方法中设置 trace 的 capture_as_const=True,以进行模型导出:

from megengine.quantization import quantize
​
@jit.trace(capture_as_const=True)
def infer_func(processed_img):
    model.eval()
    logits = model(processed_img)
    probs = F.softmax(logits)
    return probs
​
quantize(model)
​
processed_img = transform.apply(image)[np.newaxis, :]
processed_img = processed_img.astype("int8")
probs = infer_func(processed_img)
​
infer_func.dump(output_file, arg_names=["data"])

调用了 quantize 后,model 就从 QFloat 模型转换为了 Q 模型,之后便使用这个 Quantized 模型进行推理。

调用 dump 方法将模型导出,便得到了一个可用于部署的量化模型。

完整代码见:

小结

MegEngine Python 层模块串讲系列到这里就结束了,我们介绍了用户在使用 MegEngine 时主要会接触到的 python 层的各个模块的主要功能、结构以及使用方法,此外还有一些原理性的介绍。对于各模块具体实现感兴趣的读者可以参考 MegEngine 官方文档 和 github。之后的文章我们会对 MegEngine 开发相关工具以及 MegEngine 底层的实现做更深入的介绍。

更多 MegEngine 信息获取,您可以:查看文档GitHub 项目,或加入 MegEngine 用户交流 QQ 群:1029741705。欢迎参与 MegEngine 社区贡献,成为 Awesome MegEngineer,荣誉证书、定制礼品享不停。

标签:串讲,Python,模型,Module,channels,MegEngine,量化,self
From: https://www.cnblogs.com/megengine/p/17593771.html

相关文章

  • 【Python&目标识别】Labelimg标记深度学习(yolo)样本
    ​    人工智能、ai、深度学习已经火了很长一段时间了,但是还有很多小伙伴没有接触到这个行业,但大家应该多多少少听过,网上有些兼职就是拿电脑拉拉框、数据标注啥的,其实这就是在标记样本,供计算机去学习。所以今天跟大家分享下如何使用Labelimg去自己标记深度学习样本。......
  • 【Python】一键提取inp文件结构的脚本
    inp=input("输入文件路径:")#print(type(inp))ex_txt=inp+'-Struct.inp'inp=inp+'.inp'importref2=open(ex_txt,'w')withopen(inp,'r',encoding="utf-8")asf1:row_num=0foriinf1:......
  • Python 导入function和导入moudle的区别
    以pprint为例导入moudleimportpprint同比C#创建对象,可以通过moudle名访问其中定义的变量、函数、类是长期过程会将moudle定义加载到内存中,整个程序执行过程中均可使用访问方法moudleName.functionNamepprint.pprint(data)导入functionfrompprintimportpprint......
  • python 比较两个excel A有b没有
    importpandasaspd#读取第一个Excel文件df1=pd.read_excel('excel_file1.xlsx')#读取第二个Excel文件df2=pd.read_excel('excel_file2.xlsx')#找出在df1中存在但不在df2中的行missing_rows=df1[~df1['列名'].isin(df2['列名'])]#保存缺失的数据到新的E......
  • Python去除文本中的NUL(0x00)字符
    问题描述在python中将文本数据存储到PostgreSQL数据库中报以下错误ValueError:AstringliteralcannotcontainNUL(0x00)characters.原因PostgreSQL不支持在文本字段中存储NULL(0x00)字符(这与支持文本中带有NULL值的数据库显然不同)。如果需要存储NULL字符,则可以使用byt......
  • 秋叶整合包如何安装Python包
    前几天写了一篇《手把手教你在本机安装StableDiffusion秋叶整合包》的文章,有同学运行时遇到缺少PythonModule的问题,帮助他处理了一下,今天把这个经验分享给大家,希望能帮助到更多的同学。有时候启动某些插件的时候会出现ModuleNotFoundError的提示,类似下图这样:这时候就需要......
  • python学习_元组
    一、什么是元组?元组也是python内置的数据结构,是一个不可变的序列,他也可以存放不同数据类型的元素不可变序列有:就是不可以改变的序列,没有增、删、改的操作,如元组、字符串就是不可变序列可变序列:可以对序列进行增、删、改操作,对象地址不发生改变,如列表、字典等'''不可变序列与......
  • python Pycharm出现“can't find '__main__' module”解决方案
    是配置没配对,因为在配置时没有选择.py文件,而只选择了工程名。因此选择EditConfigurations。选择EditConfigurations后,查看Scriptpath只选择了工程名innerfuns,而这里应该要选择工程名里面的.py文件(main函数,如果没有,选择你要执行的.py文件)。最终可运行成功......
  • Python第一天
    1、变量名-字母-下划线-数字注:不能是关键字、不能数字开头、不要和内置的东西重复补充:变量名尽量写的有意义,对变量名所指向的东西尽量看名字可识别。技巧:变量名可以用单词,另外用下划线进行断句,已表示清楚(还有可以用首字母大写进行断句,python用下划线比较清晰) 2、字符串1......
  • mobaxterm python
    实现MobaxtermPython的步骤1.下载和安装Mobaxterm首先,你需要下载并安装Mobaxterm,它是一个功能强大的终端仿真器和X服务器,可在Windows上运行。你可以在Mobaxterm的官方网站(2.打开Mobaxterm安装完成后,打开Mobaxterm。你将看到一个类似于命令行的界面,其中包含一个终端窗口和一......