首页 > 其他分享 >基于词嵌入的逻辑回归文本分类

基于词嵌入的逻辑回归文本分类

时间:2023-04-19 17:32:06浏览次数:40  
标签:dim 逻辑 嵌入 torch 50 train input model 文本

简述逻辑回归(Logistic Regression)原理,并用torch实现逻辑回归文本分类,原始数据一共有100条句子,每个样本是一条句子,每个句子有50个单词,每个单词用长为50的词向量表示。现在需要用一条句子预测一个类别,本文给出torch案例

逻辑回归是一种常用的分类算法,它是一种线性分类模型。逻辑回归的目标是通过给定的输入特征,预测输出的二分类结果。它的原理是将输入特征与一组权重进行线性组合,然后将这个结果输入到一个逻辑函数中,得到一个0~1之间的概率值。逻辑回归的核心思想是用概率来表示分类结果,通过设置阈值来将概率值转换为二分类结果。

下面是一个用PyTorch实现逻辑回归文本分类的例子,假设我们有100条句子,每个样本是一条句子,每个句子有50个单词,每个单词用长为50的词向量表示。我们需要将这些数据输入到逻辑回归模型中,来预测每个句子所属的类别。

import torch
import torch.nn as nn

# 定义逻辑回归模型
class LogisticRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.linear(x)
        out = self.sigmoid(out)
        return out

# 定义模型参数
input_dim = 50 * 50
output_dim = 2

# 创建模型实例
model = LogisticRegression(input_dim, output_dim)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 定义训练函数
def train(model, optimizer, criterion, num_epochs, train_loader):
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        for i, (inputs, labels) in enumerate(train_loader):
            # 将输入数据转换为一维张量
            inputs = inputs.view(-1, input_dim)
            labels = labels.squeeze()

            # 前向传播
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 输出统计信息
            if (i+1) % 10 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                       .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# 定义数据集
data = torch.randn(100, 50, 50)
labels = torch.randint(0, 2, (100, ))

# 定义数据加载器
batch_size = 10
train_dataset = torch.utils.data.TensorDataset(data, labels)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 训练模型
num_epochs = 10
train(model, optimizer, criterion, num_epochs, train_loader)

# 使用模型进行预测
test_input = torch.randn(1, 50, 50).view(-1, input_dim)
output = model(test_input)
predicted = torch.argmax


# 使用模型进行预测
test_input = torch.randn(1, 50, 50).view(-1, input_dim)
output = model(test_input)
predicted = torch.argmax(output, dim=1)

print("Predicted class:", predicted.item())

# 保存模型
torch.save(model.state_dict(), "logistic_regression_model.pt")

这个例子中,我们使用了PyTorch实现了一个简单的逻辑回归模型。首先我们定义了一个LogisticRegression类,它继承自nn.Module,并定义了一个包含线性层和Sigmoid函数的前向传播函数。然后我们定义了模型参数、损失函数和优化器,并创建了一个数据加载器,用于加载数据集。

在训练过程中,我们使用了一个简单的循环来迭代数据集,并在每个迭代步骤中计算损失、反向传播和优化器更新。最后,我们使用训练好的模型来进行预测,并将预测结果输出到控制台。

注意,这个例子中的数据是随机生成的,实际情况下需要根据具体的数据集进行修改。另外,在实际应用中,我们可能需要更复杂的模型和更多的训练数据,以获得更好的预测效果。

标签:dim,逻辑,嵌入,torch,50,train,input,model,文本
From: https://blog.51cto.com/guog/6207039

