首页 > 其他分享 >如何得到深度学习模型的参数量和计算复杂度

如何得到深度学习模型的参数量和计算复杂度

时间:2025-01-03 21:34:25浏览次数:1  
标签:__ 25 nn 模型 self 36 复杂度 深度 节点

1.准备好网络模型代码

import torch
import torch.nn as nn
import torch.optim as optim

# BP_36: 输入2个节点,中间层36个节点,输出25个节点
class BP_36(nn.Module):
    def __init__(self):
        super(BP_36, self).__init__()
        self.fc1 = nn.Linear(2, 36)  # 输入2个节点,中间层36个节点
        self.fc2 = nn.Linear(36, 25)  # 输出25个节点

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # 使用ReLU激活函数
        x = self.fc2(x)
        return x

# BP_64: 输入2个节点,中间层64个节点,输出25个节点
class BP_64(nn.Module):
    def __init__(self):
        super(BP_64, self).__init__()
        self.fc1 = nn.Linear(2, 64)  # 输入2个节点,中间层64个节点
        self.fc2 = nn.Linear(64, 25)  # 输出25个节点

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # 使用ReLU激活函数
        x = self.fc2(x)
        return x

# Bi-LSTM: 输入2个节点,中间层36个节点,线性层输入72个节点,输出25个节点
class Bi_LSTM(nn.Module):
    def __init__(self):
        super(Bi_LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=2, hidden_size=36, bidirectional=True, batch_first=True)  # 双向LSTM
        self.fc1 = nn.Linear(72, 25)  # LSTM的输出72维,经过线性层后输出25个节点

    def forward(self, x):
        # x的形状应该是(batch_size, seq_len, input_size)
        x, _ = self.lstm(x)  # 输出LSTM的结果
        x = self.fc1(x)
        return x

# Bi-GRU: 输入2个节点,中间层36个节点,线性层输入72个节点,输出25个节点
class Bi_GRU(nn.Module):
    def __init__(self):
        super(Bi_GRU, self).__init__()
        self.gru = nn.GRU(input_size=2, hidden_size=36, bidirectional=True, batch_first=True)  # 双向GRU
        self.fc1 = nn.Linear(72, 25)  # GRU的输出72维,经过线性层后输出25个节点

    def forward(self, x):
        # x的形状应该是(batch_size, seq_len, input_size)
        x, _ = self.gru(x)  # 输出GRU的结果
        x = self.fc1(x)
        return x

2.运行计算参数量和复杂度的脚本

import torch
# from net import BP_36
# from net import BP_64
# from net import Bi_LSTM
from net import Bi_GRU

from ptflops import get_model_complexity_info
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# 统计Transformer模型的参数量和计算复杂度
model_transformer = Bi_GRU()
model_transformer.to(device)
flops_transformer, params_transformer = get_model_complexity_info(model_transformer, (256,2), as_strings=True, print_per_layer_stat=False)
print('模型参数量:' + params_transformer)
print('模型计算复杂度:' + flops_transformer)

标签:__,25,nn,模型,self,36,复杂度,深度,节点
From: https://www.cnblogs.com/fly-smart/p/18650943

相关文章

  • 机器学习之模型评估——混淆矩阵,交叉验证与数据标准化
    目录混淆矩阵交叉验证数据标准化        0-1标准化        z标准化混淆矩阵混淆矩阵(ConfusionMatrix)是一种用于评估分类模型性能的工具。它是一个二维表格,其中行表示实际的类别,列表示模型预测的类别。假设我们有一个二分类问题(类别为正例和反例),......
  • 国内AI大模型前十排行榜,最后一个你可能没听过
    根据2024年的最新数据和搜索结果,国内AI大模型的前十排行榜阿里云通义千问(Qwen2-72B):在SuperCLUE基准测试中得分最高,超过众多国内外闭源模型,引领全球的开源生态。华为盘古大模型:凭借其强大的技术能力和行业应用得到广泛认可。百度文心一言(ERNIEBot):专注于自然语言理解......
  • 我的AI工具箱Tauri版-SEOManage大模型撰写上稿网站
    本教程基于自研的AI工具箱Tauri版进行SEOManage大模型撰写上稿网站自动SEO。SEOManage网站自动SEO是一款专为网站优化和内容生产设计的AI工具,支持高效撰写关键词文章并实现自动化上稿。基于LMStudio本地大模型,SEOManage通过智能模板匹配和关键词策略生成,为用户提供从文......
  • 数据链路层是OSI模型的第二层,负责在相邻节点间传输数据
    数据链路层是OSI模型的第二层,负责在相邻节点间传输数据。在这一层中,数据以帧(Frame)的形式进行封装和传输。帧是数据链路层的基本传输单位,它不仅包括实际要传输的数据,还包括控制信息,如源地址、目的地址、错误检测码等。这些控制信息帮助接收方正确解读数据,并进行必要的错误处......
  • 深度学习笔记08-YOLOv5-C3模块实现
    本文实现了YOLVv5-C3模块。文章目录前言一、加载数据1.引入库2.导入数据3.自定义transforms4.查看类别5.划分数据集6.加载数据二、建立模型1.搭建模型2.查看模型详情三、训练模型1.训练函数2.测试函数3.main4.结果可视化5.模型评估总结前言......
  • 深度学习基础理论————训练加速(单/半/混合精度训练)/显存优化(gradient-checkpoint)
    主要介绍单精度/半精度/混合精度训练,以及部分框架(DeepSpeed/Apex)不同精度训练单精度训练(single-precision)指的是用32位浮点数(FP32)表示所有的参数、激活值和梯度半精度训练(half-precision)指的是用16位浮点数(FP16或BF16)表示数据。(FP16是IEEE标准,BF16是一种更适合AI计算的......
  • 《机器学习》--线性回归模型详解
    线性回归模型是机器学习中的一种重要算法,以下是对其的详细解释:一、定义与原理线性回归(LinearRegression)是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。线性回归利用称为线性回归方程的最小平方函数对一个或多个自变量和因变......
  • a16z:小模型 + 边缘 AI 将定义 2025;音效模型 TangoFlux:3 秒钟生成 30 秒音频丨RTE 开发
      开发者朋友们大家好: 这里是「RTE开发者日报」,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享RTE(Real-TimeEngagement)领域内「有话题的新闻」、「有态度的观点」、「有意思的数据」、「有思考的文章」、「有看点的会议」,但内容仅代表编辑......
  • 小白也能懂文本挖掘之LDA主题模型及代码详解
    文章主要重实际应用,不做过多理论推导  LDA(LatentDirichletAllocation)主题分析模型,即潜在狄利克雷分配模型,是一种文档生成模型,也是一种无监督机器学习技术。(无监督学习即需要手动输入主题数量,下一期进行讲解如何确定LDA主题数)一、LDA模型的基本概念  LDA模型认为一......
  • 江大白 | 基于腾讯混元大模型,业务落地实践汇总!
    本文来源公众号“江大白”,仅用于学术分享,侵权删,干货满满。原文链接:基于腾讯混元大模型,业务落地实践汇总!祝各位同仁元旦快乐!2025继续学习,越来越强!导读本文探讨腾讯大语言模型在内容生成、智能客服等场景的应用,解析RAG技术在文档生成、问答系统的优势,探讨GraphRAG在角色扮演......