首页 > 其他分享 >pytorch动态量化函数

pytorch动态量化函数

时间:2024-06-16 16:33:38浏览次数:16  
标签:函数 nn torch quantization pytorch 模块 量化 qconfig

PyTorch 动态量化 API

PyTorch 提供了丰富的动态量化 API,可以帮助开发者轻松地将模型转换为动态量化模型。主要 API 包括:

torch.quantization.quantize_dynamic:将模型转换为动态量化模型。
torch.quantization.QuantStub:观察模型层的输入和输出分布。
torch.quantization.Observer:收集模型层的统计信息。
torch.quantization.DeQuantStub:将定点结果转换回浮点数。

PyTorch torch.quantization.quantize_dynamic 函数详解

torch.quantization.quantize_dynamic 函数是 PyTorch 提供的用于动态量化模型的主要 API。该函数可以将浮点模型转换为动态量化模型,从而显著降低模型大小和提高推理速度。

函数定义

torch.quantization.quantize_dynamic(
    model: torch.nn.Module,
    qconfig: Dict[Type[torch.nn.Module], Dict],
    dtype: torch.qscheme = torch.qint8
) -> torch.nn.Module

参数说明

  • model: 要转换的浮点模型。
  • qconfig: 指定要量化的模块类型和量化配置。
  • dtype: 指定量化的定点数据类型,可以是 torch.qint8torch.float16

函数返回值

quantize_dynamic 函数返回一个新的动态量化模型,该模型与原始模型具有相同的架构和功能。

函数功能

quantize_dynamic 函数主要执行以下操作:

  • 遍历模型中的每个模块。
  • 对于每个模块,检查其类型是否在 qconfig 中定义。
  • 如果模块类型在 qconfig 中定义,则根据 qconfig 中的配置对该模块进行动态量化。
  • 将量化的模块替换到新的模型中。

动态量化配置

qconfig 参数用于指定要量化的模块类型和量化配置。qconfig 是一个字典,其中键是模块类型,值是量化配置字典。量化配置字典可以包含以下键:

  • ``activation`: 指定激活的量化配置。
  • ``weight`: 指定权重的量化配置。
  • ``qscheme: 指定量化方案,可以是 torch.per_tensortorch.per_channel`。
  • ``dynamic`: 指定是否动态量化。

示例

以下是一个简单的示例,演示如何使用 quantize_dynamic 函数将模型转换为动态量化模型:

import torch
import torch.nn as nn
import torch.quantization

# 定义模型
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
)

# 定义量化配置
qconfig = {
    nn.Linear: {
        'activation': {'dtype': torch.qint8},
        'weight': {'dtype': torch.qint8}
    }
}

# 将模型转换为动态量化模型
quantized_model = torch.quantization.quantize_dynamic(
    model,
    qconfig,
    dtype=torch.qint8
)

# 测试模型
input = torch.randn(1, 10)
output = quantized_model(input)
print(output)

在这个示例中,我们定义了一个简单的模型,并使用 qconfig 参数指定了量化配置。qconfig 参数指示 quantize_dynamic 函数对模型中的所有 nn.Linear 模块进行动态量化,并将激活和权重量化为 torch.qint8 格式。

注意事项

在使用 torch.quantization.quantize_dynamic 函数时,需要注意以下几点:

  • 动态量化可能会导致模型精度下降,需要根据具体情况权衡性能和精度。
  • 动态量化目前还不支持所有模型类型和操作。
  • 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。

PyTorch torch.quantization.QuantStub 模块详解

torch.quantization.QuantStub 模块是 PyTorch 提供的用于动态量化模型的观察模块。该模块可以观察模型层的输入和输出分布,并收集统计信息,为动态量化提供必要的数据支持。

模块定义

