首页 > 其他分享 >jax框架为例:求hession矩阵时前后向模式的自动求导的性能差别

jax框架为例:求hession矩阵时前后向模式的自动求导的性能差别

时间:2024-01-27 20:05:05浏览次数:13  
标签:jnp jax 为例 jacrev hession 求导 time print hessian

注意:本文相关基础知识不介绍。


给出代码:

from jax import jacfwd, jacrev
import jax.numpy as jnp
def hessian_1(f):
    return jacfwd(jacrev(f))

def hessian_2(f):
    return jacfwd(jacfwd(f))

def hessian_3(f):
    return jacrev(jacfwd(f))

def hessian_4(f):
    return jacrev(jacrev(f))


def f(x):
    return (x ** 2).sum()

print(hessian_1(f)(jnp.ones((100,))))
print(hessian_2(f)(jnp.ones((100,))))
print(hessian_3(f)(jnp.ones((100,))))
print(hessian_4(f)(jnp.ones((100,))))

import time

a=time.time()
hessian_1(f)(jnp.ones((100,)))
b=time.time()
print(b-a)

hessian_2(f)(jnp.ones((100,)))
c=time.time()
print(c-b)

hessian_3(f)(jnp.ones((100,)))
d=time.time()
print(d-b)

hessian_4(f)(jnp.ones((100,)))
e=time.time()
print(e-d)


运算结果:

jax框架为例:求hession矩阵时前后向模式的自动求导的性能差别_系统

jax框架为例:求hession矩阵时前后向模式的自动求导的性能差别_系统_02

jax框架为例:求hession矩阵时前后向模式的自动求导的性能差别_系统_03


结论(不一定正确):

两次求导均使用后向模式的要比两次求导均使用前向模式的要速度快,并且两次求导使用相同模式的要比两次求导分别使用不同模式的速度要快;

第一次求导使用后向模式,第二次求导使用前向模式,要比第一次求导使用前向模式,第二次求导使用反向模式的速度要快。



修改代码:

from jax import jacfwd, jacrev
import jax.numpy as jnp
from jax import jit

def hessian_1(f):
    return jacfwd(jacrev(f))

def hessian_2(f):
    return jacfwd(jacfwd(f))

def hessian_3(f):
    return jacrev(jacfwd(f))

def hessian_4(f):
    return jacrev(jacrev(f))


@jit
def f(x):
    return (x ** 2).sum()

x = jnp.ones((100,))

print(hessian_1(f)(x))
print(hessian_2(f)(x))
print(hessian_3(f)(x))
print(hessian_4(f)(x))

import time

a=time.time()
hessian_1(f)(x)
b=time.time()
print(b-a)

hessian_2(f)(x)
c=time.time()
print(c-b)

hessian_3(f)(x)
d=time.time()
print(d-b)

hessian_4(f)(x)
e=time.time()
print(e-d)


运算结果:

jax框架为例:求hession矩阵时前后向模式的自动求导的性能差别_系统_04


得出另一种结论(之所以上下两次结论不同,个人估计是这个函数太过于简单造成的):

(不一定正确)

两次求导均使用后向模式的要比两次求导均使用前向模式的要速度慢;

第一次求导使用后向模式,第二次求导使用前向模式,要比第一次求导使用前向模式,第二次求导使用反向模式的速度要快。



标签:jnp,jax,为例,jacrev,hession,求导,time,print,hessian
From: https://blog.51cto.com/u_15642578/9444218

相关文章

  • 使用CPU运行大语言模型(LLM),以清华开源大模型ChatGLM3为例:无需显卡!用CPU搞定大模型运行
    教程视频地址:无需显卡!用CPU搞定大模型运行部署!【详细手把手演示】按照上面视频进行安装配置之前需要注意,python编程环境需要大于等于python3.10,否则会运行报错。下载好GitHub上的项目代码后需要运行pipinstall-rrequirements.txt配置好后运行效果:相关资料:【ChatGL......
  • jax框架为例:求hession矩阵时前后向模式的自动求导的性能差别
    注意:本文相关基础知识不介绍。给出代码:fromjaximportjacfwd,jacrevimportjax.numpyasjnpdefhessian_1(f):returnjacfwd(jacrev(f))defhessian_2(f):returnjacfwd(jacfwd(f))defhessian_3(f):returnjacrev(jacfwd(f))defhessian_4(f):......
  • 2024-1-23AJAX的概念
    目录AJAX的概念小知识点箭头函数AJAX的作用axios的使用AJAX的概念简单可以理解为想指定的url获取指定的数据。小知识点箭头函数箭头函数是一种新的函数语法,旨在提供一种更简洁的方式来编写函数。它与传统的function相比比较容易传统函数格式varsum=function(a,b){r......
  • jax框架:jax.grad
    官方地址:https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad这里只给出几个样例代码:设置allow_int参数,实现对整数类型求导:未对整数类型求导:importjaxdeffun(x,y):print(x,y)returnjax.numpy.sum(2*x[0]+y[0]+2*x[1]+......
  • springmvc与ajax相互调用, 2.springmvc中如何拦截get请求
    通过JACKson框架可以把java里面的对象转化成js可以识别的json对象。具体步骤:1.加入Jack.jar2.在配置文件配置json映射3.在接受ajax方法里面一颗直接返回OBject,list等,但方法加@ResponseBody注解。  @RequestMapping注解中加上method=RequestMethod.GET参数就可以实现拦......
  • Jax框架:通过显存分析判断操作是否进行jit编译
    相关:https://jax.readthedocs.io/en/latest/device_memory_profiling.html代码:importjaximportjax.numpyasjnpimportjax.profilerdeffunc1(x):returnjnp.tile(x,10)*0.5deffunc2(x):y=func1(x)returny,jnp.tile(x,10)+1x=jax.random.......
  • Ajax(千锋)
    目录Ajax技术一.初识前后端交互AJAX的优势二.原生Ajax1.AJAX基础创建一个ajax对象配置链接信息发送请求一个基本的ajax请求ajax状态码readyStateChangeresponseText2.使用ajax发送请求时携带参数发送一个带有参数的get请求发送一个带有参数的post请求不同的请求......
  • Google的Jax框架的JAX-Triton目前只能成功运行在TPU设备上(使用Pallas为jax编写kernel
    使用Pallas为jax编写kernel扩展,需要使用JAX-Triton扩展包。由于Google的深度学习框架Jax主要是面向自己的TPU进行开发的,虽然也同时支持NVIDIA的GPU,但是支持力度有限,目前JAX-Triton只能在TPU设备上正常运行,无法保证在GPU上正常运行。该结果使用kaggle上的TPU和GPU进行测试获得。......
  • kaggle上的jax框架的环境配置(TPU版本)
    导出时间:2024-01-1821:00:37星期四python版本:Python3.10.13absl-py==1.4.0accelerate==0.25.0aiofiles==22.1.0aiosqlite==0.19.0anyio==4.2.0argon2-cffi==23.1.0argon2-cffi-bindings==21.2.0array-record==0.5.0arrow==1.3.0astroid==3.0.2asttokens==2.4......
  • 以新晋高速公路快村营至营盘段项目为例浅谈AcrelEMS-HIM高速公路综合能效系统的应用
    引言摘要:我国新型工业化、信息化、城镇化和农业现代化加快发展,经济结构加快转型,交通运输总量将保持较快增长态势,各项事业发展要求提高国家公路网的服务能力和水平。高速公路沿线的收费站、互通枢纽、服务区、隧道等配置的供配电、照明、通风、排水等机电设备的数量急聚增加,设计一套......