首页 > 其他分享 >MindSpore反向传播配置关键字参数

MindSpore反向传播配置关键字参数

时间:2024-05-09 15:00:42浏览次数:21  
标签:msnp Tensor value 关键字 反向 import 参数 net MindSpore

技术背景

在MindSpore深度学习框架中,我们可以向construct函数传输必备参数或者关键字参数,这跟普通的Python函数没有什么区别。但是对于MindSpore中的自定义反向传播bprop函数,因为标准化格式决定了最后的两位函数输入必须是必备参数outdout用于接收函数值和导数值。那么对于一个自定义的反向传播函数而言,我们有可能要传入多个参数。例如这样的一个案例:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnp

class Net(nn.Cell):
    def bprop(self, x, y=1, out, dout):
        return msnp.cos(x) + y
    def construct(self, x, y=1):
        return msnp.sin(x) + y

x = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0))

但是因为在Python的函数传参规则下,必备参数必须放在关键字参数之前,也就是out和dout这两个参数要放在前面,否则就会出现这样的报错:

  File "test_rand.py", line 53
    def bprop(self, x, y=1, out, dout):
             ^
SyntaxError: non-default argument follows default argument

按照普通Python函数的传参规则,我们可以把y这个关键字参数的放到最后面去:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnp

class Net(nn.Cell):
    def bprop(self, x, out, dout, y=1):
        return msnp.cos(x) + y
    def construct(self, x, y=1):
        return msnp.sin(x) + y

x = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0))

经过这一番调整之后,我们发现没有报错了,可以正常输出结果,但是这个结果似乎不太正常:

[1.0015925] (Tensor(shape=[1], dtype=Float32, value= [ 1.59254798e-03]), Tensor(shape=[1], dtype=Float32, value= [ 1.25169754e-06]))

因为这里x传入了一个近似的\(\pi\),所以在construct函数计算函数值时,得到的结果应该是\(\sin(\pi)+y\),那么这里面\(y\)取\(0\)和\(1\)所得到的结果都是对的。但是关键问题在反向传播函数的计算,原本应该是\(\cos(\pi)+y=y-1\),但是在这里输入的\(y=0\),而导数的计算结果却是\(0\)而不是正确结果\(-1\)。这就说明,在MindSpore的自定义反向传播函数中,并不支持传入关键字参数。

解决方案

刚好前面写了一篇关于PyTorch的文章,这篇文章中提到的两个Issue就针对此类问题。受到这两个Issue的启发,我们在MindSpore中如果需要自定义反向传播函数,可以这么写:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnp

class Net(nn.Cell):
    def bprop(self, x, y, out, dout):
        return msnp.cos(x) + y if y is not None else msnp.cos(x)
    def construct(self, x, y=1):
        return msnp.sin(x) + y

x = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0))

简单来说就是,把原本要传给bprop的关键字参数,转换成必备参数的方式进行传入,然后做一个条件判断:当给定了该输入的时候,执行计算一,如果不给定参数值,或者给一个None,执行计算二。上述代码的执行结果如下所示:

[1.0015925] (Tensor(shape=[1], dtype=Float32, value= [ 1.59254798e-03]), Tensor(shape=[1], dtype=Float32, value= [-9.99998748e-01]))

这里输出的结果都是正确的。

当然,这里因为我们其实是强行把关键字参数按照顺序变成了必备参数进行输入,所以在顺序上一定要严格遵守bprop所定义的必备参数的顺序,否则计算结果也会出错:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnp

class Net(nn.Cell):
    def bprop(self, x, w, y, out, dout):
        return w*msnp.cos(x) + y if y is not None else msnp.cos(x)
    def construct(self, x, w=1, y=1):
        return msnp.sin(x) + y

x = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0, w=2))

输出的结果为:

[1.0015925] (Tensor(shape=[1], dtype=Float32, value= [ 1.59254798e-03]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))

那么很显然,这个结果就是因为在执行函数时给定的关键字参数跟必备参数顺序不一致,所以才出错的。

总结概要

继上一篇文章从Torch的两个Issue中找到一些类似的问题之后,可以发现深度学习框架对于自定义反向传播函数中的传参还是比较依赖于必备参数,而不是关键字参数,MindSpore深度学习框架也是如此。但是我们可以使用一些临时的解决方案,对此问题进行一定程度上的规避,只要能够自定义的传参顺序传入关键字参数即可。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/bprop-kwargs.html

作者ID:DechinPhy