class QuantStub(nn.Module):
    r"""Quantize stub module, before calibration, this is same as an observer,.

    It will be swapped as nnq.Quantize in convert .

    Parameters:
        qconfig(Dict): quantization configuration for the tensor, if qconfig is not
          provided, we will use the global qconfig
    """

    def __init__(self, qconfig=None):
        super(QuantStub, self).__init__()
        self.qconfig = qconfig

    def forward(self, x):
        return x

模块属性

  • qconfig: 量化配置字典。

模块方法

  • forward(x): 该方法只是简单地返回输入 x,不做任何处理。

模块功能

QuantStub 模块主要用于观察模型层的输入和输出分布,并收集统计信息。在动态量化过程中,QuantStub 模块会被替换为 nnq.Quantize 模块,nnq.Quantize 模块会使用收集的统计信息对输入进行量化。

示例

以下是一个简单的示例,演示如何使用 QuantStub 模块观察模型层的输入和输出分布:

import torch
import torch.nn as nn
import torch.quantization

# 定义模型
model = nn.Sequential(
    QuantStub(qconfig={'dtype': torch.qint8}),
    nn.Linear(10, 20),
    QuantStub(qconfig={'dtype': torch.qint8}),
    nn.ReLU(),
    QuantStub(qconfig={'dtype': torch.qint8}),
    nn.Linear(20, 1)
)

# 测试模型
input = torch.randn(1, 10)
output = model(input)
print(output)

在这个示例中,我们为模型中的每个层都添加了 QuantStub 模块。QuantStub 模块会观察每个层的输入和输出分布,并收集统计信息。

注意事项

在使用 torch.quantization.QuantStub 模块时,需要注意以下几点:

  • QuantStub 模块只用于观察模型层的输入和输出分布,不进行任何量化操作。
  • QuantStub 模块必须与 torch.quantization.DeQuantStub 模块搭配使用,才能完成动态量化。
  • 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。

PyTorch torch.quantization.Observer 模块详解

torch.quantization.Observer 模块是 PyTorch 提供的用于动态量化模型的观察模块。该模块可以观察模型层的输入和输出分布,并收集统计信息,为动态量化提供必要的数据支持。

模块定义

class Observer(nn.Module):
    r"""
    Observer module, which observes tensor quantization ranges for dynamic quantization.

    It attaches to the downstream module to observe the output of the module
    and records the min/max values for quantization.

    Parameters:
        dtype(torch.qscheme): quantization dtype, e.g torch.qint8
        quant_scheme(torch.qscheme): quantization scheme, e.g torch.per_tensor or
                                 torch.per_channel
    """

    def __init__(self, dtype=torch.qint8, quant_scheme=torch.per_tensor):
        super(Observer, self).__init__()
        assert dtype in [
            torch.qint8, torch.quint8, torch.bfloat16
        ], 'Only support torch.qint8, torch.quint8, torch.bfloat16 for now'
        self.dtype = dtype
        self.quant_scheme = quant_scheme
        self.qmin = None
        self.qmax = None
        self._called_once = False

    def forward(self, x):
        r"""Calculates the min/max values for quantization.

        Args:
            x(torch.Tensor): The input tensor to observe.

        Returns:
            torch.Tensor: The input tensor.
        """
        if not self._called_once:
            self._called_once = True
            if self.quant_scheme == torch.per_tensor:
                self.qmin = x.min()
                self.qmax = x.max()
            elif self.quant_scheme == torch.per_channel:
                self.qmin = x.data.min(dim=1)[0]
                self.qmax = x.data.max(dim=1)[0]
            else:
                raise NotImplementedError
        return x

模块属性

  • dtype: 量化数据类型,可以是 torch.qint8torch.quint8torch.bfloat16
  • quant_scheme: 量化方案,可以是 torch.per_tensortorch.per_channel
  • qmin: 最小值。
  • qmax: 最大值。

模块方法

  • forward(x): 该方法计算输入 x 的最小值和最大值,并将其存储在 qminqmax 属性中。

模块功能

