首页 > 编程语言 >用Python和Pytorch使用softmax和cross-entropy

用Python和Pytorch使用softmax和cross-entropy

时间:2023-03-24 18:14:36浏览次数:32  
标签:pre Python torch cross Pytorch entropy softmax np

softmax激活函数

softmax激活函数将包含K个元素的向量转换到(0,1)之间,并且和为1,因此它们可以用来表示概率。

 

 

 python:

def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=0)
x=np.array([0.1, 0.9, 4.0])
 
output=softmax(x)
 
print('Softmax in Python :',output)
 
#Softmax in Python : [0.04672966 0.10399876 0.84927158]

pytorch

x=torch.tensor(x)
output=torch.softmax(x,dim=0)
print(output)
 
#tensor([0.0467, 0.1040, 0.8493], dtype=torch.float64)

cross-entropy

交叉熵是常用损失,用来衡量两个分布的不同。通常以真实分布和预测分布作为输入。

 

 

#Cross Entropy Loss
 
def cross_entropy(y,y_pre):
  loss=-np.sum(y*np.log(y_pre))
  return loss/float(y_pre.shape[0])
y=np.array([0,0,1]) #class #2
 
y_pre_good=np.array([0.1,0.1,0.8])
y_pre_bed=np.array([0.8,0.1,0.1])
 
l1=cross_entropy(y,y_pre_good)
l2=cross_entropy(y,y_pre_bed)
 
print('Loss 1:',l1)
print('Loss 2:',l2)
 
Loss 1: 0.07438118377140324
Loss 2: 0.7675283643313485
loss =nn.CrossEntropyLoss()

y=torch.tensor([2])
 
y_pre_good=torch.tensor([[1.0,1.1,2.5]])
y_pre_bed=torch.tensor([[3.2,0.2,0.9]])
 
 
l1=loss(y_pre_good,y)
l2=loss(y_pre_bed,y)
 
print(l1.item()) #0.3850
print(l2.item()) #2.4398

 

参考链接:https://androidkt.com/implement-softmax-and-cross-entropy-in-python-and-pytorch/

标签:pre,Python,torch,cross,Pytorch,entropy,softmax,np
From: https://www.cnblogs.com/squirrel-7/p/17252931.html

相关文章

  • python超时处理方法eventlet的eventlet.Timeout
    一、前言在使用python进行接口自动化测试、脚本编写、执行sql的时候,如果遇到以下问题的,都可以用eventlet.timeout这个方法。执行下载数据的接口,数据量较大导致后面接口......
  • python总结
    whypython脚本比起c++更简单代码量更少,省去编译的时间。python比起rubby,pearl等其他脚本也更简洁一些,要的就是最简洁。python数据集合元组,列表,set,字典(相当于map)元组和列......
  • Python中实现获取所有微信好友的头像并拼接成一张图片
    场景实现扫码登录微信并获取所有好友的昵称以及头像,并将所有头像拼接成一张图片。实现新建文件夹weixinImage文件夹下新建文件weixinImge.py#-*-coding:utf-8-*-fromw......
  • linux环境下离线安装python3
    1、卸载旧的python3rpm-qa|greppython3|xargsrpm-ev--allmatches--nodepswhereispython3|xargsrm-frv2、安装python3http://npm.taobao.org/mirrors/python/......
  • Python中提示:UnicodeDecodeError: 'ascii' codec can't decode byte 0xe5 in position
    场景Pycharm中运行:获取所有微信好友的头像并拼接成一张图片提示:UnicodeDecodeError:'ascii'codeccan'tdecodebyte0xe5inposition......
  • python stata转mysql
    importnumpyasnpimportpyreadstataspyreadstatimportjson,re,random,pymysql,configparser,sysimportpandasaspdfromduconfigimportread_inidefdujieg......
  • Python中提示:no module named 'PIL'
    场景实现不要执行pipinstallPIL要执行pipinstallPillow如果提示超时,执行pip--default-timeout=200install-UPillow......
  • 简单介绍最新python 字符串数组互转问题
    字符串转list数组str='1,2,3'arr=str.split(',')gpu_ids分配name=opt.namegpu_ids=[int(item)foriteminopt.gpu_ids.split(',')]#setgpuidsiflen(gpu_i......
  • 从Python的turtle绘图开始学习图形化程序设计
    Turtlepython2.6版本中后引入的一个简单的绘图工具,叫做海龟绘图(TurtleGraphics),turtle库是python的内部库,使用导入即可:importturtle画布画布就是turtle为我们展开用......
  • python之迭代
    一、可迭代对象可迭代对象:窄义来讲:能够通过for……in这种方式,把元素一个个取出来的,这个对象叫可迭代对象。lst=[1,2,3,5]foriinlst:print(i)广义来讲:对象实......