更多原著文章:https://www.cnblogs.com/dechinphy/

请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

参考链接

  1. https://www.cnblogs.com/dechinphy/p/18179248/torch

标签:msnp,Tensor,value,关键字,反向,import,参数,net,MindSpore
From: https://www.cnblogs.com/dechinphy/p/18182055/bprop-kwargs

相关文章

  • Mybatis if判断中使用了Ognl关键字导致报错解决方法
    mybatisxml中使用OGNL解析参数,如果直接使用了关键字则会导致解析失败。常见的关键字有:字段mybatis关键字bor(字符|)的英文xor字符^的英文and字符&&band字符&ed字符==neg字符!=lt字符<gt字符>lte字符<=......
  • 深入浅出const和static关键字
    constconst是constant的缩写,意为不变的。在C++中是用来修饰内置类型变量,自定义对象,成员/普通函数,返回值,函数参数。C++const允许指定一个语义约束,编译器会强制实施这个约束,允许程序员告诉编译器某值是保持不变的。如果在编程中确实有某个值保持不变,就应该明确使用const,这样可......
  • 五一反向旅游,景区“AI+视频监控”将持续助力旅游业发展
    一、建设背景每年五一劳动节出去旅游都是人挤人状态,这导致景区的体验感极差。今年“五一反向旅游”的话题冲上了热搜,好多人选择了五一之后再出去旅游,避开拥挤的人群,这个时候景区的监管力度和感知能力就更要跟上去!随着人工智能技术的持续发展,景区的视频监控系统也可以融入AI智能......
  • volatile关键字
    volatile关键字概要volatile修饰符并不是Java语言的首创,早在C和C++当中就已经存在。为了理解volatile关键字的作用和原理,需要先了解一些计算机基础知识。请先参考《什么是Java内存模型(JMM)?》我们知道,并发编程时,线程安全涉及三个特性:原子性、可见性、有序性。volatile用于保证......
  • [转帖]【MySQL】字段名与关键字冲突解决办法
    https://www.jianshu.com/p/50e59feb3e83   首先,不推荐使用MySQL的关键词来作为字段名,但是有时候的确没有注意,或者因为之前就这么写了,没办法,那怎么办呢?方法1,改字段名,改了肯定就没问题了。这个就不细说了。方法2,使用引号`来处理。  下面就详细的说明一下怎样使用方法......
  • 使用快捷键的方式把多个关键字文本快速替换(快速替换AE脚本代码)
     首先,需要用到的这个工具:度娘网盘提取码:qwu2蓝奏云提取码:2r1z这里做AE(AdobeAfterEffact)里的脚本规则,把英文替换成中文,如下swap=thisComp.layer(“Segmentsettings”).effect("%")(“Checkbox”);if(swap==true){s=thisComp.layer(“Segmentsettings”).effect(“Se......
  • VScode自定义折叠代码快 region和endregion 关键字
    前言全局说明VScode自定义折叠代码快region和endregion关键字一、说明vscode有自带的代码折叠功能,但是因为某些内容不是标准的代码或不被识别就不能正常被折叠比如很多的单行注释,或者被注释的代码就能不能自动折叠。这里就要用到region和endregion关键字使用时r......
  • 使用PowerDesigner连接数据库并反向工程生成所有表及关系
    配置对数据库的JDBC连接时,总是提示连接失败!也没有任何其他信息,查阅网上资料并实际验证,按如下步骤可以成功:1、因为PowerDesigner是32位的程序,需要使用x86-32位版本的JDK2、配置PowerDesigner-》Tools-》GeneralOptions-》variables ,配置jar、java等路径配置为32位JDK3、......
  • MindSpore强化学习:使用PPO配合环境HalfCheetah-v2进行训练
    本文分享自华为云社区《MindSpore强化学习:使用PPO配合环境HalfCheetah-v2进行训练》,作者:irrational。半猎豹(HalfCheetah)是一个基于MuJoCo的强化学习环境,由P.Wawrzyński在“ACat-LikeRobotReal-TimeLearningtoRun”中提出。这个环境中的半猎豹是一个由9个链接和8个关节......
  • Nginx反向代理的好处
    负载均衡:好处:负载均衡可以将传入的请求分发到多个后端服务器,从而提高系统的性能和可靠性,同时避免单个服务器过载。例子:假设有一个电子商务网站,每天有大量用户同时访问,使用Nginx的负载均衡功能可以将请求分发到多个商品服务器上,确保每个用户都能够快速访问到商品信息,而不会因......