首页 > 其他分享 >67、基于长短期记忆网络的心电图(ECG)信号分类(matlab)

67、基于长短期记忆网络的心电图(ECG)信号分类(matlab)

时间:2024-07-05 12:56:14浏览次数:21  
标签:ECG ... 训练 matlab 信号 67 cellfun LSTM

1、基于长短期记忆网络的心电图(ECG)信号分类原理及流程

基于长短期记忆网络(LSTM)的心电图(ECG)信号分类通常用于分析和识别心脏方面的问题,如心律失常。下面是基于LSTM的ECG信号分类的原理和流程:

原理:

  1. 长短期记忆网络(LSTM):LSTM是一种特殊的循环神经网络(RNN),能够更好地捕捉时间序列数据中的长期依赖关系。在ECG信号分类中,LSTM可以有效地保留和利用ECG信号中的时间相关信息。

  2. 特征提取:LSTM网络通过学习ECG信号中的特征,从而能够将其向量化表示,以便进行分类。

  3. 分类器:在LSTM网络的输出端,一般会连接一个分类器来对ECG信号进行分类,例如softmax分类器用于将ECG信号分为不同的类别,如正常心电图、心律失常等。

流程:

  1. 数据准备:收集具有标注类别(如心律失常类型)的ECG信号数据集,确保数据质量和标记的准确性。

  2. 数据预处理:对ECG信号进行预处理,包括信号降噪、滤波、标准化等操作,以确保信号质量和一致性。

  3. 特征提取:将经过处理的ECG信号输入到LSTM网络中,LSTM网络会学习并提取出关键的时间序列特征。

  4. 模型训练:使用训练数据来训练LSTM模型,通过反向传播算法不断调整模型参数,使其能够最好地拟合训练数据。

  5. 模型评估:使用测试数据集来评估训练好的模型性能,包括准确度、召回率、精确度等指标。

  6. 结果解释:根据模型的输出结果进行结果解释和分析,判断ECG信号的类别,如是否存在心律失常等。

  7. 优化和调整:根据评估结果对模型进行优化和调整,以提高模型的效果和泛化能力。

通过以上流程,基于LSTM的ECG信号分类可以帮助医生和研究人员更准确地分析和识别心脏疾病。

2、基于长短期记忆网络的心电图(ECG)信号分类说明

方案

用长短期记忆网络和时频分析对来自 PhysioNet 2017 Challenge 的心电图 (ECG) 数据进行分类

第一步:使用原始数据训练 LSTM 网络得到的结果

第二步:使用提取的特征训练相同的 LSTM 网络得到的结果

说明

ECG 记录一段时间内人体心脏的电活动。

示例使用的 ECG 数据来自 PhysioNet 2017 Challenge [1]、[2]、[3],可从 https://physionet.org/challenge/2017/ 获得。

数据由一组以 300 Hz 采样的 ECG 信号组成,由一组专家分成四个不同类:正常 (N)、AFib (A)、其他心律 (O) 和含噪记录 (~)。

使用深度学习来自动化分类过程

使用一个二类分类器,该二类分类器可以将正常 ECG 信号和显示 AFib 符号的信号区分开来

 

长短期记忆 (LSTM) 网络是一种循环神经网络 (RNN),非常适合研究序列和时间序列数据。

LSTM 网络可以学习序列的时间步之间的长期相关性。LSTM 层 (lstmLayer) 可以前向分析时间序列,而双向 LSTM 层 (bilstmLayer) 可以前向和后向分析时间序列。

 

3、加载数据

1)数据加载

Signals保存 ECG 信号的元胞数组。Labels 分类数组,它保存信号的对应真实值标签。

load PhysionetData

代码实现

load PhysionetData
Signals(1:5)'
Labels(1:5)

ans=1×5 cell array
    {1×9000 double}    {1×9000 double}    {1×18000 double}    {1×9000 double}    {1×18000 double}

ans = 5×1 categorical
     N 
     N 
     N 
     A 
     A 

 2)使用 summary 函数查看数据中包含多少 AFib 信号和正常信号

代码实现

summary(Labels)

A       738 
     N      5050

3) 信号长度的直方图

代码实现

L = cellfun(@length,Signals);
h = histogram(L);
xticks(0:3000:18000);
xticklabels(0:3000:18000);
title('Signal Lengths')
xlabel('Length')
ylabel('Count')

 视图效果

