首页 > 其他分享 >JAX-中文文档-二-

JAX-中文文档-二-

时间:2024-06-21 14:24:48浏览次数:19  
标签:中文 jnp JAX jax tree jit 文档 print

JAX 中文文档(二)

原文:jax.readthedocs.io/en/latest/

JAX 教程

原文:jax.readthedocs.io/en/latest/tutorials.html

  • 快速入门
  • 关键概念

  • 即时编译

  • 自动向量化

  • 自动微分

  • 调试入门

  • 伪随机数

  • 使用 pytrees 工作

  • 分片计算入门

  • 有状态计算

关键概念

原文:jax.readthedocs.io/en/latest/key-concepts.html

本节简要介绍了 JAX 包的一些关键概念。

JAX 数组 (jax.Array)

JAX 中的默认数组实现是 jax.Array。在许多方面,它与您可能熟悉的 NumPy 包中的 numpy.ndarray 类型相似,但它也有一些重要的区别。

数组创建

我们通常不直接调用 jax.Array 构造函数,而是通过 JAX API 函数创建数组。例如,jax.numpy 提供了类似 NumPy 风格的数组构造功能,如 jax.numpy.zeros()jax.numpy.linspace()jax.numpy.arange() 等。

import jax
import jax.numpy as jnp

x = jnp.arange(5)
isinstance(x, jax.Array) 
True 

如果您在代码中使用 Python 类型注解,jax.Array 是 jax 数组对象的适当注释(参见 jax.typing 以获取更多讨论)。

数组设备和分片

JAX 数组对象具有一个 devices 方法,允许您查看数组内容存储在哪里。在最简单的情况下,这将是单个 CPU 设备:

x.devices() 
{CpuDevice(id=0)} 

一般来说,数组可能会在多个设备上 分片,您可以通过 sharding 属性进行检查:

x.sharding 
SingleDeviceSharding(device=CpuDevice(id=0)) 

在这里,数组位于单个设备上,但通常情况下,JAX 数组可以分布在多个设备或者多个主机上。要了解更多关于分片数组和并行计算的信息,请参阅分片计算介绍## 变换

除了用于操作数组的函数外,JAX 还包括许多用于操作 JAX 函数的变换。这些变换包括

  • jax.jit(): 即时(JIT)编译;参见即时编译

  • jax.vmap(): 向量化变换;参见自动向量化

  • jax.grad(): 梯度变换;参见自动微分

以及其他几个。变换接受一个函数作为参数,并返回一个新的转换后的函数。例如,这是您可能如何对一个简单的 SELU 函数进行 JIT 编译:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0)) 
1.05 

通常情况下,您会看到使用 Python 的装饰器语法来应用变换以方便操作:

@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) 

jit()vmap()grad() 等变换对于有效使用 JAX 至关重要,我们将在后续章节中详细介绍它们。## 跟踪

变换背后的魔法是跟踪器的概念。跟踪器是数组对象的抽象替身,传递给 JAX 函数,以提取函数编码的操作序列。

您可以通过打印转换后的 JAX 代码中的任何数组值来看到这一点;例如:

@jax.jit
def f(x):
  print(x)
  return x + 1

x = jnp.arange(5)
result = f(x) 
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)> 

打印的值不是数组 x,而是代表 x 的关键属性的 Tracer 实例,比如它的 shapedtype。通过使用追踪值执行函数,JAX 可以确定函数编码的操作序列,然后在实际执行这些操作之前执行转换:例如 jit()vmap()grad() 可以将输入操作序列映射到变换后的操作序列。 ## Jaxprs

JAX 对操作序列有自己的中间表示形式,称为 jaxpr。jaxpr(JAX exPRession 的缩写)是一个函数程序的简单表示,包含一系列原始操作。

例如,考虑我们上面定义的 selu 函数:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) 

我们可以使用 jax.make_jaxpr() 实用程序来将该函数转换为一个 jaxpr,给定特定的输入:

x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x) 
{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0
    c:f32[5] = exp a
    d:f32[5] = mul 1.6699999570846558 c
    e:f32[5] = sub d 1.6699999570846558
    f:f32[5] = pjit[
      name=_where
      jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
          j:f32[5] = select_n g i h
        in (j,) }
    ] b a e
    k:f32[5] = mul 1.0499999523162842 f
  in (k,) } 

与 Python 函数定义相比,可以看出它编码了函数表示的精确操作序列。我们稍后将深入探讨 JAX 内部的 jaxprs:jaxpr 语言。 ## Pytrees

JAX 函数和转换基本上操作数组,但实际上编写处理数组集合的代码更为方便:例如,神经网络可能会将其参数组织在具有有意义键的数组字典中。与其逐案处理这类结构,JAX 依赖于 pytree 抽象来统一处理这些集合。

以下是一些可以作为 pytrees 处理的对象的示例:

# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params)) 
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)] 
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params)) 
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5] 
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params)) 
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0] 

JAX 提供了许多用于处理 PyTrees 的通用实用程序;例如函数 jax.tree.map() 可以用于将函数映射到树中的每个叶子,而 jax.tree.reduce() 可以用于在树中的叶子上应用约简操作。

你可以在《使用 pytrees 教程》中了解更多信息。

即时编译

原文:jax.readthedocs.io/en/latest/jit-compilation.html

