首页 > 其他分享 >利用torch.nn实现前馈神经网络解决二分类问题

利用torch.nn实现前馈神经网络解决二分类问题

时间:2022-10-24 13:24:08浏览次数:78  
标签:acc loss num nn torch 前馈 train ls test

5、利用torch.nn实现前馈神经网络解决二分类问题

#导入必要的包
import torch 
import torch.nn as nn
from torch.utils.data import TensorDataset,DataLoader
from torch.nn import init
import torch.optim as optim
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
#创建数据集
num_inputs,num_example = 200,10000
x1 = torch.normal(2,1,(num_example,num_inputs))
y1 = torch.ones((num_example,1))
x2 = torch.normal(-2,1,(num_example,num_inputs))
y2 = torch.zeros((num_example,1))
x_data = torch.cat((x1,x2),dim=0)
y_data = torch.cat((y1,y2),dim = 0)
train_x,test_x,train_y,test_y = train_test_split(x_data,y_data,shuffle=True,test_size=0.3,stratify=y_data)
#读取数据
batch_size = 256
train_dataset = TensorDataset(train_x,train_y)
train_iter = DataLoader(
    dataset = train_dataset,
    shuffle = True,
    num_workers = 0,
    batch_size = batch_size
)
test_dataset = TensorDataset(test_x,test_y)
test_iter = DataLoader(
    dataset = test_dataset,
    shuffle = True,
    num_workers = 0,
    batch_size = batch_size
)
#定义模型
num_input,num_hidden,num_output = 200,256,1
class net(nn.Module):
    def __init__(self,num_input,num_hidden,num_output):
        super(net,self).__init__()
        self.linear1 = nn.Linear(num_input,num_hidden,bias =False)
        self.linear2 = nn.Linear(num_hidden,num_output,bias=False)
    def forward(self,input):
        out = self.linear1(input)
        out = self.linear2(out)
        return out
model = net(num_input,num_hidden,num_output)
print(model)
net(
  (linear1): Linear(in_features=200, out_features=256, bias=False)
  (linear2): Linear(in_features=256, out_features=1, bias=False)
)
#参数初始化
for param in model.parameters():
    init.normal_(param,mean=0,std=0.001)
#定义训练函数
lr = 0.001 #学习率
loss = nn.BCEWithLogitsLoss() #损失函数
optimizer = optim.SGD(model.parameters(),lr) #优化器
def train(net,train_iter,test_iter,loss,num_epochs,batch_size):
    train_ls,test_ls,train_acc,test_acc = [],[],[],[]
    for epoch in range(num_epochs):
        train_ls_sum,train_acc_sum,n = 0,0,0
        for x,y in train_iter:
            y_pred = model(x)
            l = loss(y_pred,y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_ls_sum +=l.item()
            train_acc_sum += (((y_pred>0.5)==y)+0.0).sum().item()
            n += y_pred.shape[0]
        train_ls.append(train_ls_sum)
        train_acc.append(train_acc_sum/n)
        
        test_ls_sum,test_acc_sum,n = 0,0,0
        for x,y in test_iter:
            y_pred = model(x)
            l = loss(y_pred,y)
            test_ls_sum +=l.item()
            test_acc_sum += (((y_pred>0.5)==y)+0.0).sum().item()
            n += y_pred.shape[0]
        test_ls.append(test_ls_sum)
        test_acc.append(test_acc_sum/n)
        print('epoch %d, train_loss %.6f,test_loss %f, train_acc %.6f,test_acc %f'
              %(epoch+1, train_ls[epoch],test_ls[epoch], train_acc[epoch],test_acc[epoch]))
    return train_ls,test_ls,train_acc,test_acc
#训练次数和学习率
num_epochs = 10
train_loss,test_loss,train_acc,test_acc = train(model,train_iter,test_iter,loss,num_epochs,batch_size)
#结果可视化
x = np.linspace(0,len(train_loss),len(train_loss))
plt.plot(x,train_loss,label="train_loss",linewidth=1.5)
plt.plot(x,test_loss,label="test_loss",linewidth=1.5)

plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()
plt.show()

标签:acc,loss,num,nn,torch,前馈,train,ls,test
From: https://www.cnblogs.com/cyberbase/p/16821143.html

相关文章

  • 在多分类任务实验中用torch.nn实现
    12、在多分类任务实验中用torch.nn实现......
  • 在多分类任务实验中用torch.nn实现dropout
    10、在多分类任务实验中用torch.nn实现dropout#导入必要的包importtorchimporttorch.nnasnnimportnumpyasnpimporttorchvisionimporttorchvision.transform......
  • Xshell连接虚拟机的Centos报错Could not connect to 主机地址感慨
    网上搜的一堆方法对我都没用,搜到的几乎都是一致的答案什么改ens33的配置,关防火墙。首先应该明确自己问题出在哪里,比如我就是ping外网显示没有这个名字(也就是失败了,可能......
  • Mysql索引原理揭秘之——MyISAM和InnoDB
    MyISAM引擎的索引实现在MyISAM里面,另外有两个文件,一个是.MYD文件,D代表Data,是MyISAM的数据文件,存放数据记录,比如我们的user_myisam表的所有的表数据;一个是.MYI文件,I代表Inde......
  • Scanner类
    ScannerScannerScanner类可以获取用户的输入基本语法Scanners=newScanner(System.in);通过Scanner类的next()和nextLine()方法来获取输入的字符串,读取之前一般会......
  • RuntimeError: Deterministic behavior was enabled with either `torch.use_determin
    在CORL的代码中,出现了一种error:  可经过如下方法解决:cuda10.1os.environ['CUDA_LAUNCH_BLOCKING']='1'cuda10.2及以上os.environ['CUBLAS_WORKSPACE_CONFIG']......
  • innordb并发基础知识
    说明:以下内容仅仅针对innordb引擎,其他的不一定通用 1、并发中可能存在的问题(1)读读不会有问题(2)读写有脏读、不可重复读、幻读的问题(3)写写有写丢失的问题 2、事......
  • 【EF Core】Data Annotations之ComplexType 复杂类型
    EFCoreCodeFirst代码优先中的复杂类型复杂类型在EF4.1中很容易实现。想象客户实体类有一些像城市,邮政编码和街道的属性,我们发现把这些属性组织成一个叫地址的复杂类......
  • 《PyTorch 深度学习实践 》 刘二大人 第十讲
    课堂练习:1importtorch2fromtorchvisionimporttransforms3fromtorchvisionimportdatasets4fromtorch.utils.dataimportDataLoader5importtorch.......
  • Faster R-CNN理论合集
    FasterR-CNNR-CNN(RegionwithCNNfeature)算法流程​ RCNN算法流程可分为4个步骤一张图像生成1k~2k个候选区域(使用SelectiveSearch方法)对每个候选区域,使用......