首页 > 其他分享 >可微TopK算子

可微TopK算子

时间:2024-09-25 17:03:15浏览次数:1  
标签:可微 torch text Delta TopK vec 算子 xs grad

形式及推导

形式:前向计算如下所示,

\[\text{TopK}(\vec{x}, k) = \sigma(\vec{x}+\Delta(\vec{x}, k)) \]

注意\(\Delta(\cdot)\)满足限制条件\(\sum \Delta(\vec{x}, k) = k\),并且\(\sigma(x) = \frac{1}{1+\exp\{-x\}}\)


梯度推导:
令\(f(\vec{x}, k) = \sigma(\vec{x}+\Delta(\vec{x}, k))\)

\[\frac{ \text{d} f(\vec{x}, k)_{i} }{ \text{d} x_{j} } = \frac{ \text{d} \sigma(x_{i}+\Delta(\vec{x}, k)) }{ \text{d} x_{j} } = \sigma'(x_i + \Delta(\vec{x})) \Big( \mathbb{I}_{i=j} + \frac{\text{d}\Delta(\vec{x})}{\text{d}x_j} \Big) \]

难点在于如何计算\(\frac{\text{d}\Delta(\vec{x})}{\text{d}x_j}\)。

我们通过利用条件\(\sum \Delta(\vec{x}) = k\)来计算上述导数:

\[\frac{\text{d}k}{\text{d}x_j} = 0 = \sum_{i}\sigma'(x_i+\Delta(\vec{x}))\Big( \mathbb{I}_{i=j} + \frac{\text{d}{\Delta(\vec{x})}}{\text{d}x_{j}} \Big) = \sigma'(x_{\color\red j}+\Delta(\vec{x})) + \frac{\text{d}{\Delta(\vec{x})}}{\text{d}x_{j}} \sum_{i}\sigma'(x_i+\Delta(\vec{x})) \]

因此,我们可以得到:

\[\frac{\text{d}{\Delta(\vec{x})}}{\text{d}x_{j}} = \frac{ - \sigma'(x_{\color\red j} +\Delta(\vec{x})) }{ \sum_{i}\sigma'(x_i+\Delta(\vec{x})) } \]

向量版本:如果令\(v = \sigma'(\vec{x}+\Delta(\vec{x}))\),则雅可比矩阵为

\[J_{\text{TopK}}(\vec{x}) = \text{diag}(\vec{v}) - \frac{\vec{v}\vec{v}^{\top}}{\Vert\vec{v}\Vert_1} \]

其他细节:如何计算出\(\Delta(\vec{x})=k\)?可以通过二分法快速找到该函数的合适值。

实现

# %% differentiable top-k function
import torch
from torch.func import vmap, grad
from torch.autograd import Function
import torch.nn as nn

sigmoid = torch.sigmoid
sigmoid_grad = vmap(vmap(grad(sigmoid)))


class TopK(Function):
    @staticmethod
    def forward(ctx, xs, k):
        ts, ps = _find_ts(xs, k)
        ctx.save_for_backward(xs, ts)
        return ps

    @staticmethod
    def backward(ctx, grad_output):
        # Compute vjp, that is grad_output.T @ J.
        xs, ts = ctx.saved_tensors
        # Let v = sigmoid'(x + t)
        v = sigmoid_grad(xs + ts)
        s = v.sum(dim=1, keepdims=True)
        # Jacobian is -vv.T/s + diag(v)
        uv = grad_output * v
        t1 = -uv.sum(dim=1, keepdims=True) * v / s
        return t1 + uv, None


@torch.no_grad()
def _find_ts(xs, k):
    # (batch_size, input_dim)
    _, n = xs.shape
    assert 0 < k < n
    # Lo should be small enough that all sigmoids are in the 0 area.
    # Similarly Hi is large enough that all are in their 1 area.
    # (batch_size, 1)
    lo = -xs.max(dim=1, keepdims=True).values - 10
    hi = -xs.min(dim=1, keepdims=True).values + 10
    for iteration in range(64):
        mid = (hi + lo) / 2
        subject = sigmoid(xs + mid).sum(dim=1)
        mask = subject < k
        lo[mask] = mid[mask]
        hi[~mask] = mid[~mask]
    ts = (lo + hi) / 2
    return ts, sigmoid(xs + ts)


def test_check():
    topk = TopK.apply
    xs = torch.randn(2, 10)
    ps = topk(xs, 2)
    print(f"{xs=}")
    print(f"{ps=}")
    print(f"{ps.sum(dim=1)=}")

    from torch.autograd import gradcheck

    input = torch.randn(20, 10, dtype=torch.double, requires_grad=True)
    for k in range(1, 10):
        print(k, gradcheck(topk, (input, k), eps=1e-6, atol=1e-4))


def sgd_update():
    topk = TopK.apply
    batch_size = 2
    k = 2
    tau = 10
    xs = torch.randn(batch_size, 10, dtype=torch.double, requires_grad=True)
    target = torch.zeros_like(xs)
    target[torch.arange(batch_size), torch.argsort(xs, descending=True)[:, :k].T] = 1.0
    print(f"{xs=}")
    print(f"{target=}")
    loss_fn = nn.MSELoss()
    learning_rate = 1

    def fn(x):
        x = x * tau
        return topk(x, k)

    for iteration in range(1, 1000 + 1):
        ws = nn.Parameter(data=xs, requires_grad=True)
        ps = fn(ws)
        loss = loss_fn(ps.view(-1), target.view(-1))
        loss.backward()
        xs = ws - learning_rate * ws.grad
        if iteration % 100 == 0:
            print(f"{iteration=} {fn(xs)=}")


sgd_update()

相关资料

Differentiable top-k function - Stach Exchange
Softmax后传:寻找Top-K的光滑近似 - 科学空间

标签:可微,torch,text,Delta,TopK,vec,算子,xs,grad
From: https://www.cnblogs.com/WrRan/p/18431596

相关文章

  • 【算法】topk之字节题
    1.合并两个有序列表......
  • 大数据-128 - Flink 并行度设置 细节详解 全局、作业、算子、Slot
    点一下关注吧!!!非常感谢!!持续更新!!!目前已经更新到了:Hadoop(已更完)HDFS(已更完)MapReduce(已更完)Hive(已更完)Flume(已更完)Sqoop(已更完)Zookeeper(已更完)HBase(已更完)Redis(已更完)Kafka(已更完)Spark(已更完)Flink(正在更新!)章节内容上节我们完成了如下的内容:ManageOperatorStateStateBackendCheckpoint......
  • 大数据-123 - Flink 并行度 相关概念 全局、作业、算子、Slot并行度 Flink并行度设置
    点一下关注吧!!!非常感谢!!持续更新!!!目前已经更新到了:Hadoop(已更完)HDFS(已更完)MapReduce(已更完)Hive(已更完)Flume(已更完)Sqoop(已更完)Zookeeper(已更完)HBase(已更完)Redis(已更完)Kafka(已更完)Spark(已更完)Flink(正在更新!)章节内容上节我们完成了如下的内容:FlinkTimeWatermarkJava代码实例测试简单介......
  • IP地址、地址分类、子网掩码、子网划分、使用Python计算子网划分
    IP地址(InternetProtocolAddress)乃是用于明确标识网络中各类设备的独一无二的地址。IP地址主要存在两种重要类型,即IPv4和IPv6。IPv4地址IPv4地址实则是一个由32位二进制数字所构成的标识,通常会以四个十进制数字的形式呈现出来,每一个数字均处于0至255的区间范围内,且通......
  • YOLOv9改进策略【Neck】| 有效且轻量的动态上采样算子:DySample
    一、本文介绍本文记录的是利用DySample上采样对YOLOv9的颈部网络进行改进的方法研究。YOLOv9采用传统的最近邻插值的方法进行上采样可能无法有效地捕捉特征的细节和语义信息,从而影响模型在密集预测任务中的性能。DySample通过动态采样的方式进行上采样,能够更好地处理特征的......
  • YOLOv9改进策略【Neck】| 使用CARAFE轻量级通用上采样算子
    一、本文介绍本文记录的是利用CARAFE上采样对YOLOv9的颈部网络进行改进的方法研究。YOLOv9采用传统的最近邻插值的方法,仅考虑子像素邻域,无法捕获密集预测任务所需的丰富语义信息,从而影响模型在密集预测任务中的性能。CARAFE通过在大感受野内聚合信息、能够实时适应实例特定......
  • Ascend C算子开发(中级)—— 编写Sinh算子
    AscendC算子开发(中级)——编写Sinh算子文章目录AscendC算子开发(中级)——编写Sinh算子准备工作香橙派与PC连接Add算子调用体验Sinh算子开发(AscendC算子开发中级认证考试内容)准备工作一块香橙派AIpro开发板,一根Type-c口的电源线,一根网线,一个网线转接器,一台......
  • 说说Canny边缘检测算子?
    Canny边缘检测算子什么是Canny边缘检测算子Canny边缘检测算子原理Canny边缘检测算子用途什么是Canny边缘检测算子Canny边缘检测算子是一种旨在以最优方式从图像中提取边缘信息的算法。其“最优”体现在三个方面:低错误率:算法应尽可能多地标识出图像中的实际边缘,同时......
  • 图像边缘检测技术详解:利用OpenCV实现Sobel算子
    图像边缘检测技术详解:利用OpenCV实现Sobel算子前言Sobel算子的原理代码演示结果展示结语前言  在数字图像处理的广阔领域中,边缘检测技术扮演着至关重要的角色。无论是在科学研究、工业自动化,还是在日常生活中的智能设备中,我们都需要从图像中提取有用的信息。边缘,作......
  • [Python手撕]TOPK
    TOPK问题描述:从arr[1,n]这n个数中,找出最大的k个数,这就是经典的TopK问题。栗子:从arr[1,12]={5,3,7,1,8,2,9,4,7,2,6,6}这n=12个数中,找出最大的k=5个。整体排序排序是最容易想到的方法,将n个数排序之后,取出最大的k个,即为所得。伪代码:sort(arr,1,n);returnarr[1,k];......