Observer 模块主要用于观察模型层的输入和输出分布,并收集统计信息。在动态量化过程中,Observer 模块收集的统计信息将被用于计算量化参数,例如量化尺度和零点。

示例

以下是一个简单的示例,演示如何使用 Observer 模块观察模型层的输入和输出分布:

import torch
import torch.nn as nn
import torch.quantization

# 定义模型
model = nn.Sequential(
    Observer(dtype=torch.qint8, quant_scheme=torch.per_tensor),
    nn.Linear(10, 20),
    Observer(dtype=torch.qint8, quant_scheme=torch.per_tensor),
    nn.ReLU(),
    Observer(dtype=torch.qint8, quant_scheme=torch.per_tensor),
    nn.Linear(20, 1)
)

# 测试模型
input = torch.randn(1, 10)
output = model(input)
print(output)

在这个示例中,我们为模型中的每个层都添加了 Observer 模块。Observer 模块会观察每个层的输入和输出分布,并收集统计信息。

注意事项

在使用 torch.quantization.Observer 模块时,需要注意以下几点:

  • Observer 模块只用于观察模型层的输入和输出分布,不进行任何量化操作。
  • Observer 模块必须与 torch.quantization.DeQuantStub 模块搭配使用,才能完成动态量化。
  • 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。

PyTorch torch.quantization.DeQuantStub 模块详解

torch.quantization.DeQuantStub 模块是 PyTorch 提供的用于动态量化模型的反量化模块。该模块可以将定点张量转换为浮点张量,从而恢复模型的精度。

模块定义

class DeQuantStub(nn.Module):
    r"""Dequantize stub module, before calibration, this is same as identity,.

    It will be swapped as nnq.DeQuantize in convert .

    Parameters:
        qconfig(Dict): quantization configuration for the tensor, if qconfig is not
          provided, we will use the global qconfig
    """

    def __init__(self, qconfig=None):
        super(DeQuantStub, self).__init__()
        self.qconfig = qconfig

    def forward(self, x):
        return x

模块属性

  • qconfig: 量化配置字典。

模块方法

  • forward(x): 该方法只是简单地返回输入 x,不做任何处理。

模块功能

DeQuantStub 模块主要用于将定点张量转换为浮点张量。在动态量化过程中,DeQuantStub 模块会被替换为 nnq.DeQuantize 模块,nnq.DeQuantize 模块会将定点张量转换为浮点张量,从而恢复模型的精度。

示例

以下是一个简单的示例,演示如何使用 DeQuantStub 模块将定点张量转换为浮点张量:

import torch
import torch.nn as nn
import torch.quantization

# 定义模型
model = nn.Sequential(
    QuantStub(qconfig={'dtype': torch.qint8}),
    nn.Linear(10, 20),
    QuantStub(qconfig={'dtype': torch.qint8}),
    nn.ReLU(),
    QuantStub(qconfig={'dtype': torch.qint8}),
    nn.Linear(20, 1),
    DeQuantStub(qconfig={'dtype': torch.qint8})
)

# 将模型转换为动态量化模型
quantized_model = torch.quantization.quantize_dynamic(
    model,
    {nn.Linear: torch.quantization.QuantStub, nn.ReLU: torch.quantization.QuantStub},
    dtype=torch.qint8
)

# 测试模型
input = torch.randn(1, 10)
output = quantized_model(input)
print(output)

在这个示例中,我们在模型的最后添加了一个 DeQuantStub 模块。DeQuantStub 模块会将模型输出的定点张量转换为浮点张量,从而恢复模型的精度。

注意事项

在使用 torch.quantization.DeQuantStub 模块时,需要注意以下几点:

  • DeQuantStub 模块只用于将定点张量转换为浮点张量,不进行任何量化操作。
  • DeQuantStub 模块必须与 torch.quantization.QuantStub 模块搭配使用,才能完成动态量化。
  • 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。

更多资源

