首页 > 其他分享 >昇思25天学习打卡营第8天|模型权重与 MindIR 的保存加载

昇思25天学习打卡营第8天|模型权重与 MindIR 的保存加载

时间:2024-07-03 13:03:06浏览次数:15  
标签:load 25 MindIR 模型 param 打卡 model mindspore 加载

目录

导入Python 库和模块

创建神经网络模型

保存和加载模型权重

保存和加载MindIR


导入Python 库和模块


        上一章节着重阐述了怎样对超参数予以调整,以及如何开展网络模型的训练工作。在网络模型训练的整个进程当中,事实上我们满怀期望能够留存中间阶段以及最终的成果,以便用于细微的调整(fine-tune)以及后续的模型推理和部署操作。在本章节,我们将会为您介绍怎样去保存以及加载模型。

        首先,我们进行了一系列的 Python 库和模块的导入操作:我们导入了 NumPy 库,并将其简称为 np 。要知道,NumPy 通常被广泛应用于数值计算领域以及数组相关的操作之中。此外,我们还导入了 MindSpore 库,MindSpore 乃是一个极为出色的深度学习框架。不仅如此,我们从 MindSpore 库中导入了 nn 模块,这里面或许涵盖了与神经网络相关联的各类类和函数。最后,我们还从 MindSpore 库中导入了 Tensor 类,其主要作用在于创建张量这种数据结构。

        代码如下:

import numpy as np  
import mindspore  
from mindspore import nn  
from mindspore import Tensor  

创建神经网络模型


        定义了一个被称作“network”的函数,此函数旨在创建一个神经网络模型。在该函数的内部,通过运用“nn.SequentialCell”构建了一个按照顺序相互连接的神经网络。最终,这个函数会返回构建完成的模型。

        代码如下:

def network():  
    model = nn.SequentialCell(  
                #用于将输入数据展平为一维向量  
                nn.Flatten(),  
                #全连接层,输入维度为 28*28,输出维度为 512。  
                nn.Dense(28*28, 512),  
                #激活函数 ReLU 层。  
                nn.ReLU(),  
                #全连接层,输入维度为 512,输出维度为 512。  
                nn.Dense(512, 512),  
                #激活函数 ReLU 层。  
                nn.ReLU(),  
                #全连接层,输入维度为 512,输出维度为 10。  
                nn.Dense(512, 10))  
    return model  

保存和加载模型权重


        当对模型进行保存操作时,将采用 save_checkpoint 这一接口,并将网络和特定指定的保存路径传入其中。

        代码如下:

model = network()  
mindspore.save_checkpoint(model, "model.ckpt")  

        分析:在 MindSpore 框架中,“model = network()”这行代码一般而言是创建了一个被命名为“model”的对象。此对象是通过对名为“network”的函数或者类的调用而得以生成。而“mindspore.save_checkpoint(model, "model.ckpt")”这行代码,其发挥的作用是借助 MindSpore 框架所提供的“save_checkpoint”函数,把创建好的“model”对象的当前状态保存至一个叫做“model.ckpt”的文件之中。之所以要进行这样的操作,通常是出于如下目的:在后续的一系列操作里,能够重新加载这个模型的状态,从而便于继续开展训练工作、执行预测任务,或者实现模型的迁移以及部署等相关操作。

        要实现模型权重的加载,第一步是创建相同的模型实例,接下来则要通过 load_checkpoint 和 load_param_into_net 方法对参数予以加载。

        代码如下:

model = network()  
param_dict = mindspore.load_checkpoint("model.ckpt")  
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)  
print(param_not_load)  

         分析:首先,通过 model = network() 创建了一个名为 model 的网络对象。

        然后,使用 mindspore.load_checkpoint("model.ckpt") 从名为 "model.ckpt" 的文件中加载模型的检查点数据,并将其存储在 param_dict 中。

        接着,通过 mindspore.load_param_into_net(model, param_dict) 尝试将加载的参数数据 param_dict 加载到模型 model 中。同时,返回未成功加载的参数以及一个相关的标识,未成功加载的参数存储在 param_not_load 中。

        最后,使用 print(param_not_load) 输出未成功加载的参数。

        运行结果:

        []

保存和加载MindIR


        除了 Checkpoint 之外,MindSpore 为云侧(训练)和端侧(推理)提供了统一的中间表示(Intermediate Representation,IR)。用户能够通过 export 接口,直接将模型保存为 MindIR 格式。这种统一的中间表示和便捷的模型保存方式,为模型的训练和推理提供了高效且便捷的支持,极大地提升了开发和应用的效率。

        代码如下:

