首页 > 其他分享 >jax框架为例:求hession矩阵时前后向模式的自动求导的性能差别

jax框架为例:求hession矩阵时前后向模式的自动求导的性能差别

时间:2024-01-24 19:47:04浏览次数:30  
标签:jnp jax 为例 jacrev hession 求导 time print hessian

注意:本文相关基础知识不介绍。


给出代码:

from jax import jacfwd, jacrev
import jax.numpy as jnp
def hessian_1(f):
    return jacfwd(jacrev(f))

def hessian_2(f):
    return jacfwd(jacfwd(f))

def hessian_3(f):
    return jacrev(jacfwd(f))

def hessian_4(f):
    return jacrev(jacrev(f))


def f(x):
    return (x ** 2).sum()

print(hessian_1(f)(jnp.ones((100,))))
print(hessian_2(f)(jnp.ones((100,))))
print(hessian_3(f)(jnp.ones((100,))))
print(hessian_4(f)(jnp.ones((100,))))

import time

a=time.time()
hessian_1(f)(jnp.ones((100,)))
b=time.time()
print(b-a)

hessian_2(f)(jnp.ones((100,)))
c=time.time()
print(c-b)

hessian_3(f)(jnp.ones((100,)))
d=time.time()
print(d-b)

hessian_4(f)(jnp.ones((100,)))
e=time.time()
print(e-d)

运算结果:
image

image

image


结论(不一定正确):

两次求导均使用后向模式的要比两次求导均使用前向模式的要速度快,并且两次求导使用相同模式的要比两次求导分别使用不同模式的速度要快;

第一次求导使用后向模式,第二次求导使用前向模式,要比第一次求导使用前向模式,第二次求导使用反向模式的速度要快。



修改代码:

from jax import jacfwd, jacrev
import jax.numpy as jnp
from jax import jit

def hessian_1(f):
    return jacfwd(jacrev(f))

def hessian_2(f):
    return jacfwd(jacfwd(f))

def hessian_3(f):
    return jacrev(jacfwd(f))

def hessian_4(f):
    return jacrev(jacrev(f))


@jit
def f(x):
    return (x ** 2).sum()

x = jnp.ones((100,))

print(hessian_1(f)(x))
print(hessian_2(f)(x))
print(hessian_3(f)(x))
print(hessian_4(f)(x))

import time

a=time.time()
hessian_1(f)(x)
b=time.time()
print(b-a)

hessian_2(f)(x)
c=time.time()
print(c-b)

hessian_3(f)(x)
d=time.time()
print(d-b)

hessian_4(f)(x)
e=time.time()
print(e-d)

运算结果:

image


得出另一种结论(之所以上下两次结论不同,个人估计是这个函数太过于简单造成的):

(不一定正确)

两次求导均使用后向模式的要比两次求导均使用前向模式的要速度慢;

第一次求导使用后向模式,第二次求导使用前向模式,要比第一次求导使用前向模式,第二次求导使用反向模式的速度要快。



标签:jnp,jax,为例,jacrev,hession,求导,time,print,hessian
From: https://www.cnblogs.com/devilmaycry812839668/p/17985697

相关文章

  • 2024-1-23AJAX的概念
    目录AJAX的概念小知识点箭头函数AJAX的作用axios的使用AJAX的概念简单可以理解为想指定的url获取指定的数据。小知识点箭头函数箭头函数是一种新的函数语法,旨在提供一种更简洁的方式来编写函数。它与传统的function相比比较容易传统函数格式varsum=function(a,b){r......
  • jax框架:jax.grad
    官方地址:https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad这里只给出几个样例代码:设置allow_int参数,实现对整数类型求导:未对整数类型求导:importjaxdeffun(x,y):print(x,y)returnjax.numpy.sum(2*x[0]+y[0]+2*x[1]+......
  • springmvc与ajax相互调用, 2.springmvc中如何拦截get请求
    通过JACKson框架可以把java里面的对象转化成js可以识别的json对象。具体步骤:1.加入Jack.jar2.在配置文件配置json映射3.在接受ajax方法里面一颗直接返回OBject,list等,但方法加@ResponseBody注解。  @RequestMapping注解中加上method=RequestMethod.GET参数就可以实现拦......
  • Jax框架:通过显存分析判断操作是否进行jit编译
    相关:https://jax.readthedocs.io/en/latest/device_memory_profiling.html代码:importjaximportjax.numpyasjnpimportjax.profilerdeffunc1(x):returnjnp.tile(x,10)*0.5deffunc2(x):y=func1(x)returny,jnp.tile(x,10)+1x=jax.random.......
  • Ajax(千锋)
    目录Ajax技术一.初识前后端交互AJAX的优势二.原生Ajax1.AJAX基础创建一个ajax对象配置链接信息发送请求一个基本的ajax请求ajax状态码readyStateChangeresponseText2.使用ajax发送请求时携带参数发送一个带有参数的get请求发送一个带有参数的post请求不同的请求......
  • Google的Jax框架的JAX-Triton目前只能成功运行在TPU设备上(使用Pallas为jax编写kernel
    使用Pallas为jax编写kernel扩展,需要使用JAX-Triton扩展包。由于Google的深度学习框架Jax主要是面向自己的TPU进行开发的,虽然也同时支持NVIDIA的GPU,但是支持力度有限,目前JAX-Triton只能在TPU设备上正常运行,无法保证在GPU上正常运行。该结果使用kaggle上的TPU和GPU进行测试获得。......
  • kaggle上的jax框架的环境配置(TPU版本)
    导出时间:2024-01-1821:00:37星期四python版本:Python3.10.13absl-py==1.4.0accelerate==0.25.0aiofiles==22.1.0aiosqlite==0.19.0anyio==4.2.0argon2-cffi==23.1.0argon2-cffi-bindings==21.2.0array-record==0.5.0arrow==1.3.0astroid==3.0.2asttokens==2.4......
  • 以新晋高速公路快村营至营盘段项目为例浅谈AcrelEMS-HIM高速公路综合能效系统的应用
    引言摘要:我国新型工业化、信息化、城镇化和农业现代化加快发展,经济结构加快转型,交通运输总量将保持较快增长态势,各项事业发展要求提高国家公路网的服务能力和水平。高速公路沿线的收费站、互通枢纽、服务区、隧道等配置的供配电、照明、通风、排水等机电设备的数量急聚增加,设计一套......
  • 以青岛公交车停车场为例浅谈电动汽车充电站的电气安全
    1引言1月14日日上午10点左右,青岛市市北区辽宁路63号公交停车场内,一辆报废公交车突然起火,由于大风天气,大火很快引燃了停在旁边的几辆报废车。消防人员快速赶到,迅速控制住火势。11时30分,停车场内的大火已经被完全扑灭,共有8辆公交车被烧毁,没有人员伤亡。消防人员正在现场进一步勘查具......
  • Google的jax框架在TPU上的循环控制 —— 向量计算设备的循环结构控制
    相关:https://jax.readthedocs.io/en/latest/pallas/tpu.html向量计算设备,如:GPU、TPU等,都是通过向量计算来进行加速的,因此在这类设备中进行向量计算的计算单元是成百上千的,但是进行结构控制的电路单元比较少,可以基本认为在向量设备中进行流程控制是标量的,而不是向量的,也就是说......