首页 > 其他分享 >昇思MindSpore进阶教程-参数初始化

昇思MindSpore进阶教程-参数初始化

时间:2024-09-24 08:55:27浏览次数:3  
标签:初始化 arr 进阶 shape 参数 fan MindSpore out

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

使用内置参数初始化

MindSpore提供了多种网络参数初始化的方式,并在部分算子中封装了参数初始化的功能。本节以Conv2d为例,分别介绍如何使用Initializer子类,字符串进行参数初始化。

Initializer初始化

Initializer是MindSpore内置的参数初始化基类,所有内置参数初始化方法均继承该类。mindspore.nn中提供的神经网络层封装均提供weight_init、bias_init等入参,可以直接使用实例化的Initializer进行参数初始化。样例如下:

import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore.common.initializer import Normal, initializer

input_data = ms.Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))
# 卷积层,输入通道为3,输出通道为64,卷积核大小为3*3,权重参数使用正态分布生成的随机数
net = nn.Conv2d(3, 64, 3, weight_init=Normal(0.2))
# 网络输出
output = net(input_data)

字符串初始化

除使用实例化的Initializer外,MindSpore也提供了参数初始化简易方法,即使用参数初始化方法名称的字符串。此方法使用Initializer的默认参数进行初始化。样例如下:

import numpy as np
import mindspore.nn as nn
import mindspore as ms

net = nn.Conv2d(3, 64, 3, weight_init='normal')
output = net(input_data)

自定义参数初始化

通常情况下,MindSpore提供的默认参数初始化可以满足常用神经网络层的初始化需求,在遇到需要自定义的参数初始化方法时,可以继承Initializer自定义参数初始化方法。下面以XavierNormal为例介绍自定义参数初始化方法:

import math
import numpy as np
from mindspore.common.initializer import Initializer


def _calculate_fan_in_and_fan_out(arr):
    # 计算fan_in和fan_out。fan_in是 `arr` 中输入单元的数量,fan_out是 `arr` 中输出单元的数量。
    shape = arr.shape
    dimensions = len(shape)
    if dimensions < 2:
        raise ValueError("'fan_in' and 'fan_out' can not be computed for arr with fewer than"
                         " 2 dimensions, but got dimensions {}.".format(dimensions))
    if dimensions == 2:  # Linear
        fan_in = shape[1]
        fan_out = shape[0]
    else:
        num_input_fmaps = shape[1]
        num_output_fmaps = shape[0]
        receptive_field_size = 1
        for i in range(2, dimensions):
            receptive_field_size *= shape[i]
        fan_in = num_input_fmaps * receptive_field_size
        fan_out = num_output_fmaps * receptive_field_size
    return fan_in, fan_out


class XavierNormal(Initializer):
    def __init__(self, gain=1):
        super().__init__()
        # 配置初始化所需要的参数
        self.gain = gain

    def _initialize(self, arr): # arr为需要初始化的Tensor
        fan_in, fan_out = _calculate_fan_in_and_fan_out(arr) # 计算fan_in, fan_out值

        std = self.gain * math.sqrt(2.0 / float(fan_in + fan_out)) # 根据公式计算std值
        data = np.random.normal(0, std, arr.shape) # 使用numpy构造初始化好的ndarray

        arr[:] = data[:] # 将初始化好的ndarray赋值到arr

完成自定义初始化方法后,我们可以像内置初始化方法一样进行调用:

net = nn.Conv2d(3, 64, 3, weight_init=XavierNormal())
# 网络输出
output = net(input_data)

Cell遍历初始化

除了使用weight_init, bias_init等mindspore.nn接口提供的入参外,我们也习惯于先构造完整神经网络,然后对weight、bias等参数进行统一管理。此时需要先构造网络并实例化,然后对Cell进行遍历,并对参数进行赋值。下面是一个简单的样例:

for name, param in net.parameters_and_names():
    if 'weight' in name:
        param.set_data(initializer(Normal(), param.shape, param.dtype))
    if 'bias' in name:
        param.set_data(initializer('zeros', param.shape, param.dtype))

