首页 > 其他分享 >手搓自动微分

手搓自动微分

时间:2024-04-18 16:15:39浏览次数:11  
标签:__ cos obj self 微分 自动 np grad

技术背景

自动微分技术,在各大深度学习框架里面得到了广泛的应用。但是其实究其原理,就是一个简单的链式法则。要实现一个自动微分框架是非常容易的事情,难的是高阶的自动微分和端到端的自动微分。这篇文章主要介绍一阶自动微分的基础Python实现,以及一些简单的测试案例。

链式法则

求导的链式法则,这个在高数里面大家就都学过了,形式比较简单:

\[f(g(x))'=f'[g(x)]\cdot g'(x) \]

或者可以写成这种形式:

\[\frac{df}{dx}=\frac{df}{dg}\cdot\frac{dg}{dx} \]

自动微分框架的使用

我们先用一些现成的自动微分框架,如MindSpore,来演示一下自动微分的基本用法:

import numpy as np
from mindspore import grad, Tensor
from mindspore import numpy as msnp
# 定义一个自变量x
x = Tensor(np.array([1., 2., 3.], np.float32))
# 定义一个复合函数
f = lambda x: msnp.sin(msnp.cos(x))
# 函数求导
gf = grad(f)
# 计算自动微分结果
print (gf(x))
# [-0.7216062  -0.831692   -0.07743199]

这里面的函数定义为:

\[f(x) = \sin(\cos(x)) \]

其导数解析形式为:

\[f'(x)=-\cos(\cos(x))\sin(x) \]

也可以用MindSpore做一个简单的验证:

print (-msnp.cos(msnp.cos(x))*msnp.sin(x))
# [-0.7216062  -0.831692   -0.07743199]

可以看到结果是一致的。

手搓自动微分

自己实现自动微分,其实就是把每一个操作函数的导数函数定义好,例如我们可以定义某一个操作的求导函数为__grad__(),而求值函数在python中有一个内置的__call__()函数。例如我们可以基于numpy的函数自定义一个正弦函数的类:

import numpy as np
class SIN:
    def __call__(self, x):
        """计算正弦值"""
        return np.sin(x)
    def __grad__(self, x):
        """计算正弦函数的导数值"""
        return np.cos(x)

然后配套一个grad自动微分函数:

def grad(obj):
    """直接调用输入操作的自动微分函数"""
    return obj.__grad__

甚至可以实现一个value_and_grad函数,同时计算值和导数:

class ValueAndGrad:
    def __init__(self, obj):
        """初始化输入对象的求值函数和求导函数"""
        self.obj1 = obj
        self.obj2 = obj.__grad__
    def __call__(self, x):
        """用元组的形式将值和导数的计算结果返回"""
        return (self.obj1(x), self.obj2(x))
def value_and_grad(obj):
    """初始化求值求导对象"""
    return ValueAndGrad(obj)

需要注意的是,因为大多数的场景下都会涉及到复合函数的计算,这也是自动微分技术的核心之一,因此我们自己实现的自动微分框架要能够接收一些外来的操作,然后在内部递归的计算。对应的带有自动微分的类格式变为:

class SIN:
    def __init__(self, obj=None):
        """给定一个其他的函数"""
        self.obj = obj
    def __call__(self, x):
        """没有复合函数时直接返回结果,有复合函数就递归计算"""
        return np.sin(x) if self.obj is None else np.sin(self.obj(x))
    def __grad__(self, x):
        """没有复合函数时直接返回导数结果,有复合函数就按照链式法则递归计算"""
        return COS()(x) if self.obj is None else COS()(self.obj(x))*self.obj.__grad__(x)

最终形成的自动微分实现案例为:

import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore import grad as msgrad
from mindspore import numpy as msnp

class SIN:
    """自定义正弦类"""
    def __init__(self, obj=None):
        self.obj = obj
    def __call__(self, x):
        return np.sin(x) if self.obj is None else np.sin(self.obj(x))
    def __grad__(self, x):
        return COS()(x) if self.obj is None else COS()(self.obj(x))*self.obj.__grad__(x)
    
class COS:
    """自定义余弦类"""
    def __init__(self, obj=None):
        self.obj = obj
    def __call__(self, x):
        return np.cos(x) if self.obj is None else np.cos(self.obj(x))
    def __grad__(self, x):
        return -SIN()(x) if self.obj is None else -SIN()(self.obj(x))*self.obj.__grad__(x)

class ValueAndGrad:
    """自定义求值求导类"""
    def __init__(self, obj):
        self.obj1 = obj
        self.obj2 = obj.__grad__
    def __call__(self, x):
        return (self.obj1(x), self.obj2(x))

def grad(obj):
    """自定义求导函数"""
    return obj.__grad__

def value_and_grad(obj):
    """自定义求值求导函数"""
    return ValueAndGrad(obj)

# 定义自变量
x = np.array([0., 1., 2., 3.,], np.float32)
# 单体函数验证
assert np.allclose(SIN()(x), np.sin(x))
# 单体函数求导验证
assert np.allclose(grad(SIN())(x), np.cos(x))
v, g = value_and_grad(SIN())(x)
# 单体函数求值求导验证
assert np.allclose(v, np.sin(x))
assert np.allclose(g, np.cos(x))
# 双复合函数验证
assert np.allclose(SIN(SIN())(x), np.sin(np.sin(x)))
assert np.allclose(SIN(COS())(x), np.sin(np.cos(x)))
assert np.allclose(COS(SIN())(x), np.cos(np.sin(x)))
assert np.allclose(COS(COS())(x), np.cos(np.cos(x)))
# 三复合函数验证
assert np.allclose(SIN(COS(SIN()))(x), np.sin(np.cos(np.sin(x))))
# 双复合函数求导验证
assert np.allclose(grad(SIN(SIN()))(x), np.cos(x)*np.cos(np.sin(x)))
tensor_x = Tensor(x, ms.float32)
ms_func1 = lambda x: msnp.sin(msnp.cos(x))
assert np.allclose(grad(SIN(COS()))(x), msgrad(ms_func1)(tensor_x).asnumpy())
ms_func2 = lambda x: msnp.cos(msnp.sin(x))
assert np.allclose(grad(COS(SIN()))(x), msgrad(ms_func2)(tensor_x).asnumpy())
ms_func3 = lambda x: msnp.cos(msnp.sin(msnp.cos(x)))
# 三复合函数求导验证
assert np.allclose(grad(COS(SIN(COS())))(x), msgrad(ms_func3)(tensor_x).asnumpy())

这里面除了可以跟手推的微分解析形式的计算结果进行比对之外,还可以跟MindSpore等自动微分框架计算出来的结果进行比对,可以看到结果都是一致的。

总结概要

不同于符号微分、手动微分和差分法,自动微分方法有着使用简单、计算精度较高、性能较好等优势,因此在各大深度学习框架中得到了广泛的应用。虽然每个框架所使用的自动微分的原理不尽相同,但大致都是基于链式法则计算结合图计算的一些优化。如果是自己动手来手搓一个自动微分框架的话,大致就只能实现一下一阶的链式法则的自动微分。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/auto-grad.html

作者ID:DechinPhy

更多原著文章:https://www.cnblogs.com/dechinphy/

请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

标签:__,cos,obj,self,微分,自动,np,grad
From: https://www.cnblogs.com/dechinphy/p/18143594/auto-grad

相关文章

  • gaussdb通过编写shell脚本自动化执行查询和结果收集
    转:https://support.huaweicloud.com/pwp-dws/dws_13_00033.html1、登录ECS,进入到/opt目录下,使用vim命令生成query.conf和run_query.sh两个脚本文件。脚本内容如下,编辑后按:wq!保存脚本配置:run_query.sh脚本如下:#!/bin/bashscript_path=$(cd`dirname$0`;pwd)query_mode=$1......
  • h5 自适应页面背景图无法自动适应的问题
     有时候制作好自适应页面,上面用的是背景图,发现在手机中,宽度会出现不能100%显示的问题,虽然样式中我们使用了width为100%。加入以下语句<html><head><metaname="viewport"content="width=1200px">当只设置width属性值,而不指定initial-scale属性值时,大多数浏览器......
  • 自动生成接口文档coreapi
    drf-yasg只能用于drf去看官方文档2coreapipipinstallcoreapi2.1配置路由fromrest_framework.documentationimportinclude_docs_urlsurlpatterns=[...path('docs/',include_docs_urls(title='站点页面标题'))]2.2drf配置#AttributeError:'......
  • 在博客园平台为博客自动化添加目录
    一、效果预览二、操作方法在设置-页脚HTML代码中添加如下代码:<scriptlanguage="javascript"type="text/javascript">//生成目录索引列表//ref:http://www.cnblogs.com/wangqiguo/p/4355032.html//modifiedby:zzqfunctionGenerateContentList(){varmainC......
  • Python-自动化秘籍(一)
    Python自动化秘籍(一)原文:zh.annas-archive.org/md5/de38d8b70825b858336fa5194110e245译者:飞龙协议:CCBY-NC-SA4.0前言我们都可能花费时间进行一些不太有价值的小手动任务。可能是在信息来源中搜索相关信息的小片段,使用电子表格一遍又一遍生成相同的图表,或者逐个搜索文件......
  • Python-自动化秘籍(二)
    Python自动化秘籍(二)原文:zh.annas-archive.org/md5/de38d8b70825b858336fa5194110e245译者:飞龙协议:CCBY-NC-SA4.0第三章:构建您的第一个Web抓取应用程序在本章中,我们将涵盖以下内容:下载网页解析HTML爬取网络订阅源访问WebAPI与表单交互使用Sel......
  • GridControl列自动匹配宽度(转)
    //自动调整所有字段宽度this.gridView1.BestFitColumns();//调整某列字段宽度this.gridView1.Columns[n].BestFit(); 大多是网上零散找到的,小部分是自己使用的时候自己遇到的。 XtraGrid的关键类就是:GridControl和GridView。GridControl本身不显示数据,数据都是显示在Grid......
  • 接口自动化测试工程实践分享
    本文作者:欧海锋,碧桂园服务高级测试工程师,致力于研究测试技术。一、前言接口自动化测试是一种软件测试技术,它通过模拟用户系统操作来对系统的接口进行自动化测试。接口自动化测试的目的是为了提高测试效率和准确性,同时降低测试成本和周期。以下是为什么需要进行接口自动化测试的......
  • OPC DA通信,自动读写数据
    主打的就是简单,使用非常简单!opcDaTags.Add(newOpcDaTag("numeric.random.int32"));opcDaTags.Add(newOpcDaTag("time.current"));opcDaTags.Add(newOpcDaTag("textual.weekday"));opcDaTags.Add(......
  • 新连点器和bat不弹黑窗口且自动获取管理员权限
    标题好长新的连点器相比原来那个c语言版,这次使用python编写,添加了简单的图形界面,参数调整非常简单(指的是直接编辑源码)直接贴完整代码:#导入模块importtkinterastkimportthreadingimportpyautoguiimportkeyboard#定义全局变量running=False#是否开启连点int......