13f9e4e1bc144c7799ebf6ce681d6b88.png

4) 可视化每个类中一个信号的一段

说明

AFib 心跳间隔不规则,而正常心跳会周期性发生。AFib 心跳信号还经常缺失 P 波,P 波在正常心跳信号的 QRS 复合波之前出现。正常信号的绘图会显示 P 波和 QRS 复合波。

代码实现

normal = Signals{1};
aFib = Signals{4};

subplot(2,1,1)
plot(normal)
title('Normal Rhythm')
xlim([4000,5200])
ylabel('Amplitude (mV)')
text(4330,150,'P','HorizontalAlignment','center')
text(4370,850,'QRS','HorizontalAlignment','center')

subplot(2,1,2)
plot(aFib)
title('Atrial Fibrillation')
xlim([4000,5200])
xlabel('Samples')
ylabel('Amplitude (mV)')

视图效果

 adcf7fa196d9461881a4207e193f8831.png

4、准备要训练的数据

说明1

在训练期间,trainnet 函数将数据分成小批量。然后,该函数在同一个小批量中填充或截断信号,使它们都具有相同的长度。

避免过度填充或截断,请对 ECG 信号应用 segmentSignals 函数,使它们的长度都为 9000 个采样

代码实现

[Signals,Labels] = segmentSignals(Signals,Labels);

说明2

查看 Signals 数组的前五个元素,以验证每个条目的长度现在为 9000 个采样。

代码实现 

Signals(1:5)'
ans=1×5 cell array
    {1×9000 double}    {1×9000 double}    {1×9000 double}    {1×9000 double}    {1×9000 double}

5、第一步:使用原始数据训练 LSTM 网络得到的结果

1)数据分化比例

将信号分成一个训练集(用于训练分类器)和一个测试集(用于基于新数据测试分类器的准确度)

使用 summary 函数显示 AFib 信号与正常信号的比率约为 1:7

实现代码

summary(Labels)
 A       718 
     N      4937 

2) 统一数据

通过复制数据集中的 AFib 信号来增加 AFib 数据,以便正常信号和 AFib 信号的数量相同

实现代码

afibX = Signals(Labels=='A');
afibY = Labels(Labels=='A');

normalX = Signals(Labels=='N');
normalY = Labels(Labels=='N');

 3)使用 dividerand 将每个类的目标随机分为训练集、验证集和测试集

实现代码

rng("default")
[trainIndA,validIndA,testIndA] = dividerand(length(afibX),0.8,0.1,0.1);
[trainIndN,validIndN,testIndN] = dividerand(length(normalX),0.8,0.1,0.1);
XTrainA = afibX(trainIndA);
YTrainA = afibY(trainIndA);
XTrainN = normalX(trainIndN);
YTrainN = normalY(trainIndN);

XValidA = afibX(validIndA);
YValidA = afibY(validIndA);
XValidN = normalX(validIndN);
YValidN = normalY(validIndN);

XTestA = afibX(testIndA);
YTestA = afibY(testIndA);
XTestN = normalX(testIndN);
YTestN = normalY(testIndN);

4) 平衡数据集

数据集不平衡。要获得相似数量的 AFib 信号和正常信号,请重复七次 AFib 信号。

默认情况下,神经网络会在训练前随机对数据进行乱序处理,以确保相邻信号不都有相同的标签。

实现代码

XTrain = [repmat(XTrainA,7,1); XTrainN];
YTrain = [repmat(YTrainA,7,1); YTrainN];

XValid = [repmat(XValidA,7,1); XValidN];
YValid = [repmat(YValidA,7,1); YValidN];

XTest = [repmat(XTestA,7,1); XTestN];
YTest = [repmat(YTestA,7,1); YTestN];

 5)正常信号和 AFib 信号在训练集、验证集和测试集中均衡分布

实现代码

summary(YTrain)
summary(YValid)
summary(YTest)

     A      4018 
     N      3949 

A      504 
     N      494 


A      504 
     N      494 

5.1、定义 LSTM 网络架构

1)使用双向 LSTM 层 bilstmLayer

使用双向 LSTM 层 bilstmLayer,其前向和后向检测序列

将输入大小指定是大小为 1 的序列,

指定输出大小为 50 的一个双向 LSTM 层,并输出序列的最后一个元素

实现代码