相关文章

  • font-display 文本在网页字体加载期间保持可见状态
    为确保文本在网页字体加载期间保持可见状态,可以考虑以下几种方法:1.使用系统默认字体或web-safe字体:这些字体通常是已经在大多数操作系统和浏览器中安装和加载的,因此在页面加载期间可以立即呈现。这样,即使自定义字体尚未加载,文本也将始终可见。2.通过CSS实现字体预加载:可以在C......
  • tinymce 复制粘贴时去除文本里面的样式
    第一步需要引入tinymce自带的一个粘贴插件"paste",代码如下;import"tinymce/plugins/paste";tinymce.init({...其他配置,plugins:["paste"],});第二步将以下几个参数放到配置项中,亲测这些参数都有效,比如从excel表格中复制过来的内容可以清除table样式。import"tinymce/plugin......
  • vue利用正则去除富文本的标签和样式
    constremoveHtmlStyle=(html:any)=>{letrelStyle=/style\s*?=\s*?([‘"])[\s\S]*?\1/g;//去除样式letrelTag=/<.+?>/g;//去除标签letrelClass=/class\s*?=\s*?([‘"])[\s\S]*?\1/g;//清除类名letnewHtml="";  if(html){......
  • NULL值引入导致新增的unknown逻辑值 以及 SQL server中ANSI_NULLS的使用
    部分参考文章:https://www.bbsmax.com/A/A7zgEOVl54/ [BBSMAX]Lumia1020 2022-11-08https://www.cnblogs.com/SFAN/p/4343703.htmlcnblogs@ sunnyboy 2015-03-1710:17wikipedia三值逻辑:https://zh.wikipedia.org/wiki/%E4%B8%89%E5%80%BC%E9%80%BB......
  • 论文阅读记录3——基于提示学习的小样本文本分类方法——计算机应用
     方法:首先,利用预训练模型BERT在标注样本上学习到最优的提示模板;然后,在每条样本中补充提示模板和空缺,将文本分类任务转化为完形填空任务;最后,通过预测空缺位置的填充词,结合填充词与标签之间的映射关系得到最终的标签。原因:文本分类任务通常依赖足量的标注数据,针对低资源场景......
  • css文本
    1、colorcolor:red;设置字体颜色2、text-aligntext-align:center;设置文本的水平对齐方式,可选项:center居中对齐,right向右对齐,left向左对齐,justify两端对齐3、文本修饰text-decoration:none;可选项:overline上划线,line-through删除线,underline下划线,none没有4、大小......
  • 痞子衡嵌入式:我被邀请做嵌入式联盟主办的职场奇葩说(上海站)辩手
    「嵌入式联盟」是「科锐国际」联合圈子里一些有影响力的公众号主组建起来的嵌入式行业人才的专属社区。联盟致力于为嵌入式领域从业者提供线下交流与分享的机会,定期进行技术及行业信息等深度的探讨,满足嵌入式人才零距离交流及互助需求。痞子衡有幸被邀请做3月26日联盟首期活动“嵌......
  • 自然语言处理:词嵌入简介
    动动发财的小手,点个赞吧!WordEmbeddings机器学习模型“查看”数据的方式与我们(人类)的方式不同。例如,我们可以轻松理解“我看到一只猫”这一文本,但我们的模型却不能——它们需要特征向量。此类向量或词嵌入是可以输入模型的词的表示。工作原理:查找表(词汇)在实践中,你有一个允许......
  • 痞子衡嵌入式:恩智浦经典LPC系列MCU内部Flash IAP驱动入门
    大家好,我是痞子衡,是正经搞技术的痞子。今天痞子衡给大家介绍的是恩智浦经典LPC系列MCU内部FlashIAP驱动。LPC系列MCU是恩智浦公司于2003年开始推出的非常具有代表性的产品,距今已经有近20年的生命。按时间线演进来说,其主要分为三代:-元老:基于ARM7/9内核的LPC2000......
  • 痞子衡嵌入式:利用i.MXRT1xxx系列ROM集成的DCD功能可轻松配置指定外设
    大家好,我是痞子衡,是正经搞技术的痞子。今天痞子衡给大家介绍的是利用i.MXRT1xxx系列ROM集成的DCD功能可轻松配置指定外设。关于i.MXRT1xxx系列芯片BootROM中集成的DCD功能这个话题,痞子衡早就想写了,但是一直没有动笔,毕竟这个话题比较生涩,单独讲会比较枯燥。最近痞子衡......