首页 > 其他分享 >【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏

【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏

时间:2024-11-04 21:20:06浏览次数:5  
标签:group 语言 neg 模型 torch pos size self 07

【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏


目录

文章目录


BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏
在这里插入图片描述

摘要

本文介绍了BitDistiller,这是一个通过结合量化感知训练(QAT)和知识蒸馏(KD)来提升超低精度(亚4比特)大型语言模型(LLMs)性能的框架。BitDistiller首先采用定制的非对称量化和裁剪技术来尽可能保持量化权重的保真度,然后提出了一种新颖的基于置信度的Kullback-Leibler散度(CAKLD)目标,用于自蒸馏,以实现更快的收敛和更优的模型性能。实验评估表明,BitDistiller在3比特和2比特配置下,无论是在通用语言理解还是复杂推理基准测试中,都显著超越了现有方法。值得注意的是,BitDistiller更具成本效益,需要更少的数据和训练资源。

研究背景

随着大型语言模型(LLMs)规模的扩大,自然语言处理领域取得了令人印象深刻的进展。然而,这种模型规模的扩大在部署上带来了显著的挑战,尤其是在资源受限的设备上,因为它们需要大量的内存和计算能力。权重量化作为一种流行的策略,通过减少模型大小来提高LLMs的效率和可访问性,同时最小化性能损失。尽管4比特量化已被广泛采用,提供了显著的压缩比和保留LLM能力之间的平衡,但亚4比特量化会显著降低模型权重的保真度,尤其是在小型模型或需要复杂推理的任务中,导致模型性能恶化。
在这里插入图片描述

问题与挑战

在极端低比特QAT中实现高性能的两个基本挑战是:如何在量化过程中最大限度地保持权重保真度,以及如何在训练中有效学习低比特表示。

如何解决

BitDistiller通过以下方式解决上述挑战:

  1. 非对称量化和裁剪:BitDistiller采用了定制的非对称量化和裁剪策略,以保持全精度模型的能力,特别是在超低比特水平上。
  2. 自蒸馏:BitDistiller利用全精度模型作为教师,低比特模型作为学生,通过自蒸馏方法进行有效的低比特表示学习。
  3. CAKLD目标:BitDistiller创新性地提出了一种基于置信度的Kullback-Leibler散度(CAKLD)目标,优化知识传递效率,实现更快的收敛和增强的模型性能。

创新点

  • 非对称量化和裁剪:BitDistiller针对不同比特级别的量化采用了不同的量化策略,如NF格式和INT格式,以及非对称裁剪,以提高量化权重的表示保真度。
  • CAKLD目标:BitDistiller提出了一种新颖的CAKLD目标,它根据全精度模型对训练数据的置信度自动权衡模式寻求和模式覆盖行为。
  • 自蒸馏框架:BitDistiller将QAT与知识蒸馏相结合,使用全精度模型作为教师来指导低比特学生模型,这是一种简单而有效的自蒸馏方法。
    在这里插入图片描述

算法模型

BitDistiller的框架包括以下几个关键步骤:

  1. 非对称量化和裁剪:在QAT初始化阶段,BitDistiller对权重进行非对称裁剪,以减少量化误差。
  2. 自蒸馏:在训练过程中,全精度模型生成数据,低比特模型学习这些数据,通过CAKLD目标进行优化。
  3. CAKLD目标:CAKLD目标结合了反向KL散度和正向KL散度,根据全精度模型的置信度自动调整模式寻求和模式覆盖行为。
    在这里插入图片描述

实验效果

实验评估表明,BitDistiller在3比特和2比特配置下的性能显著优于现有的PTQ和QAT方法。以下是一些重要的数据和结论:

  • 语言建模任务:在WikiText-2的困惑度(PPL)和MMLU(5-shot)准确性方面,BitDistiller超越了竞争对手。
  • 推理任务:在HumanEval和GSM8K等推理基准测试中,BitDistiller在3比特和2比特量化中均展现出优越性能。
  • 成本效益:BitDistiller需要的训练数据和资源更少,更具成本效益。
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

代码

https://github.com/DD-DuDa/BitDistiller.git
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from tqdm import tqdm
import gc
# import bitsandbytes as bnb
import torch.nn as nn
from functools import partial
# import bitsandbytes.functional as bnbF

class Round(Function):
    @staticmethod
    def forward(self, input):
        sign = torch.sign(input)
        output = sign * torch.floor(torch.abs(input) + 0.5)
        return output

    @staticmethod
    def backward(self, grad_output):
        grad_input = grad_output.clone()
        return grad_input

# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,
                           zero_point=True, q_group_size=-1,
                           inplace=False,
                           get_scale_zp=False
                           ):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
    elif q_group_size == -1:
        w = w.reshape(-1, w.shape[-1])
    assert w.dim() == 2
    if zero_point:
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
        max_int = 2 ** n_bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
    else:  # we actually never used this
        assert min_val is None
        max_val = w.abs().amax(dim=1, keepdim=True)
        max_val = max_val.clamp(min=1e-5)
        max_int = 2 ** (n_bit - 1) - 1
        min_int = - 2 ** (n_bit - 1)
        scales = max_val / max_int
        zeros = 0

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    if inplace:
        ((w.div_(scales).round_().add_(zeros)).clamp_(
            min_int, max_int).sub_(zeros)).mul_(scales)
    else:
        w = (torch.clamp(torch.round(w / scales) +
                         zeros, min_int, max_int) - zeros) * scales
    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)

    if get_scale_zp:
        return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
    else:
        return w



@torch.no_grad()
def real_quantize_model_weight(
    model, w_bit, q_config,
    init_only=False
):
    from .qmodule import WQLinear
    from .pre_quant import get_blocks, get_named_linears, set_op_by_name
    assert q_config["zero_point"], "We only support zero_point quantization now."
    
    layers = get_blocks(model)
    for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):
        layer = layers[i]
        named_linears = get_named_linears(layer)
        # scale_activations(layer)

        for name, module in named_linears.items():
            if init_only:
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], True)
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            else:
                module.cuda()
                module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
                # scales = scales.t().contiguous()
                # zeros = zeros.t().contiguous()
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], False, scales, zeros)
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
                
    torch.cuda.empty_cache()
    gc.collect()




def pseudo_quantize_n2f3_tensor(w, q_group_size=-1):
    quantizer = SteN2F3Quantizer(q_group_size=q_group_size)
    w = quantizer(w)
    return w


class SteInt3AsymQuantizer(nn.Module):
    def __init__(self, q_group_size=128):
        super().__init__()
        self.q_group_size = q_group_size
        self.bit = 3
    def forward(self, x):
        org_w_shape = x.shape

        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, self.q_group_size)
        elif self.q_group_size == -1:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, x.shape[-1])
        assert x.dim() == 2

        max_val = x.amax(dim=1, keepdim=True)
        min_val = x.amin(dim=1, keepdim=True)
        max_int = 2 ** self.bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(x).sum() == 0

        x = (torch.clamp(Round.apply(x / scales) +
                         zeros, min_int, max_int) - zeros) * scales
        assert torch.isnan(x).sum() == 0

        x = x.reshape(org_w_shape)

        return x

class SteInt2AsymQuantizer(nn.Module):
    def __init__(self, q_group_size=64):
        super().__init__()
        self.q_group_size = q_group_size
        self.bit = 2
    def forward(self, x):
        org_w_shape = x.shape

        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, self.q_group_size)
        assert x.dim() == 2

        max_val = x.amax(dim=1, keepdim=True)
        min_val = x.amin(dim=1, keepdim=True)
        max_int = 2 ** self.bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(x).sum() == 0

        x = (torch.clamp(Round.apply(x / scales) +
                         zeros, min_int, max_int) - zeros) * scales
        assert torch.isnan(x).sum() == 0

        x = x.reshape(org_w_shape)

        return x

class SteN2F3Quantizer(nn.Module):
    def __init__(self, q_group_size=128):
        super().__init__()
        self.q_group_size = q_group_size
    
    def forward(self, x):
        org_w_shape = x.shape

        # reshape to groupsize
        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            qx = x.reshape(-1, self.q_group_size)
        elif self.q_group_size == -1:
            qx = x.reshape(-1, x.shape[-1])
        assert qx.dim() == 2

        # Get the Min Max
        max_val = qx.amax(dim=1, keepdim=True)
        min_val = qx.amin(dim=1, keepdim=True)

        
        scale_pos = torch.abs(max_val)
        scale_neg = torch.abs(min_val)

        dev = qx.device
        x_pos = torch.zeros_like(qx)
        x_neg = torch.zeros_like(qx)
        x_pos = torch.where(qx >= 0, qx, x_pos)
        x_neg = torch.where(qx < 0, qx, x_neg)
        q_pos = x_pos / scale_pos
        q_neg = x_neg / scale_neg

        q_pos, q_neg = self.round_pass(q_pos, q_neg, dev)

        qx = q_pos * scale_pos + q_neg * scale_neg

        qx = qx.reshape(org_w_shape)

        return qx
    
    def round_n2f3(self, q_pos, q_neg, dev):
        q_pos = torch.where(q_pos >= 0.8114928305149078,                                        torch.tensor(1.0).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.8114928305149078)    & (q_pos >= 0.5024898052215576),    torch.tensor(0.6229856610298157).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.5024898052215576)    & (q_pos >= 0.2826657369732857),    torch.tensor(0.3819939494132996).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.2826657369732857)    & (q_pos >= 0.0916687622666359),    torch.tensor(0.1833375245332718).to(dev), q_pos)
        q_pos = torch.where(q_pos < 0.0916687622666359,                                        torch.tensor(0).to(dev), q_pos)

        q_neg = torch.where(q_neg >= -0.1234657019376755,                                     torch.tensor(0).to(dev), q_neg)
        q_neg = torch.where((q_neg < -0.1234657019376755)   & (q_neg >= -0.39097706973552704),   torch.tensor(-0.2469314038753510).to(dev), q_neg)
        q_neg = torch.where((q_neg < -0.39097706973552704)   & (q_neg >= -0.7675113677978516),   torch.tensor(-0.5350227355957031).to(dev), q_neg)
        q_neg = torch.where(q_neg < -0.7675113677978516,                                        torch.tensor(-1.0).to(dev), q_neg)

        return q_pos, q_neg

    def round_pass(self, q_pos, q_neg, dev):
        y_grad_pos, y_grad_neg = q_pos, q_neg
        y_pos, y_neg = self.round_n2f3(q_pos, q_neg, dev)
        
        return (y_pos - y_grad_pos).detach() + y_grad_pos, (y_neg - y_grad_neg).detach() + y_grad_neg

