首页 > 其他分享 >PyTorch实现多分类任务

PyTorch实现多分类任务

时间:2024-06-07 16:57:40浏览次数:26  
标签:torch nn 分类 任务 PyTorch num input classes size

import torch
import torch.nn as nn
import torch.optim as optim

'''定义模型'''
class SimpleModel(nn.Module):
    '''
    方便理解,这里只定义了一层网络
    input_size: 输入维度(这里表示每个样本的特征数量)
    num_classes: 输出维度(这里表示类别数量)
    '''
    def __init__(self, input_size, num_classes) -> None:
        super().__init__()
        self.linear = nn.Linear(input_size, num_classes)
        
    def forward(self, x):
        x = self.linear(x)
        # 此处没有使用激活函数  
        return x    
    
'''构造模拟数据'''
input_data = torch.randn(5, 10)  # 5个样本,每个样本有10个特征  
labels = torch.tensor([0, 1, 2, 0, 1])  # 5个样本,每个样本对应一个类别标签(三分类:0,1,2)

'''实例化模型'''
model = SimpleModel(input_size=10, num_classes=3)

'''定义损失函数'''
criterion = nn.CrossEntropyLoss()

'''定义优化器'''
optimizer = optim.SGD(model.parameters(), lr=0.1)   

epoch = 1000
for i in range(epoch):
    '''前向传播'''
    y = model(input_data)

    '''计算损失'''
    # 此处,criterion会自动将y进行softmax处理,所以不需要显式地在模型中定义softmax  
    # 同时,labels不需要进行one-hot编码,因为criterion会自动完成这一操作    
    loss = criterion(y, labels)
    print(f'Epoch {i+1}/{epoch}, Loss: {loss}')   
    
    '''反向传播 和 更新参数'''
    # 梯度清零
    optimizer.zero_grad()
    # 计算梯度
    loss.backward()
    # 更新参数
    optimizer.step() 
    
    

  

标签:torch,nn,分类,任务,PyTorch,num,input,classes,size
From: https://www.cnblogs.com/zhangyh-blog/p/18237490

相关文章

  • 【医疗器械产品分类规则了解】
    分类目录由国家食品药品监督管理部门依据医疗器械分类规则制定:医疗器械按照风险程度由低到高,管理类别依次分为第一类、第二类和第三类。医疗器械风险程度,应当根据医疗器械的预期目的,通过结构特征、使用形式、使用状态、是否接触人体等因素综合判定。第一类医疗器械是风险程度......
  • BERT+P-Tuning文本分类模型
    基于BERT+P-Tuning方式文本分类模型搭建模型搭建本项目中完成BERT+P-Tuning模型搭建、训练及应用的步骤如下(注意:因为本项目中使用的是BERT预训练模型,所以直接加载即可,无需重复搭建模型架构):一、实现模型工具类函数二、实现模型训练函数,验证函数三、实现模型预测函......
  • Python实现投递多线程任务
    使用Python的apscheduler库中的BackgroundScheduler实现投递多线程任务的示例代码。这个示例将展示如何根据任务ID投递和停止任务,设置任务同时执行的上限,以及删除全部任务。首先,确保你已经安装了apscheduler库:``pipinstallapscheduler``代码示例:``fromapscheduler.sched......
  • 【已解决】Python报错Pytorch:ModuleNotFoundError: No module named ‘torch’
    本文摘要:本文已解决Pytorch:ModuleNotFoundError:Nomodulenamed‘torch’的相关报错问题,并总结提出了几种可用解决方案。同时结合人工智能GPT排除可能得隐患及错误。......
  • teamcenter 按照审批节点和节点的目标分组统计任务数量
    selectcount(*),--(case--when'L8_DesignRevision'then'图对象'--when'L8_DocumentRevision'then'文档'--when'L8_JcsjDocumentRevision'then'检测数据'--WHENINSTR(v.pobject_type,'PartRevi......
  • 48.线程池提交任务的方法
     execute方法submit方法提交任务task,用返回值Future获得任务执行结果。Future用于主线程接受线程池中线程的返回结果。ExecutorServiceexecutorService=Executors.newFixedThreadPool(2);//提交第一个任务返回结果Future<String>future=execu......
  • 基于springboot的相亲网站管理系统,相亲管理系统,附源码+数据库+论文+开题报告+任务书+P
    1、项目介绍相亲网站根据使用权限的角度进行功能分析,并运用用例图来展示各个权限需要操作的功能。管理员权限操作的功能包括管理婚礼公司,管理婚礼公司预约信息,管理结婚案例,管理相亲信息,管理相亲留言,管理用户等。用户权限操作的功能包括预约婚礼公司,收藏婚礼公司,查看结婚......
  • 苍穹外卖笔记-06-菜品管理-菜品分类,公共字段填充
    菜品分类1菜品分类模块1.1需求分析与设计1.1.1产品原型1.1.2接口设计1.1.3表设计1.3代码实现1.4测试分类分页查询启用禁用分类修改分类信息新增菜品分类删除菜品分类2公共字段自动填充2.1问题分析2.2实现思路自定义注解AutoFill自定义切面AutoFillAspectMap......
  • 环境配置·Ubuntu1804安装CUDA和Pytorch
    InitUbuntuandchangedeb&pipsourcewgethttps://github.com/blueflylabor/blueflylabor.github.io/blob/main/toolbox/initUbuntu/initUbuntu.shbash./initUbuntu.shCUDA11.6wgethttps://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64......
  • QShop商城-Quartz定时任务
    QShop商城-Quartz定时任务编写任务代码在Qs.App中编写定时任务的执行代码。比如添加订单完成定时器[JobOrderDone]namespaceQs.App.Jobs{publicclassJobOrderDone:IJob{privateApp......