首页 > 其他分享 >MATLAB神经网络——如何自定义属于自己的训练流程

MATLAB神经网络——如何自定义属于自己的训练流程

时间:2024-03-14 17:05:38浏览次数:27  
标签:自定义 训练 temp 神经网络 train MATLAB test data size

网络上大部分matlab神经网络训练流程都应用matlab内置的相关训练函数进行训练,如何让matlab神经网络训练过程拥有像pytorch一样的训练过程呢?本文将通过一个案例介绍如何利用matlab自定义自己的训练流程,希望对你有所启迪,让我们开始吧!

clear,clc

加载并处理原始数据  


我们使用mnist手写数字识别任务作为案例进行讲解。首先我们将原始数据进行加载。

data_train = load('mnist_train.csv');  % 加载训练数据
data_test = load('mnist_test.csv');  % 加载测试数据

此时我们可以发现训练数据维度为60000*785,测试数据维度为10000*785。你会发现,数据中第一列数据为标签,剩下的784项内容,则是把原本的28*28的图像展平后形成的数据。因此训练数据有60000张图片,测试数据有10000张图片。
 

 接下来我们先将标签提取出来。

labels_train = data_train(:,1);
labels_test = data_test(:,1);

我们将data中图像数据转化为BCHW的形式,即(批量大小,通道数,高度,宽度)的形式,例如data_train中图像为60000*684,则转换后为60000*1*28*28

images_train = reshape(data_train(:,2:end),[size(data_train,1),1,28,28]);
images_test = reshape(data_test(:,2:end),[size(data_test,1),1,28,28]);

网络结构设计

我们使用一个简单的卷积神经网络来完成此任务。网络结构如下图所示。

size_input = [28,28,1];  % 输入尺寸
size_output = 10;  % 输出尺寸

layers = [
    imageInputLayer(size_input,Normalization="none")
    convolution2dLayer(5,20,Padding="same")
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding="same")
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding="same")
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(size_output)
    softmaxLayer];

net = dlnetwork(layers);  % 创建网络

让我们测试一下网络的输入输出吧!

test_data = images_train(1:20,:,:,:);
test_data = dlarray(test_data,'BCSS');  % 将matlab数组转换为dlarray数据,dlarray是matlab
% 中专门用来进行深度学习的一种数据类型,类似于pytorch中的Tensor张量
test_out = predict(net,test_data);

可以发现test_out为一个10*20的dlarray数据,符合网络的输入输出设计。

损失函数设计

要自定义自己的训练流程,合理的损失函数设计是必要的。在代码中,损失函数是为了跟踪变量的输入输出过程,得到网络参数的梯度信息以更新网络参数。

function [loss,Gradients,state] = netLoss_cal(net,inputs,labels)
%% 损失函数设计
[out,state] = forward(net, inputs);  % 网络前向传播获得输出
loss = crossentropy(out,labels);  % 交叉熵损失函数
Gradients = dlgradient(loss, net.Learnables);  % 获得梯度信息
end

设计训练过程

设计好损失函数之后,我们就可以开始自定义训练流程并更新网络参数了,让我们一起来看看吧!

