首页 > 其他分享 >Logistic 回归测试代码

Logistic 回归测试代码

时间:2023-05-03 16:34:49浏览次数:44  
标签:loss pred 回归 torch test epoch Logistic 测试代码 data

简单概念

Logistic 回归是一种经典的分类方法,多用于二分类的问题。通过寻找合适的分类函数,用以对输入的数据进行预测,并给出判断结果。使用 sigmoid 函数(逻辑函数)将线性模型的结果压缩到 [0, 1] 之间,使输出的结果具有概率意义,实现输入值到输出概率的转换。

sigmoid 函数:$ g(z) = \frac{1}{1+e^{-z}} $

img

测试代码

目标:大于 3 的数输出结果为 1,小于等于 3 的数输出结果为 0。

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # cuda 加速

# 测试数据,≤3 结果为 0,>3 结果为 1
x_data = torch.Tensor([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]])
y_data = torch.Tensor([[0], [0], [0], [1], [1], [1]])


# Logistic 回归测试模型
class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(1, 1)  # 线性层

    def forward(self, x):
        output = torch.nn.functional.sigmoid(self.linear(x))
        return output


model = LogisticRegressionModel().to(device)

# 损失函数和优化器
criterion = torch.nn.BCELoss(reduction='sum').to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练
for epoch in range(1, 40001):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    if epoch % 1000 == 0:
        print(f'epoch: {epoch}, loss: {loss.item():.4f}')

    optimizer.zero_grad()  # 梯度清零
    loss.backward()        # 反向传播
    optimizer.step()       # 利用优化器对参数 x 进行更新

# 预测结果
test_data = [-2.0, -1.0, 0.0, 7.0, 8.0, 9.0]
for data in test_data:
    x_test = torch.Tensor([[data]])
    y_test = model(x_test)
    print(f'{x_test.item()} pred: {y_test.item():.8f}')

输出结果

epoch: 1000, loss: 1.2079
epoch: 2000, loss: 0.8453
epoch: 3000, loss: 0.6859
epoch: 4000, loss: 0.5891
epoch: 5000, loss: 0.5216
epoch: 6000, loss: 0.4706
epoch: 7000, loss: 0.4302
epoch: 8000, loss: 0.3970
epoch: 9000, loss: 0.3691
epoch: 10000, loss: 0.3452
epoch: 11000, loss: 0.3245
epoch: 12000, loss: 0.3062
epoch: 13000, loss: 0.2900
epoch: 14000, loss: 0.2755
epoch: 15000, loss: 0.2624
epoch: 16000, loss: 0.2506
epoch: 17000, loss: 0.2397
epoch: 18000, loss: 0.2298
epoch: 19000, loss: 0.2207
epoch: 20000, loss: 0.2123
epoch: 21000, loss: 0.2046
epoch: 22000, loss: 0.1973
epoch: 23000, loss: 0.1906
epoch: 24000, loss: 0.1843
epoch: 25000, loss: 0.1784
epoch: 26000, loss: 0.1729
epoch: 27000, loss: 0.1677
epoch: 28000, loss: 0.1628
epoch: 29000, loss: 0.1582
epoch: 30000, loss: 0.1538
epoch: 31000, loss: 0.1497
epoch: 32000, loss: 0.1458
epoch: 33000, loss: 0.1421
epoch: 34000, loss: 0.1385
epoch: 35000, loss: 0.1352
epoch: 36000, loss: 0.1319
epoch: 37000, loss: 0.1289
epoch: 38000, loss: 0.1260
epoch: 39000, loss: 0.1232
epoch: 40000, loss: 0.1205
-2.0 pred: 0.00000000
-1.0 pred: 0.00000000
0.0 pred: 0.00000000
7.0 pred: 1.00000000
8.0 pred: 1.00000000
9.0 pred: 1.00000000

标签:loss,pred,回归,torch,test,epoch,Logistic,测试代码,data
From: https://www.cnblogs.com/zhangxiaochn/p/17369220.html