model = network()  
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))  
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")  
mindspore.set_context(mode=mindspore.GRAPH_MODE)  
graph = mindspore.load("model.mindir")  
model = nn.GraphCell(graph)  
outputs = model(inputs)  
print(outputs.shape) 

        分析:首先定义了一个叫做 model 的网络模型,接着准备了一个输入数据 inputs ,这个数据的值全是 1 ,并且是张量的形式。然后通过 mindspore.export 把模型和这个输入保存成 MINDIR 格式的文件,文件名就叫 model 。接下来设置 MindSpore 的运行环境为图模式。再去加载之前保存的 model.mindir 文件,并把它转变为 GraphCell 类型的模型。之后使用之前准备好的输入数据 inputs 来对模型进行推理运算,从而得到输出 outputs 。最后把输出的形状给打印出来。

        运行结果:

        (1, 10)

        运行截图:

标签:load,25,MindIR,模型,param,打卡,model,mindspore,加载
From: https://blog.csdn.net/chinayun_6401/article/details/140148996

相关文章

  • YC307B [ 20240625 CQYC省选模拟赛 T2 ] 一个题(ynoi)
    题意你需要维护一个可重集\(S\),支持插入删除以及查询最大的方案使得给定正整数\(k\),划分为\(k\)个非空子集的按位与结果之和最大。\(n\le10^5\)Sol先上个trie。然后考虑一次查询怎么搞。先转化一下,如果需要划分为\(k\)个子集,显然需要合并\(n-k\)次。我们只......
  • Spring Boot 中 PGSQL 判断打卡点是否经过轨迹优化代码,循环查询物理表修改生成临时表,
    记录一下一个业务问题,流程是这样的,我现在有一个定时任务,5分钟执行一次,更新车辆打卡的情况。现在有20俩车,每辆车都分配了路线,每条路线都有打卡点,每个打卡点分配了不同的时间段,也就是说,一条路线可能有几百个打卡点,这几百个打卡点中每一个都分配了时间段,有可能是1个时间段,比如8......
  • 2025秋招计算机视觉面试题(七)-NMS详细工作机制及代码实现
    问题看到一句话:NMS都不懂,还做什么Detection!虎躯一震……懂是大概懂,但代码能写出来吗???在目标检测网络中,产生proposal后使用分类分支给出每个框的每类置信度,使用回归分支修正框的位置,最终会使用NMS方法去除同个类别当中IOU重叠度较高且scores即置信度较低的那些......
  • CS253 Laboratory session
    CS253 Laboratorysession4Part 1: Disassembling code, going backwards, converting an executable back to Assembly Language.Preamble: Remember that whatever language you are using ultimately it runs as Machine Code onthe processor......
  • ADS1256芯片说明
    本篇文章先总结一下24位的8通道24bit高精度采集的24位ADS1256,本篇文章不是纯粹的datasheet的抄袭,而是datasheet的总结,高度概括,以及对我们编程有用的思路,我大概看了一下网上流传的版本,大多数都是STM32,另外有一份是verilog不知道是谁写的,各个网都有,它是多通道采集,仅仅使用了一种模式......
  • UOJ #807. 【UR #25】装配序列
    题面传送门首先根据Dliworth定理,原问题等价于前缀LIS。考虑如何做到\(O(n^2)\)求出LIS的变化点(显然这只有\(n\)个)。按照值从小到大考虑,记\(f_{i,j}\)表示考虑到第\(i\)个值,长度为\(j\)的LIS最早在哪个前缀处出现,转移只需要two-pointers一遍就能更新。这个转......
  • 2023-2025年最值得选择的Java毕业设计选题大全:1000个热门选题推荐✅✅✅
    ......
  • 【打卡】003 p3 Pytorch实现天气识别
    打卡~555我的环境:●语言环境:Python ●编译器:jupyternotebook●深度学习环境:Pytorch>-**......
  • 昇思25天学习打卡营第13天| 数据变换 Transforms
    IT专业入门,高考假期预习指南七月来临,各省高考分数已揭榜完成。而高考的完结并不意味着学习的结束,而是新旅程的开始。对于有志于踏入IT领域的高考少年们,这个假期是开启探索IT世界的绝佳时机。作为该领域的前行者和经验前辈,你是否愿意为准新生们提供一份全面的学习路线图呢?快来......
  • 昇思25天学习打卡营第12天|网络构建
    IT专业入门,高考假期预习指南七月来临,各省高考分数已揭榜完成。而高考的完结并不意味着学习的结束,而是新旅程的开始。对于有志于踏入IT领域的高考少年们,这个假期是开启探索IT世界的绝佳时机。作为该领域的前行者和经验前辈,你是否愿意为准新生们提供一份全面的学习路线图呢?快来......