首页 > 数据库 >强化学习SQL算法(soft q learning)—— SVGD的实现(Stein Variational Gradient Descent: A General Purpo

强化学习SQL算法(soft q learning)—— SVGD的实现(Stein Variational Gradient Descent: A General Purpo

时间:2024-12-22 13:30:49浏览次数:3  
标签:kernel Inference Descent Stein shape tf xs ys Kx

代码实现地址:

https://openi.pcl.ac.cn/devilmaycry812839668/softlearning/src/branch/master/softlearning/misc/kernel.py


from distutils.version import LooseVersion

import numpy as np
import tensorflow as tf


def adaptive_isotropic_gaussian_kernel(xs, ys, h_min=1e-3):
    """Gaussian kernel with dynamic bandwidth.

    The bandwidth is adjusted dynamically to match median_distance / log(Kx).
    See [2] for more information.

    Args:
        xs(`tf.Tensor`): A tensor of shape (N x Kx x D) containing N sets of Kx
            particles of dimension D. This is the first kernel argument.
        ys(`tf.Tensor`): A tensor of shape (N x Ky x D) containing N sets of Kx
            particles of dimension D. This is the second kernel argument.
        h_min(`float`): Minimum bandwidth.

    Returns:
        `dict`: Returned dictionary has two fields:
            'output': A `tf.Tensor` object of shape (N x Kx x Ky) representing
                the kernel matrix for inputs `xs` and `ys`.
            'gradient': A 'tf.Tensor` object of shape (N x Kx x Ky x D)
                representing the gradient of the kernel with respect to `xs`.

    Reference:
        [2] Qiang Liu,Dilin Wang, "Stein Variational Gradient Descent: A General
            Purpose Bayesian Inference Algorithm," Neural Information Processing
            Systems (NIPS), 2016.
    """
    Kx, D = xs.get_shape().as_list()[-2:]
    Ky, D2 = ys.get_shape().as_list()[-2:]
    assert D == D2

    leading_shape = tf.shape(input=xs)[:-2]

    # Compute the pairwise distances of left and right particles.
    diff = tf.expand_dims(xs, -2) - tf.expand_dims(ys, -3)
    # ... x Kx x Ky x D

    if LooseVersion(tf.__version__) <= LooseVersion('1.5.0'):
        dist_sq = tf.reduce_sum(input_tensor=diff**2, axis=-1, keepdims=False)
    else:
        dist_sq = tf.reduce_sum(input_tensor=diff**2, axis=-1, keepdims=False)
    # ... x Kx x Ky

    # Get median.
    input_shape = tf.concat((leading_shape, [Kx * Ky]), axis=0)
    values, _ = tf.nn.top_k(
        input=tf.reshape(dist_sq, input_shape),
        k=(Kx * Ky // 2 + 1),  # This is exactly true only if Kx*Ky is odd.
        sorted=True)  # ... x floor(Ks*Kd/2)

    medians_sq = values[..., -1]  # ... (shape) (last element is the median)

    h = medians_sq / np.log(Kx)  # ... (shape)
    h = tf.maximum(h, h_min)
    h = tf.stop_gradient(h)  # Just in case.
    h_expanded_twice = tf.expand_dims(tf.expand_dims(h, -1), -1)
    # ... x 1 x 1

    kappa = tf.exp(-dist_sq / h_expanded_twice)  # ... x Kx x Ky

    # Construct the gradient
    h_expanded_thrice = tf.expand_dims(h_expanded_twice, -1)
    # ... x 1 x 1 x 1
    kappa_expanded = tf.expand_dims(kappa, -1)  # ... x Kx x Ky x 1

    kappa_grad = -2 * diff / h_expanded_thrice * kappa_expanded
    # ... x Kx x Ky x D

    return {"output": kappa, "gradient": kappa_grad}

标签:kernel,Inference,Descent,Stein,shape,tf,xs,ys,Kx
From: https://www.cnblogs.com/xyz/p/18622032

相关文章

  • USACO备考冲刺必刷题 | P1460 Healthy Holsteins
    学习C++从娃娃抓起!记录下USACO(美国信息学奥赛)备考学习过程中的题目,记录每一个瞬间。附上汇总贴:USACO备考冲刺必刷题|汇总-CSDN博客【题目描述】农民John以拥有世界上最健康的奶牛为傲。他知道每种饲料中所包含的牛所需的最低的维他命量是多少。请你帮助农夫喂养他的牛,......
  • Xinference环境搭建&推理测试
    引子写了很多篇开源大模型的环境部署与推理搭建,截止到目前,开源大模型已经发展较为完善。个人觉得,产品和项目维度来看更多的是如果去落地实现,也就是大模型的最后一公里的应用开发。最近看到Xinference一个开源很火的推理框架。OK,那就让我们开始吧。一、框架介绍Xinference:一款性......
  • An Active Inference Strategy for Prompting Reliable Responses from Large Languag
    本文是LLM系列文章,针对《AnActiveInferenceStrategyforPromptingReliableResponsesfromLargeLanguageModelsinMedicalPractice》的翻译。在医疗实践中促进大型语言模型做出可靠响应的主动推理策略摘要1引言2方式3方法4结果5讨论摘要人工......
  • 昇腾920B2成功运行bge-large-zh-v1.5后(text embeddings inference方式,也被称为TEI),如何
    文章目录引言什么是bge-large-zh-v1.5?在昇腾920B2上运行bge-large-zh-v1.5编写fastapi服务,将TEI转化成兼容OpenAI的方式将模型注册到dify结论引言在上一篇中,我们抱着侥幸的,试一试的心态,竟然真的用昇腾显卡跑通了用于embedding的bge-large-zh-v1.5......
  • 题解:AT_abc368_d[ABC368D] Minimum Steiner Tree
    题目大意题目给定一棵由$N$个节点组成的无根树,删除其中的一些点和边,使剩下的点和边仍然能够组成一棵树,且包含给定的$K$个特殊点,问最少剩下几个点。思路我们可以发现,这棵无根树的根必须是给定的特殊点之一,不然根节点就可以删除,答案就不是最优。所以我们使用深度优先搜索遍......
  • Explicit Inductive Inference using Large Language Models
    本文是LLM系列文章,针对《ExplicitInductiveInferenceusingLargeLanguageModels》的翻译。使用大型语言模型进行显式归纳推理摘要1引言2相关工作3显示归纳推理4实验设置5结果和讨论6结论局限性摘要据报道,大型语言模型(LLM)在推理任务上存在不......
  • 如何在生成式AI里使用 Ray Data 进行大规模 RAG 应用的 Embedding Inference
    检索增强生成(RAG,即RetrievalAugmentedGeneration)是企业级生成式AI(GenAI)应用的热门案例之一。多数RAG教程演示了如何利用OpenAIAPI结合Embedding模型和大语言模型(LLM)来进行推理(Inference)。然而,在开发过程中,如果能使用开源工具,就可以免去访问自己数据的费用,同时也能加......
  • FastGPT一站式解决方案[1-部署篇]:轻松实现RAG-智能问答系统(含sealos云端部署、docker
    FastGPT一站式解决方案[1-部署篇]:轻松实现RAG-智能问答系统(含sealos云端部署、docker部署、OneAPI&Xinference模型接入)FastGPT是一个功能强大的平台,专注于知识库训练和自动化工作流程的编排。它提供了一个简单易用的可视化界面,支持自动数据预处理和基于Flow模块的工作流编排。Fas......
  • Einstein 大战 Eisenstein 大战 Eppstein
    \[\newcommand{\Co}{\operatornameC}\newcommand{\Am}{\operatornameA}\newcommand{\Vo}{\operatornameV}\newcommand{\Me}{\operatornamem}\newcommand{\Se}{\operatornames}\newcommand{\Ne}{\operatornameN}\newcommand{\Fa}{\operatorn......
  • SciTech-Mathmatics-Probability+Statistics: Statistical Inference统计推断- Estima
    轻松学统计:https://zh-cn.statisticseasily.com/词汇表/什么是统计推断/StatisticalInference:SI(统计推断)的类型SI(统计推断)主要有两种类型:Estimation:根据样本数据确定总体的特征;PointEstimation:提供总体参数的单一值估计;ConfidenceInterval:提供......