相关文章

  • 异构图中节点的分类/回归
    异构图中节点的分类/回归导入包importnumpyasnpimporttorchimportdglimporttorch.nnasnnimporttorch.nn.functionalasFimportdgl.nnasdglnn创建一个异构图设置这个图中的节点个数和边的个数n_users=100#user节点个数n_jobspre=500#jobpre......
  • MATLAB实现PSO-SVM多输入单输出回归预测(粒子群算法优化支持向量机)
    ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。......
  • 【数据挖掘&机器学习】招聘网站的职位招聘数据的分位数图、分位数-分位数图以及散点图
    一.本次需求背景本文主题:招聘网站的职位招聘数据的分位数图、分位数-分位数图以及散点图、使用线性回归算法拟合散点图处理详解之前的文章我们已经对爬取的数据做了清洗处理,然后又对其数据做了一个薪资数据的倾斜情况以及盒图离群点的探究。我们这次的需求是:使用散点图、使用......
  • python用支持向量机回归(SVR)模型分析用电量预测电力消费|附代码数据
    全文链接:http://tecdat.cn/?p=23921最近我们被客户要求撰写关于SVR的研究报告,包括一些图形和统计输出。本文描述了训练支持向量回归模型的过程,该模型用于预测基于几个天气变量、一天中的某个小时、以及这一天是周末/假日/在家工作日还是普通工作日的用电量关于支持向量机的快速......
  • Python用RNN神经网络:LSTM、GRU、回归和ARIMA对COVID19新冠疫情人数时间序列预测|附代
    全文下载链接: http://tecdat.cn/?p=27042最近我们被客户要求撰写关于新冠疫情的研究报告,包括一些图形和统计输出。在本文中,该数据根据世界各国提供的新病例数据提供。获取时间序列数据df=pd.read_csv("C://global.csv")探索数据此表中的数据以累积的形式呈现,为了找出每天......
  • softmax回归的简洁实现
    softmax回归的简洁实现通过深度学习框架的高级API能够使实现softmax回归模型更方便地实现继续使用Fashion-MNIST数据集,并保持批量大小为256。importtorchfromtorchimportnnfromd2limporttorchasd2lbatch_size=256train_iter,test_iter=d2l.load_data_fash......
  • 多元线性回归
    1绪论2预备知识2.1多元线性回归分析法基本思想2.2多元线性回归分析法的理论模型2.3多元线性回归分析的计算步骤2.3.1参数估计2.3.2假设检验2.4Python语言操作步骤3多元线性回归模型的建立与分析3.1数据收集与分析3.......
  • 线性回归
    线性回归线性模型利用特征的线性函数进行预测,这里的线性指的是参数是线性的。一、普通最小二乘法线性回归(OLS)是最简单&最经典的线性方法,模型寻找截距和系数,使得模型对训练集的预测值与真实值之间的均方误差(MSE)最小,但是线性回归没有办法控制模型的复杂度(模型有大量的非0参数)。......
  • 多元时间序列滚动预测:ARIMA、回归、ARIMAX模型分析|附代码数据
    原文链接:http://tecdat.cn/?p=22849最近我们被客户要求撰写关于多元时间序列滚动预测的研究报告,包括一些图形和统计输出。当需要为数据选择最合适的预测模型或方法时,预测者通常将可用的样本分成两部分:内样本(又称"训练集")和保留样本(或外样本,或"测试集")。然后,在样本中估计模型,并......
  • 数据分享|逻辑回归、随机森林、SVM支持向量机预测心脏病风险数据和模型诊断可视化|附
    原文链接:http://tecdat.cn/?p=24973最近我们被客户要求撰写关于心脏病的研究报告,包括一些图形和统计输出。世界卫生组织估计全世界每年有1200万人死于心脏病。在美国和其他发达国家,一半的死亡是由于心血管疾病简介心血管疾病的早期预后可以帮助决定改变高危患者的生活方式,从......