layers = [ ...
    sequenceInputLayer(1)
    bilstmLayer(50,'OutputMode','last')
    fullyConnectedLayer(2)
    softmaxLayer
    ]

layers = 
  4×1 Layer array with layers:

     1   ''   Sequence Input    Sequence input with 1 dimensions
     2   ''   BiLSTM            BiLSTM with 50 hidden units
     3   ''   Fully Connected   2 fully connected layer
     4   ''   Softmax           softmax

2) 指定分类器的训练选项

将 'MaxEpochs' 设置为 100,以允许基于训练数据对网络进行 100 轮训练。

'MiniBatchSize' 为 300 指示网络一次分析 300 个训练信号

'InitialLearnRate' 为 0.01 有助于加快训练过程

将 'Plots' 指定为 'training-progress',以生成显示训练随迭代次数的增加而变化的进度图

将 'Verbose' 设置为 false 以隐藏对应于图中所示数据的表输出

使用自适应矩估计 (ADAM) 求解器

实现代码

options = trainingOptions('adam', ...
    'MaxEpochs',150, ...
    'MiniBatchSize', 200, ...
    'GradientThreshold',1, ...
    'Shuffle','every-epoch', ...
    'InitialLearnRate', 1e-3, ...
    'ExecutionEnvironment','auto', ...
    'plots','training-progress', ...
    'Metrics','accuracy', ...
    'InputDataFormats','CTB', ...
    'ValidationData',{XValid,YValid}, ...
    'Verbose',false, ...
    'OutputNetwork','last-iteration');

5.2、训练 LSTM 网络

说明

使用 trainnet 用指定的训练选项和层架构训练 LSTM 网络

此例训练准确度很高,但验证准确度并没有相应提高。这可能指示过拟合,意味着模型无法泛化,而是与训练数据集过于接近。

实现代码

net = trainnet(XTrain,YTrain,layers,"crossentropy",options);

视图效果

 e2b08c23059245aa95b4fb9e65d11e9f.png

5.3、可视化训练和测试准确度

1)使用多个观测值进行预测

实现代码

classNames = categories(YTrain);
scores = minibatchpredict(net,XTrain,"InputDataFormats","CTB");
trainPred = scores2label(scores,classNames);

 2)使用 confusionchart 命令计算用于测试数据预测的总体分类准确度

实现代码

LSTMAccuracy = sum(trainPred == YTrain)/numel(YTrain)*100
LSTMAccuracy = 99.0335

3)混淆矩阵可视化分类器

实现代码 

figure
confusionchart(YTrain,trainPred,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');

视图效果 

54f3fe12a6214059bad042afa0a1764b.png

5.4、测试网络

1)用相同的网络对测试数据进行分类

实现代码

scores = minibatchpredict(net,XTest,InputDataFormats="CTB");
testPred = scores2label(scores,classNames);

 2)计算测试准确度,并使用混淆矩阵将分类性能可视化

实现代码

LSTMAccuracy = sum(testPred == YTest)/numel(YTest)*100
figure
confusionchart(YTest,testPred,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');



LSTMAccuracy = 61.1222

 视图效果

9704a446e3e64dbbaf5231621d884b24.png

6、第二步:使用提取的特征训练相同的 LSTM 网络得到的结果

6.1、从数据中提取特征有助于提高分类器的性能

1)从数据中提取特征有助于提高分类器的性能

为了决定提取哪些特征,先计算时频图像(如频谱图),然后使用它们来训练卷积神经网络 (CNN)

可视化每个信号类型的频谱图

实现代码

fs = 300;

figure
subplot(2,1,1);
pspectrum(normal,fs,'spectrogram','TimeResolution',0.5)
title('Normal Signal')

subplot(2,1,2);
pspectrum(aFib,fs,'spectrogram','TimeResolution',0.5)
title('AFib Signal')

视图效果

76385f43baa347ec95b4e519cee7f9c3.png 

2) 可视化每个信号类型的瞬时频率

instfreq 函数估计信号的时变频率,作为功率谱图的第一个矩。

实现代码

[instFreqA,tA] = instfreq(aFib,fs);
[instFreqN,tN] = instfreq(normal,fs);

figure
subplot(2,1,1);
plot(tN,instFreqN)
title('Normal Signal')
xlabel('Time (s)')
ylabel('Instantaneous Frequency')

subplot(2,1,2);
plot(tA,instFreqA)
title('AFib Signal')
xlabel('Time (s)')
ylabel('Instantaneous Frequency')

 视图效果