在这一部分,我们将进一步探讨 JAX 的工作原理,以及如何使其性能卓越。我们将讨论 jax.jit() 变换,它将 JAX Python 函数进行即时编译,以便在 XLA 中高效执行。

如何工作 JAX 变换

在前一节中,我们讨论了 JAX 允许我们转换 Python 函数的能力。JAX 通过将每个函数减少为一系列原始操作来实现这一点,每个原始操作代表一种基本的计算单位。

查看函数背后原始操作序列的一种方法是使用 jax.make_jaxpr()

import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0)) 
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) } 

文档的理解 Jaxprs 部分提供了有关上述输出含义的更多信息。

重要的是要注意,jaxpr 不捕获函数中存在的副作用:其中没有对 global_list.append(x) 的任何内容。这是一个特性,而不是一个错误:JAX 变换旨在理解无副作用(也称为函数纯粹)的代码。如果 纯函数副作用 是陌生的术语,这在

标签:中文,jnp,JAX,jax,tree,jit,文档,print
From: https://www.cnblogs.com/apachecn/p/18260391

相关文章

  • JAX-中文文档-八-
    JAX中文文档(八)原文:jax.readthedocs.io/en/latest/自动微分手册原文:jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.htmlalexbw@,mattjj@JAX拥有非常通用的自动微分系统。在这本手册中,我们将介绍许多巧妙的自动微分思想,您可以根据自己的工作进行选择。i......
  • prometheus 中文说明告警指标
    https://blog.51cto.com/qiangsh/1977449主机和硬件监控可用内存指标主机中可用内存容量不足10%-alert:HostOutOfMemoryexpr:node_memory_MemAvailable_bytes/node_memory_MemTotal_bytes*100<10for:5mlabels:severity:warningannotations:......
  • 软件开发项目全套文档资料参考(规格说明书、详细设计、测试计划、验收报告)
     前言:在软件开发过程中,文档资料是非常关键的一部分,它们帮助团队成员理解项目需求、设计、实施、测试、验收等各个环节,确保项目的顺利进行。以下是各个阶段的文档资料概述:软件项目管理部分文档清单: 工作安排任务书,可行性分析报告,立项申请审批表,产品需求规格说明书,需求调研......
  • 中文检测插件
    大家都知道,做出海应用,尤其是在一些对中国不友好的国家做业务。全面去中文化至关重要。对于开发而言,在代码层如果只靠人为控制这个变量,尤其艰难。所以给大家安利一个我们自研的中文检测插件,他能在您开发过程中时刻检测您的输入是否含有中文。大家先看下效果。如果您有需要,烦......
  • 搜索硬编码中文
    老项目中常常有直接在代码里或者xml布局中硬编码中文的,在后期业务扩展做国际化翻译时,这就是一个巨大的坑,因为我们需要知道哪里硬编码了,然后提取到strings.xml中刚好我最近在弄这个,如何找到代码中所有的硬编码就是核心问题,下面记录下我的步骤 1.首先写好正则,直接百度也行^((?!......
  • MestReNova14.0中文版安装教程
    MestReNova14是一款专业级的核磁共振(NMR)与质谱(MS)数据分析软件,专注于化合物结构解析和验证。该软件以卓越的谱图处理能力和智能化算法为核心,提供自定义参数调整、自动峰识别、精准积分、耦合常数计算等功能。支持多种仪器数据格式导入,可高效处理一维至四维NMR谱图以及各类质谱数据......
  • PDF英语文档怎么翻译成中文?
    外语文献是我们学习和工作中经常遇到的难题,其中包含许多重要工作信息,精确地理解和翻译非常重要。但并不是所有格式的文件都能直接编辑和翻译。例如PDF格式的文件就无法直接进行编辑,当我们需要翻译PDF格式的外语文档时,应该使用什么工具呢?本篇文章就为你提供几个快速翻译PDF文件的方......
  • mac苹果窗口辅助工具:Magnet for mac 2.14.0中文免激活版
    Magnet是一款针对MacOS系统的窗口管理工具软件。它能够帮助用户更加高效地管理和组织桌面上的窗口,通过简单的快捷键操作,可以将窗口自动调整到指定的位置和大小,实现多窗口快速布局。Magnet还支持多显示器环境下的窗口管理,可以让用户更加轻松地在多屏幕之间切换和布局窗口。......
  • 032java jsp ssm大学生第二课堂成绩单系统学生思想道德技术修养文体活动管理(源码+数据
     项目技术:SSM+Maven等等组成,B/S模式+Maven管理等等。环境需要1.运行环境:最好是javajdk1.8,我们在这个平台上运行的。其他版本理论上也可以。2.IDE环境:IDEA,Eclipse,Myeclipse都可以。推荐IDEA;3.tomcat环境:Tomcat7.x,8.x,9.x版本均可4.硬件环境:windows7/8/101G......
  • 026java jsp ssm网络硬盘系统网站系统(源码+数据库+文档)
     项目技术:SSM+Maven等等组成,B/S模式+Maven管理等等。环境需要1.运行环境:最好是javajdk1.8,我们在这个平台上运行的。其他版本理论上也可以。2.IDE环境:IDEA,Eclipse,Myeclipse都可以。推荐IDEA;3.tomcat环境:Tomcat7.x,8.x,9.x版本均可4.硬件环境:windows7/8/101G......