1.准备好网络模型代码
import torch
import torch.nn as nn
import torch.optim as optim
# BP_36: 输入2个节点,中间层36个节点,输出25个节点
class BP_36(nn.Module):
def __init__(self):
super(BP_36, self).__init__()
self.fc1 = nn.Linear(2, 36) # 输入2个节点,中间层36个节点
self.fc2 = nn.Linear(36, 25) # 输出25个节点
def forward(self, x):
x = torch.relu(self.fc1(x)) # 使用ReLU激活函数
x = self.fc2(x)
return x
# BP_64: 输入2个节点,中间层64个节点,输出25个节点
class BP_64(nn.Module):
def __init__(self):
super(BP_64, self).__init__()
self.fc1 = nn.Linear(2, 64) # 输入2个节点,中间层64个节点
self.fc2 = nn.Linear(64, 25) # 输出25个节点
def forward(self, x):
x = torch.relu(self.fc1(x)) # 使用ReLU激活函数
x = self.fc2(x)
return x
# Bi-LSTM: 输入2个节点,中间层36个节点,线性层输入72个节点,输出25个节点
class Bi_LSTM(nn.Module):
def __init__(self):
super(Bi_LSTM, self).__init__()
self.lstm = nn.LSTM(input_size=2, hidden_size=36, bidirectional=True, batch_first=True) # 双向LSTM
self.fc1 = nn.Linear(72, 25) # LSTM的输出72维,经过线性层后输出25个节点
def forward(self, x):
# x的形状应该是(batch_size, seq_len, input_size)
x, _ = self.lstm(x) # 输出LSTM的结果
x = self.fc1(x)
return x
# Bi-GRU: 输入2个节点,中间层36个节点,线性层输入72个节点,输出25个节点
class Bi_GRU(nn.Module):
def __init__(self):
super(Bi_GRU, self).__init__()
self.gru = nn.GRU(input_size=2, hidden_size=36, bidirectional=True, batch_first=True) # 双向GRU
self.fc1 = nn.Linear(72, 25) # GRU的输出72维,经过线性层后输出25个节点
def forward(self, x):
# x的形状应该是(batch_size, seq_len, input_size)
x, _ = self.gru(x) # 输出GRU的结果
x = self.fc1(x)
return x
2.运行计算参数量和复杂度的脚本
import torch
# from net import BP_36
# from net import BP_64
# from net import Bi_LSTM
from net import Bi_GRU
from ptflops import get_model_complexity_info
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 统计Transformer模型的参数量和计算复杂度
model_transformer = Bi_GRU()
model_transformer.to(device)
flops_transformer, params_transformer = get_model_complexity_info(model_transformer, (256,2), as_strings=True, print_per_layer_stat=False)
print('模型参数量:' + params_transformer)
print('模型计算复杂度:' + flops_transformer)
标签:__,25,nn,模型,self,36,复杂度,深度,节点
From: https://www.cnblogs.com/fly-smart/p/18650943