51655d03b8424707949a14c37256867e.png

 3)使用 cellfun 将 instfreq 函数应用于训练集中和测试集中的每个单元

实现代码

instfreqTrain = cellfun(@(x)instfreq(x,fs)',XTrain,'UniformOutput',false);
instfreqTest = cellfun(@(x)instfreq(x,fs)',XTest,'UniformOutput',false);
instfreqValid = cellfun(@(x)instfreq(x,fs)',XValid,'UniformOutput',false);

 4)谱熵测量信号的频谱的尖度或平坦度

具有尖峰频谱的信号(如正弦波之和)具有低谱熵。

% 具有平坦频谱的信号(如白噪声)具有高谱熵。pentropy 函数基于功率谱估计谱熵。

实现代码

[pentropyA,tA2] = pentropy(aFib,fs);
[pentropyN,tN2] = pentropy(normal,fs);

figure
subplot(2,1,1)
plot(tN2,pentropyN)
title('Normal Signal')
ylabel('Spectral Entropy')

subplot(2,1,2)
plot(tA2,pentropyA)
title('AFib Signal')
xlabel('Time (s)')
ylabel('Spectral Entropy')-

视图效果

 43b32cdf9ab94c5bbf5f4cf56cdb2a57.png

5)使用 cellfun 将 pentropy 函数应用于训练集、测试集和验证集中的每个单元

实现代码

pentropyTrain = cellfun(@(x)pentropy(x,fs)',XTrain,'UniformOutput',false);
pentropyTest = cellfun(@(x)pentropy(x,fs)',XTest,'UniformOutput',false);
pentropyValid = cellfun(@(x)pentropy(x,fs)',XValid,'UniformOutput',false);

6)串联这些特征,使新的训练集和测试集中的每个单元都有两个维度(即两个特征)

实现代码 

XTrain2 = cellfun(@(x,y)[x;y],instfreqTrain,pentropyTrain,'UniformOutput',false);
XTest2 = cellfun(@(x,y)[x;y],instfreqTest,pentropyTest,'UniformOutput',false);
XValid2 = cellfun(@(x,y)[x;y],instfreqTest,pentropyTest,'UniformOutput',false);

 7)可视化新输入的格式

实现代码

XTrain2(1:5)
ans=5×1 cell array
    {2×255 double}
    {2×255 double}
    {2×255 double}
    {2×255 double}
    {2×255 double}

6.2、标准化数据

1)瞬时频率和谱熵的均值

实现代码

mean(instFreqN)
mean(pentropyN)
ans = 5.5551
ans = 0.6324

2)使用训练集均值和标准差来标准化训练集、测试集和验证集

实现代码

XV = [XTrain2{:}];
mu = mean(XV,2);
sg = std(XV,[],2);

XTrainSD = XTrain2;
XTrainSD = cellfun(@(x)(x-mu)./sg,XTrainSD,'UniformOutput',false);
XValidSD = XValid2;
XValidSD = cellfun(@(x)(x-mu)./sg,XValidSD,'UniformOutput',false);
XTestSD = XTest2;
XTestSD = cellfun(@(x)(x-mu)./sg,XTestSD,'UniformOutput',false);

3) 显示标准化瞬时频率和谱熵的均值

实现代码

instFreqNSD = XTrainSD{1}(1,:);
pentropyNSD = XTrainSD{1}(2,:);

mean(instFreqNSD)
mean(pentropyNSD)
ans = 0.1544
ans = 0.1935

6.3、修改 LSTM 网络架构

1)现在每个信号都有两个维度,就有必要通过将输入序列大小指定为 2 来修改网络架构

实现代码

layers = [ ...
    sequenceInputLayer(2)
    bilstmLayer(50,'OutputMode','last')
    fullyConnectedLayer(2)
    softmaxLayer
    ]

layers = 
  4×1 Layer array with layers:

     1   ''   Sequence Input    Sequence input with 2 dimensions
     2   ''   BiLSTM            BiLSTM with 50 hidden units
     3   ''   Fully Connected   2 fully connected layer
     4   ''   Softmax           softmax

2)指定训练选项

将最大轮数设置为 120,以允许基于训练数据对网络进行 120 轮训练

实现代码