推荐阅读指数:✭✭✭✭✩

推荐理由

  • 创新性:BitDistiller通过结合QAT和KD,在亚4比特量化领域提供了一种新的解决方案,具有显著的性能提升。
  • 实用性:BitDistiller不仅在理论上具有创新性,而且在实际应用中也显示出了成本效益,这对于资源受限的设备尤为重要。
  • 广泛适用性:BitDistiller在多种语言和推理任务中都展现出了优越的性能,表明其方法的广泛适用性。

后记

如果您对我的博客内容感兴趣,欢迎三连击(点赞、收藏、关注和评论),我将持续为您带来计算机人工智能前沿技术(尤其是AI相关的大语言模型,深度学习和计算机视觉相关方向)最新学术论文及工程实践方面的内容分享,助力您更快更准更系统地了解 AI前沿技术

标签:group,语言,neg,模型,torch,pos,size,self,07
From: https://blog.csdn.net/fyf2007/article/details/143468096

相关文章

  • 大模型的上下文学习
    文章目录上下文学习的形式化定义示例设计底层机制    在GPT-3的论文中,OpenAI研究团队首次提出上下文学习(In-contextlearning,ICL)这种特殊的提示形式。目前,上下文学习已经成为使用大语言模型解决下游任务的一种主流途径。上下文学习的形式化定义   ......
  • 大型语言模型在逻辑推理中的记忆化现象:是真推理还是死记硬背?
    大家好!今天我们要聊一个有趣的话题:那些聪明绝顶的大型语言模型(LLM),它们在解决逻辑推理问题时,究竟是像福尔摩斯一样抽丝剥茧、层层推理,还是像考试前临时抱佛脚的学生一样,只会死记硬背?......
  • 大模型的微调新思路:XGBLoRA的崛起
    ......
  • 引导自我改进:缓解大型语言模型的尾部收敛现象
    在当今人工智能的浪潮中,大型语言模型(LLMs)犹如一颗璀璨的明珠,凭借其出色的推理能力和自我改进的潜力,一直备受瞩目。然而,这颗明珠的光彩却并非始终如一。随着自我改进的迭代进行,性能的提升似乎逐渐趋于平稳,甚至出现了“尾部收敛”的现象。这就像一位优秀的学生,在学习过程中逐......
  • ​Cell Res | 首个知识与数据联合驱动的跨物种生命基础大模型GeneCompass:解析基因调控
    近年来,大语言模型(LLMs)已在自然语言、计算机视觉等通用领域引发了新一轮技术革命,通过大规模语料和模型参数进行预训练,LLMs能够掌握语言的共性规律,能够对多种下游任务产生质的提升,已经形成了新的人工智能范式。在生命科学领域,单细胞组学技术的突破产生了大量不同物种细胞的基因表达......
  • MT1421-MT1430 码题集 (c 语言详解)
    目录        MT1421·异或        MT1422·总位数        MT1423·被3整除        MT1424·卡特兰序列        MT1425·小码哥的序列        MT1426·普洛尼克数        MT1427·素数序列        MT1......
  • 主流AI Agent框架对比,让你轻松构建企业专属大模型!
    大模型的出现为AIAgent提供了足够聪明的“大脑”,并重新定义了AIAgent。各大科技公司正在投入巨额资金来创建AIAgent,包括OpenAI的SamAltman在内的许多专家都表示,AIAgent已成为下一个大热门方向。AIAgent是感知环境并采取行动以实现特定目标或目的的软件或系统。可以......
  • ​Meta AI推出思维偏好优化技术,提升AI模型回应质量
    近日,MetaAI的研究团队与加州大学伯克利分校及纽约大学的研究人员合作,推出了一种名为思维偏好优化(ThoughtPreferenceOptimization,TPO)的方法,旨在提升经过指令微调的大型语言模型(LLM)的回应质量。与传统模型仅关注最终答案不同,TPO方法允许模型在生成回应前进行内部思考......
  • C语言版数据结构算法(考研初试版—3)--链表定义、创建
    2、链表1、链表结构体typedefstructLNode{   intdata;   structLNode*next; }LNode,*LinkList; 2、遍历链表voidPrintList(LinkListL){   LinkListp=L->next;//Skiptheheadnodewhichisadummynode   while(p!=......