标签:初始化,arr,进阶,shape,参数,fan,MindSpore,out
From: https://blog.csdn.net/weixin_42553583/article/details/142478518

相关文章

  • Vue 2&3进阶面试题:(第五天)
    目录17.keep-alive18.$router和$route的区别19.vue-router路由模式有几种?20.vue的路由传参param和query的区别17.keep-alivekeep-alive是Vue的内置组件,当它包裹动态组件时,会缓存不活动的组件实例,而不是销毁它们。keep-alive是一个抽象组件:它自身不会渲染成一个DO......
  • 筛质数(线性筛法--进阶版)(面对大部分都直接ac)
    给定一个正整数 n,请你求出 1∼n中质数的个数。输入格式共一行,包含整数 n。输出格式共一行,包含一个整数,表示 1∼n中质数的个数。数据范围1≤n≤10^6输入样例:8输出样例:4思路:给一个数:将质数筛到的同时,筛去它的倍数,并且该倍数一定是在给定的数内的这样在下次......
  • “RAII资源获取就是初始化”的好处
    RAII指的是“资源获取就是初始化”(ResourceAllocationIsInitialization),它被视作C++中最强大的编程范式之一。简单说来,它指的是,用构造函数来获取一个对象的资源,相应的,借助析构函数来释放对象的资源。为了理解这一范式的用处,让我们考虑某个函数使用文件句柄时的情况:voiddoSo......
  • SpringBoot 初始化资源
    1、使用接口ApplicationRunner和CommandLineRunner这两个接口都是在容器运行后执行的,如下图示 如果项目需要在系统启动时,初始化资源,可以继承这两个接口,实现诸如缓存预热、DB连接等。实现ApplicationRunner接口@ComponentpublicclassMyApplicationRunnerimplementsApp......
  • C++类成员变量初始化顺序
    C++类成员变量初始化顺序类成员初始化顺序与其在类中声明顺序一致。比如classDemo{public: Demo(intd) :_d1{d},_d2{_d1+10} { } voidshow(){ std::cout<<"d1="<<_d1<<std::endl; std::cout<<"d2="<<_d2<<std:......
  • Dockerfile全面指南:从基础到进阶,掌握容器化构建的核心工具
    目录Dockerfile全面指南:从基础到进阶,掌握容器化构建的核心工具引言一、什么是Dockerfile二、Dockerfile的基本结构三、Dockerfile的常见配置项1、多阶段构建(Multi-stageBuilds)2、缓存优化3、合并RUN命令四、Dockerfile使用须知五、一个完整的Dockerfile实......
  • 2024年华为杯研究生数学建模竞赛C题 波形机理建模+GBDT 进阶完整文章+代码+高级可视化
    2024年华为杯研究生数学建模竞赛C题波形机理建模+GBDT完整文章代码|进阶可视化全部问题已经更新完成,可视化图表20余张,代码量千余行,实在累到了…由于篇幅原因,此处放出部分内容供参考~完整内容可以从底部名片的群中获取~问题重述该题目围绕磁性元件的磁芯损耗建模......
  • 场景初始化
    获取初始化的元素//常规consttargetdom = document.getElementById('targetdom')//vue3consttargetdom = ref('targetdom')//reactconsttargetdom = ref('targetdom')初始化相机、场景、光源、renderconstcamera=newThree.PerspectiveCamera(......
  • Android15音频进阶之新播放器HwAudioSource(八十六)
    简介:CSDN博客专家、《Android系统多媒体进阶实战》一书作者新书发布:《Android系统多媒体进阶实战》......
  • C++ 列表初始化 {}
    花括号的形式{},进行列表初始化,在C++11中初始化变量到了全面的应用。可参看《C++Primer》P39P76P88等相关内容信息。Note:当我们提供一个类内初始值时,必须以符号=或者花括号表示。《C++Primer》P246。如下:classDog{public:Dog(intage):m_age(age){}......