options = trainingOptions('adam', ...
    'MaxEpochs',150, ...
    'MiniBatchSize', 200, ...
    'GradientThreshold',1, ...
    'Shuffle','every-epoch', ...
    'InitialLearnRate', 1e-3, ...
    'ExecutionEnvironment','auto',...
    'plots','training-progress', ...
    'Metrics','accuracy', ...
    'InputDataFormats','CTB', ...
    'ValidationData',{XValidSD,YValid}, ...
    'OutputNetwork','last-iteration', ...
    'Verbose',false);

6.4、用时频特征训练 LSTM 网络

1)使用 trainnet 用指定的训练选项和层架构训练 LSTM 网络

实现代码

net2 = trainnet(XTrainSD,YTrain,layers,"crossentropy",options);

视图效果

 06fc4d5330db4a0bbeca5f90f8fdf1e1.png

6.5、可视化训练和测试准确度

1)使用更新后的 LSTM 网络对训练数据进行分类

实现代码

scores = minibatchpredict(net2,XTrainSD,"InputDataFormats","CTB");
trainPred2 = scores2label(scores,classNames);
LSTMAccuracy = sum(trainPred2 == YTrain)/numel(YTrain)*100

LSTMAccuracy = 96.3600

2)将分类性能可视化为混淆矩阵

实现代码