num_train = 10;  % 训练次数
% 在训练的过程中,每一次训练都会把训练集中所有数据用到
batch_size = 256;  % 训练批量大小(指每一次训练时模型输入数据个数)
batch_size_test = 300;  % 测试模型时的输入个数
LearnRate = 0.0001;  % 我们使用SGD方法更新网络参数,此处为学习率
momentum = 0.9;  % SGD方法中的"动量项"值
num_thistrain = ceil(size(data_train,1)/batch_size);  % 一次训练过程中,网络更新次数
vel =[];
for i=1:num_train
    Sequence_random = randperm(size(data_train,1));  % 将训练数据的顺序随机打乱
    % 进行训练
    for j=1:num_thistrain
        % 获得该次的训练输入和对应标签
        index_end = min([j*batch_size,size(data_train,1)]);  % 计算输入数据的末尾下标
        index_train = (j-1)*batch_size+1:index_end;
        temp_input = images_train(Sequence_random(index_train),:,:,:);  % 该次训练的输入
        temp_label = labels_train(Sequence_random(index_train),:);  % 该次训练的标签

        % 更新网络
        temp_input_dlarray = dlarray(temp_input,'BCSS');  % 数据类型转化
        targets = onehotencode(temp_label',1,'ClassNames',0:9);  % 对标签采用onehot编码
        targets = dlarray(targets,'CB');
        % [loss,Gradients,state] = netLoss_cal(net,inputs,labels)
        [loss,Gradients,state] = dlfeval(@netLoss_cal,net,temp_input_dlarray,targets); % 该函数
        % 可以跟踪输入在网络中的计算过程,从而得到梯度信息Gradients
        net.State = state;
        [net,vel] = sgdmupdate(net,Gradients,vel,LearnRate,momentum);  % 利用sgd(随机梯度下降)方法更新网络参数,
        % 其中vel为更新的速度参数,具体概念可见SGD计算方法过程,以及matlab中sgdmupdate函数用法
    end
    % 进行测试
    Sequence_random_test = randperm(size(data_test,1));  % 将训练数据的顺序随机打乱
    temp_input_test = images_test(Sequence_random_test(1:batch_size_test),:,:,:);  % 获得测试数据输入
    temp_label_test = labels_test(Sequence_random_test(1:batch_size_test),:);  % 获得测试数据标签
    temp_input_test_dlarray = dlarray(temp_input_test,'BCSS');  % 将测试数据输入转换为dlarray类型
    temp_out = predict(net,temp_input_test_dlarray);  % 获得输入对应的输出
    temp_out = extractdata(temp_out); % 将dlarray类型数据转换为正常matlab数组数据
    [~,out_labels] = max(temp_out);
    out_labels = out_labels - 1;
    % 计算准确率
    accur = 0;
    for j=1:length(temp_label_test)
        if temp_label_test(j) == out_labels(j)
            accur = accur + 1;
        end
    end
    accur = accur / length(temp_label_test);
    fprintf("第%d轮次训练准确率为%f\n",i,accur)
end

最终得到结果如下,可以看到该网络训练识别准确率可以达到98%

 改进拓展方法

改进1:上述训练过程是在CPU上进行的,其实,如果你有英伟达显卡的话,我们可以把变量放置在GPU上进行训练,可参考MATLAB官方文档:gpuArray——Array stored on GPU

改进2:如果你想把训练数据放到一个迭代器里边,或者让训练过程可视化,或者设计更多丰富多彩的效果,不妨参考这篇文档:定义自定义训练循环、损失函数和网络

标签:自定义,训练,temp,神经网络,train,MATLAB,test,data,size
From: https://blog.csdn.net/weixin_60223645/article/details/136685926

相关文章

  • 小白学视觉 | 神经网络训练trick总结
    本文来源公众号“小白学视觉”,仅用于学术分享,侵权删,干货满满。原文链接:神经网络训练trick总结来自|知乎  作者|Anticoder链接|https://zhuanlan.zhihu.com/p/59918821本文仅作学术交流,如有侵权,请联系删除神经网络构建好,训练不出好的效果怎么办?明明说好的拟合任......
  • Python实现BOA蝴蝶优化算法优化循环神经网络分类模型(LSTM分类算法)项目实战
    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取。1.项目背景蝴蝶优化算法(butterflyoptimizationalgorithm,BOA)是Arora等人于2019年提出的一种元启发式智能算法。该算法受到了蝴蝶觅食和交配行为的启发,......
  • Python实现BOA蝴蝶优化算法优化循环神经网络回归模型(LSTM回归算法)项目实战
    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取。1.项目背景蝴蝶优化算法(butterflyoptimizationalgorithm,BOA)是Arora等人于2019年提出的一种元启发式智能算法。该算法受到了蝴蝶觅食和交配行为的启发,......
  • Matlab的录音与播放
    目的:使用matlab生成特定信号,Speaker循环播放这组信号的同时,Microphone启动录音。一、生成一个单频正弦信号clc;clear;closeall;Fs=48000;%采样率为4800Hzf=18000;%正弦波频率为18000Hzt=0:1/Fs:10;%从0到10秒的时间向量y=sin(2*pi*f*t);%生成正弦波信......
  • HarmonyOS 鸿蒙 arkts 自定义组件插槽
    HarmonyOS鸿蒙arkts中自定义组件中要传入其他组件的时候就可以使用自定义组件插槽。Container组件添加child属性后,表示该组件具备了额外添加子组件的能力,接下来在需要添加子组件的地方使用child属性做占位即可。自定义组件@ComponentexportstructContainer{@Bu......
  • 基于遗传优化的协同过滤推荐算法matlab仿真
    1.算法运行效果图预览  最后得到推荐的商品ID号:推荐商品的ID号:ans=98381758221911149021490212352247322307123499117901547165501655016550......
  • MogDB openGauss 自定义snmptrapd告警信息
    MogDB/openGauss自定义snmptrapd告警信息本文出处:https://www.modb.pro/db/232391在之前的文章MogDB/openGauss监控告警配置介绍了如何通过alertmanager模块将报警通过snmp推送出去,但是在实际使用中,默认的报警规则信息并不能很好的满足snmp服务端的需求,需要定制化报警......
  • ZYNQ自定义IP并使用
    目的:自定义一个IP并添加到设计中使用(产生PWM波形)在ZYNQ系统中将许多特定功能的硬件设计模块封装起来称为IP核,类似于库函数。这种方式极大的提高了设计效率。当遇到设计中的一些特殊的需求且官方没有对应的IP时就需要我们自定义IP来使用。 创建步骤:1、创建新IP进入viv......
  • 反演问题求解:基于MATLAB的反演问题求解算法实现和应用,包括反演问题数值模拟、反演问题
    鱼弦:公众号【红尘灯塔】,CSDN内容合伙人、CSDN新星导师、全栈领域优质创作者 、51CTO(Top红人+专家博主) 、github开源爱好者(go-zero源码二次开发、游戏后端架构https://github.com/Peakchen)基于MATLAB的反演问题求解:原理、应用、实现与分析反演问题是指由间接观测数......
  • SpringBoot 中使用自定义参数解析器修改请求对象
    SpringBoot中使用自定义参数解析器修改请求对象在SpringBoot应用中,有时我们需要在控制器方法执行之前对请求对象进行修改。自定义参数解析器提供了一种灵活的方式来实现这一需求。1.创建自定义参数解析器首先,我们需要创建一个自定义参数解析器来处理对CommonRequest......