标签:函数,nn,torch,quantization,pytorch,模块,量化,qconfig
From: https://www.cnblogs.com/litifeng/p/18250808

相关文章

  • Unity的生命周期函数
    在Unity中,各个生命周期函数是在特定的时机被调用的,它们的执行顺序如下:1.Awake:当脚本实例被加载时调用,用于初始化数据。如果物体上有多个脚本,它们的Awake方法会在Start方法之前执行。2.OnEnable:当对象变为活动状态(enabled)或脚本被启用时调用。如果在场景加载后对象已经......
  • 编写单个函数的ROP链
    什么是ROP链在我初识栈溢出那篇博客已经详细的讲了函数的调用过程(基于X86框架),不了解的可以看一下,没有这个理论基础,是学不好ROP的。现在我们说一下什么是ROP。ROP链就是通过返回地址的修改来完成的编程,调用特定的函数的一种编程模式。我们可以联想一下你做的最简单的栈溢出的题,返......
  • Caffe、PyTorch、Scikit-learn、Spark MLlib 和 TensorFlowOnSpark 概述
    在AI框架方面,有几种工具可用于图像分类、视觉和语音等任务。有些很受欢迎,如PyTorch和Caffe,而另一些则更受限制。以下是四种流行的AI工具的亮点。CaffeeCaffee是贾扬青在加州大学伯克利分校(UCBerkeley)时开发的深度学习框架。该工具可用于图像分类、语音和视觉。但......
  • 要将URL参数转换为JSON对象,可以使用以下函数:
    要将URL参数转换为JSON对象,可以使用以下函数:javascriptfunctiongetQueryParams(url){//使用正则表达式提取URL参数constparamsString=url.split('?')[1];if(!paramsString){return{};}//将参数字符串分割成数组,并解析键值对constparams=......
  • 6、Oracle中的分组函数
    最近项目要用到Oracle,奈何之前没有使用过,所以在B站上面找了一个学习视频,用于记录学习过程以及自己的思考。视频链接:【尚硅谷】Oracle数据库全套教程,oracle从安装到实战应用如果有侵权,请联系删除,谢谢。学习目标:了解组函数。描述组函数的用途。使用GROUPBY子句对数据分......
  • 【Linux】fork()函数详解|多进程
    ......
  • PyTorch学习9:卷积神经网络
    文章目录前言一、说明二、具体实例1.程序说明2.代码示例总结前言介绍卷积神经网络的基本概念及具体实例一、说明1.如果一个网络由线性形式串联起来,那么就是一个全连接的网络。2.全连接会丧失图像的一些空间信息,因为是按照一维结构保存。CNN是按照图像原始结构进......
  • 关于ES6的箭头函数和展开运算符
    使用ES6的箭头函数和展开运算符(...)可以简化使用逻辑与(&&)运算符的代码。这种方法通常用于当你有一组变量,并且想要在单个表达式中检查它们是否都满足特定条件时。以下是一个示例,展示如何使用箭头函数和展开运算符来简化检查多个变量是否都已定义且不为空的代码://假设有以下变量co......
  • 金融量化分析开源工具:TuShare
    TuShare:一站式金融数据解决方案,让量化分析触手可及-精选真开源,释放新价值。概览TuShare,是Github社区上一个专为金融量化分析师和数据爱好者设计的开源工具,提供了从数据采集、清洗加工到数据存储的全流程服务。它以其数据覆盖面广、接口调用简便、响应速度快而广受好评。......
  • 机器视觉入门学习:YOLOV5自定义数据集部署、网络详解、损失函数(学习笔记)
     前言源码学习资源:YOLOV5预处理和后处理,源码详细分析-CSDN博客网络学习资源:YOLOv5网络详解_yolov5网络结构详解-CSDN博客YOLOv5-v6.0学习笔记_yolov5的置信度损失公式-CSDN博客 本文为个人学习,整合各路大佬的资料进行V5-6.0版本的网络分析,在开始学习之前最好先去学习YOL......