igure
confusionchart(YTrain,trainPred2,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');

视图效果

 94e42b74c08e496faf792fc3a2c7aa3f.png

7、总结

在 MATLAB 中使用长短期记忆网络(LSTM)进行心电图(ECG)信号分类可以通过以下步骤实现:

步骤总结:

  1. 数据准备

    • 收集具有标记的ECG信号数据集,包括正常心电图和心律失常等类别的数据。
    • 确保数据集质量和标记的准确性。
  2. 数据预处理

    • 对ECG信号进行必要的预处理,如去噪、滤波、标准化等,以提高数据质量。
  3. 数据格式转换

    • 将预处理后的ECG信号数据转换为适合LSTM网络输入的格式,通常是一个三维矩阵(样本数 x 时间步长 x 特征数)。
  4. 构建LSTM模型

    • 使用 MATLAB 中提供的深度学习工具箱(Deep Learning Toolbox)构建LSTM模型,指定网络结构、层类型、隐藏单元数量等参数。
  5. 模型训练

    • 划分训练集和测试集,然后使用训练数据对LSTM模型进行训练。
    • 通过定义损失函数和选择优化器(如随机梯度下降),迭代调整模型参数以提高性能。
  6. 模型评估

    • 使用测试数据集评估训练好的模型,计算分类准确度、混淆矩阵等指标来评估模型的性能。
  7. 模型应用

    • 利用训练好的LSTM模型对新的ECG信号进行分类预测,识别心脏疾病或其他异常情况。
  8. 结果分析

    • 对模型分类结果进行解释和分析,查看模型在不同类别上的表现,进一步优化模型和参数设置。

通过以上步骤,可以在 MATLAB 环境中利用LSTM网络对ECG信号进行分类,帮助医生和研究人员更好地理解心脏疾病并进行预测和诊断。深度学习工具箱和更多的相关工具可以帮助简化实现这些步骤的过程,并提高模型性能。

8、源程序

代码

%% 基于长短期记忆网络的心电图(ECG)信号分类
%使用长短期记忆网络和时频分析对来自 PhysioNet 2017 Challenge 的心电图 (ECG) 数据进行分类
%说明:ECG 记录一段时间内人体心脏的电活动。
%示例使用的 ECG 数据来自 PhysioNet 2017 Challenge [1]、[2]、[3],可从 https://physionet.org/challenge/2017/ 获得。
%数据由一组以 300 Hz 采样的 ECG 信号组成,由一组专家分成四个不同类:正常 (N)、AFib (A)、其他心律 (O) 和含噪记录 (~)。
%使用深度学习来自动化分类过程
%使用一个二类分类器,该二类分类器可以将正常 ECG 信号和显示 AFib 符号的信号区分开来

%长短期记忆 (LSTM) 网络是一种循环神经网络 (RNN),非常适合研究序列和时间序列数据。
%LSTM 网络可以学习序列的时间步之间的长期相关性。LSTM 层 (lstmLayer) 可以前向分析时间序列,而双向 LSTM 层 (bilstmLayer) 可以前向和后向分析时间序列。

%第一步:使用原始数据训练 LSTM 网络得到的结果
%第二步:使用提取的特征训练相同的 LSTM 网络得到的结果

%% 加载数据
%Signals保存 ECG 信号的元胞数组。Labels 分类数组,它保存信号的对应真实值标签。
load PhysionetData
Signals(1:5)'
Labels(1:5)
%使用 summary 函数查看数据中包含多少 AFib 信号和正常信号。
summary(Labels)
%信号长度的直方图,
L = cellfun(@length,Signals);
h = histogram(L);
xticks(0:3000:18000);
xticklabels(0:3000:18000);
title('Signal Lengths')
xlabel('Length')
ylabel('Count')
% 可视化每个类中一个信号的一段。
%AFib 心跳间隔不规则,而正常心跳会周期性发生。
%AFib 心跳信号还经常缺失 P 波,P 波在正常心跳信号的 QRS 复合波之前出现。正常信号的绘图会显示 P 波和 QRS 复合波。
normal = Signals{1};
aFib = Signals{4};

subplot(2,1,1)
plot(normal)
title('Normal Rhythm')
xlim([4000,5200])
ylabel('Amplitude (mV)')
text(4330,150,'P','HorizontalAlignment','center')
text(4370,850,'QRS','HorizontalAlignment','center')

subplot(2,1,2)
plot(aFib)
title('Atrial Fibrillation')
xlim([4000,5200])
xlabel('Samples')
ylabel('Amplitude (mV)')
%%  准备要训练的数据
%在训练期间,trainnet 函数将数据分成小批量。然后,该函数在同一个小批量中填充或截断信号,使它们都具有相同的长度。
%避免过度填充或截断,请对 ECG 信号应用 segmentSignals 函数,使它们的长度都为 9000 个采样
[Signals,Labels] = segmentSignals(Signals,Labels);
%查看 Signals 数组的前五个元素,以验证每个条目的长度现在为 9000 个采样。
Signals(1:5)'
%% 第一步:使用原始数据训练 LSTM 网络得到的结果
%将信号分成一个训练集(用于训练分类器)和一个测试集(用于基于新数据测试分类器的准确度)
%使用 summary 函数显示 AFib 信号与正常信号的比率约为 1:7
summary(Labels)
%通过复制数据集中的 AFib 信号来增加 AFib 数据,以便正常信号和 AFib 信号的数量相同
afibX = Signals(Labels=='A');
afibY = Labels(Labels=='A');

normalX = Signals(Labels=='N');
normalY = Labels(Labels=='N');
%使用 dividerand 将每个类的目标随机分为训练集、验证集和测试集
rng("default")
[trainIndA,validIndA,testIndA] = dividerand(length(afibX),0.8,0.1,0.1);
[trainIndN,validIndN,testIndN] = dividerand(length(normalX),0.8,0.1,0.1);
XTrainA = afibX(trainIndA);
YTrainA = afibY(trainIndA);
XTrainN = normalX(trainIndN);
YTrainN = normalY(trainIndN);

XValidA = afibX(validIndA);
YValidA = afibY(validIndA);
XValidN = normalX(validIndN);
YValidN = normalY(validIndN);

XTestA = afibX(testIndA);
YTestA = afibY(testIndA);
XTestN = normalX(testIndN);
YTestN = normalY(testIndN);
% 数据集不平衡。要获得相似数量的 AFib 信号和正常信号,请重复七次 AFib 信号。
%默认情况下,神经网络会在训练前随机对数据进行乱序处理,以确保相邻信号不都有相同的标签。
XTrain = [repmat(XTrainA,7,1); XTrainN];
YTrain = [repmat(YTrainA,7,1); YTrainN];

XValid = [repmat(XValidA,7,1); XValidN];
YValid = [repmat(YValidA,7,1); YValidN];

XTest = [repmat(XTestA,7,1); XTestN];
YTest = [repmat(YTestA,7,1); YTestN];
%正常信号和 AFib 信号在训练集、验证集和测试集中均衡分布。
summary(YTrain)
summary(YValid)
summary(YTest)
%% 定义 LSTM 网络架构
%使用双向 LSTM 层 bilstmLayer,其前向和后向检测序列
%将输入大小指定是大小为 1 的序列,
%指定输出大小为 50 的一个双向 LSTM 层,并输出序列的最后一个元素
layers = [ ...
    sequenceInputLayer(1)
    bilstmLayer(50,'OutputMode','last')
    fullyConnectedLayer(2)
    softmaxLayer
    ]

%指定分类器的训练选项
%将 'MaxEpochs' 设置为 100,以允许基于训练数据对网络进行 100 轮训练。
%'MiniBatchSize' 为 300 指示网络一次分析 300 个训练信号
%'InitialLearnRate' 为 0.01 有助于加快训练过程
%将 'Plots' 指定为 'training-progress',以生成显示训练随迭代次数的增加而变化的进度图
%将 'Verbose' 设置为 false 以隐藏对应于图中所示数据的表输出
%使用自适应矩估计 (ADAM) 求解器
options = trainingOptions('adam', ...
    'MaxEpochs',150, ...
    'MiniBatchSize', 200, ...
    'GradientThreshold',1, ...
    'Shuffle','every-epoch', ...
    'InitialLearnRate', 1e-3, ...
    'ExecutionEnvironment','auto', ...
    'plots','training-progress', ...
    'Metrics','accuracy', ...
    'InputDataFormats','CTB', ...
    'ValidationData',{XValid,YValid}, ...
    'Verbose',false, ...
    'OutputNetwork','last-iteration');
%% 训练 LSTM 网络 
%使用 trainnet 用指定的训练选项和层架构训练 LSTM 网络
net = trainnet(XTrain,YTrain,layers,"crossentropy",options);
%此例训练准确度很高,但验证准确度并没有相应提高。这可能指示过拟合,意味着模型无法泛化,而是与训练数据集过于接近。

%% 可视化训练和测试准确度
%使用多个观测值进行预测
classNames = categories(YTrain);
scores = minibatchpredict(net,XTrain,"InputDataFormats","CTB");
trainPred = scores2label(scores,classNames);
%使用 confusionchart 命令计算用于测试数据预测的总体分类准确度
LSTMAccuracy = sum(trainPred == YTrain)/numel(YTrain)*100

%混淆矩阵可视化分类器
figure
confusionchart(YTrain,trainPred,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');


%% 测试网络
%用相同的网络对测试数据进行分类
scores = minibatchpredict(net,XTest,InputDataFormats="CTB");
testPred = scores2label(scores,classNames);
%计算测试准确度,并使用混淆矩阵将分类性能可视化

LSTMAccuracy = sum(testPred == YTest)/numel(YTest)*100
figure
confusionchart(YTest,testPred,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');

%% 第二步:使用提取的特征训练相同的 LSTM 网络得到的结果
%从数据中提取特征有助于提高分类器的性能。
%为了决定提取哪些特征,先计算时频图像(如频谱图),然后使用它们来训练卷积神经网络 (CNN)
%可视化每个信号类型的频谱图
fs = 300;

figure
subplot(2,1,1);
pspectrum(normal,fs,'spectrogram','TimeResolution',0.5)
title('Normal Signal')

subplot(2,1,2);
pspectrum(aFib,fs,'spectrogram','TimeResolution',0.5)
title('AFib Signal')
%instfreq 函数估计信号的时变频率,作为功率谱图的第一个矩。
%可视化每个信号类型的瞬时频率
[instFreqA,tA] = instfreq(aFib,fs);
[instFreqN,tN] = instfreq(normal,fs);

figure
subplot(2,1,1);
plot(tN,instFreqN)
title('Normal Signal')
xlabel('Time (s)')
ylabel('Instantaneous Frequency')

subplot(2,1,2);
plot(tA,instFreqA)
title('AFib Signal')
xlabel('Time (s)')
ylabel('Instantaneous Frequency')
% 使用 cellfun 将 instfreq 函数应用于训练集中和测试集中的每个单元
instfreqTrain = cellfun(@(x)instfreq(x,fs)',XTrain,'UniformOutput',false);
instfreqTest = cellfun(@(x)instfreq(x,fs)',XTest,'UniformOutput',false);
instfreqValid = cellfun(@(x)instfreq(x,fs)',XValid,'UniformOutput',false);
%谱熵测量信号的频谱的尖度或平坦度。
% 具有尖峰频谱的信号(如正弦波之和)具有低谱熵。
% 具有平坦频谱的信号(如白噪声)具有高谱熵。pentropy 函数基于功率谱估计谱熵。
%可视化每个信号类型的谱熵
[pentropyA,tA2] = pentropy(aFib,fs);
[pentropyN,tN2] = pentropy(normal,fs);

figure
subplot(2,1,1)
plot(tN2,pentropyN)
title('Normal Signal')
ylabel('Spectral Entropy')

subplot(2,1,2)
plot(tA2,pentropyA)
title('AFib Signal')
xlabel('Time (s)')
ylabel('Spectral Entropy')

%使用 cellfun 将 pentropy 函数应用于训练集、测试集和验证集中的每个单元
pentropyTrain = cellfun(@(x)pentropy(x,fs)',XTrain,'UniformOutput',false);
pentropyTest = cellfun(@(x)pentropy(x,fs)',XTest,'UniformOutput',false);
pentropyValid = cellfun(@(x)pentropy(x,fs)',XValid,'UniformOutput',false);
%串联这些特征,使新的训练集和测试集中的每个单元都有两个维度(即两个特征)
XTrain2 = cellfun(@(x,y)[x;y],instfreqTrain,pentropyTrain,'UniformOutput',false);
XTest2 = cellfun(@(x,y)[x;y],instfreqTest,pentropyTest,'UniformOutput',false);
XValid2 = cellfun(@(x,y)[x;y],instfreqTest,pentropyTest,'UniformOutput',false);
%可视化新输入的格式
XTrain2(1:5)
%% 标准化数据
%瞬时频率和谱熵的均值
mean(instFreqN)
mean(pentropyN)
%使用训练集均值和标准差来标准化训练集、测试集和验证集
XV = [XTrain2{:}];
mu = mean(XV,2);
sg = std(XV,[],2);

XTrainSD = XTrain2;
XTrainSD = cellfun(@(x)(x-mu)./sg,XTrainSD,'UniformOutput',false);
XValidSD = XValid2;
XValidSD = cellfun(@(x)(x-mu)./sg,XValidSD,'UniformOutput',false);
XTestSD = XTest2;
XTestSD = cellfun(@(x)(x-mu)./sg,XTestSD,'UniformOutput',false);

%显示标准化瞬时频率和谱熵的均值
instFreqNSD = XTrainSD{1}(1,:);
pentropyNSD = XTrainSD{1}(2,:);

mean(instFreqNSD)
mean(pentropyNSD)

%% 修改 LSTM 网络架构
%现在每个信号都有两个维度,就有必要通过将输入序列大小指定为 2 来修改网络架构
layers = [ ...
    sequenceInputLayer(2)
    bilstmLayer(50,'OutputMode','last')
    fullyConnectedLayer(2)
    softmaxLayer
    ]

%指定训练选项
%将最大轮数设置为 120,以允许基于训练数据对网络进行 120 轮训练
options = trainingOptions('adam', ...
    'MaxEpochs',150, ...
    'MiniBatchSize', 200, ...
    'GradientThreshold',1, ...
    'Shuffle','every-epoch', ...
    'InitialLearnRate', 1e-3, ...
    'ExecutionEnvironment','auto',...
    'plots','training-progress', ...
    'Metrics','accuracy', ...
    'InputDataFormats','CTB', ...
    'ValidationData',{XValidSD,YValid}, ...
    'OutputNetwork','last-iteration', ...
    'Verbose',false);
%% 用时频特征训练 LSTM 网络
%使用 trainnet 用指定的训练选项和层架构训练 LSTM 网络
net2 = trainnet(XTrainSD,YTrain,layers,"crossentropy",options);
%% 可视化训练和测试准确度
%使用更新后的 LSTM 网络对训练数据进行分类
scores = minibatchpredict(net2,XTrainSD,"InputDataFormats","CTB");
trainPred2 = scores2label(scores,classNames);
LSTMAccuracy = sum(trainPred2 == YTrain)/numel(YTrain)*100
%将分类性能可视化为混淆矩阵
figure
confusionchart(YTrain,trainPred2,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');




%使用更新后的网络对测试数据进行分类。
scores = minibatchpredict(net2,XTestSD,InputDataFormats="CTB");
testPred2 = scores2label(scores,classNames);
LSTMAccuracy = sum(testPred2 == YTest)/numel(YTest)*100
% 绘制混淆矩阵以检查测试准确度。
figure
confusionchart(YTest,testPred2,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');

工程文件

https://download.csdn.net/download/XU157303764/89515137

 

 

标签:ECG,...,训练,matlab,信号,67,cellfun,LSTM
From: https://blog.csdn.net/XU157303764/article/details/140185309

相关文章