首页 > 其他分享 >深度学习实战之找最大数字

深度学习实战之找最大数字

时间:2024-06-30 22:32:06浏览次数:16  
标签:实战 loss 数字 训练 torch batch 深度 model size

文章目录

前言

之前学习了深度学习的概念与基本过程,今天用一个简单的深度学习框架实现最大数字的找寻,理解深度学习的的基本流程。

问题描述

假设有一个5维数组, X = [ 2 , 3 , 7 , 4 , 5 ] X=[2,3,7,4,5] X=[2,3,7,4,5],则定义 Y = [ 0 , 0 , 1 , 0 , 0 ] Y=[0,0,1,0,0] Y=[0,0,1,0,0], X X X中数字最大的是第三个,所以 Y Y Y的第三个为1,并将他定义为第三类。
步骤1:生成训练集
步骤2:定义神经网络
步骤3:代入数据训练参数
步骤4:模型评估预测

生成训练集

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt##导入各种库

def build_sample():##数据规则
    x=np.random.random(5)
    y=np.zeros(5)
    y[np.argmax(x)]=1
    return x,y

def build_dataset(total_samlple_num):##生成数据
    X=[]
    Y=[]
    for i in range(total_samlple_num):
        x,y=build_sample()
        X.append(x)
        Y.append(y)
    return torch.FloatTensor(np.array(X)),torch.FloatTensor(np.array(Y))

上面的代码块第一部分是导入各种库,第二部分是定义规则,先随机生成5维数组,然后将根据x的最大最小数值确定y的哪个值为1.第三部分为批量生成数据,并返回torch的浮点类型的数字,便于训练。

定义神经网络

我们利用最简单的线性层和激活层实现这个功能,我们的x是5维的,最终输出的y也是5维的,因为线性层有 y = X w + b y=Xw+b y=Xw+b所以 w w w应该也是5*5的。我们生成的y值每一维应该是0到1的,属于五类数字的概率,哪个概率大就为哪类,所以我们激活函数可以选择为sigmoid函数,损失函数选择为交叉熵损失函数(因为是分类问题),最终的神经网络可以如下图(作者的手画):
在这里插入图片描述
利用代码定义这个网络层,有如下代码:

class TorchModel(nn.Module):
    def __init__(self,input_size):
        super(TorchModel,self).__init__()
        self.linear=nn.Linear(input_size,5)
        self.activation=torch.sigmoid
        self.loss=nn.CrossEntropyLoss()
    def forward(self,x,y=None):
        x=self.linear(x)
        y_pred=self.activation(x)
        if y is not None:
            return self.loss(y_pred,y)
        else:
            return y_pred

对于forward,有y值输入时为训练,计算损失函数,没有y值时时进行前向传播,输出预测值。

进行训练

# 测试代码
# 用来测试每轮模型的准确率
def evaluate(model):
    model.eval()
    test_sample_num = 100
    x, y = build_dataset(test_sample_num)
    a,b,c,d,e=0,0,0,0,0
    for i in range(len(y)):
        if y[i].argmax()==0:
            a+=1
        elif y[i].argmax()==1:
            b+=1
        elif y[i].argmax()==2:
            c+=1
        elif y[i].argmax()==3:
            d+=1
        else:
            e+=1
    print("本次预测集中共有%d个一类样本,%d个二类样本,%d个三类样本,%d个四类样本,%d个五类样本" % (a,b,c,d,e))
    correct, wrong = 0, 0
    with torch.no_grad():
        y_pred = model(x)  # 模型预测
        for y_p, y_t in zip(y_pred, y):  # 与真实标签进行对比
            if int(y_p.argmax())==int(y_t.argmax()):
                correct += 1  # 样本判断正确
            else:
                wrong += 1
    print("正确预测个数:%d, 正确率:%f" % (correct, correct / (correct + wrong)))
    return correct / (correct + wrong)

def main():
    # 配置参数
    epoch_num = 20  # 训练轮数
    batch_size = 20  # 每次训练样本个数
    train_sample = 5000  # 每轮训练总共训练的样本总数
    input_size = 5  # 输入向量维度
    learning_rate = 0.001  # 学习率
    # 建立模型
    model = TorchModel(input_size)
    # 选择优化器
    optim = torch.optim.Adam(model.parameters(), lr=learning_rate)
    log = []
    # 创建训练集,正常任务是读取训练集
    train_x, train_y = build_dataset(train_sample)
    # 训练过程
    for epoch in range(epoch_num):
        model.train()
        watch_loss = []
        for batch_index in range(train_sample // batch_size):
            x = train_x[batch_index * batch_size : (batch_index + 1) * batch_size]
            y = train_y[batch_index * batch_size : (batch_index + 1) * batch_size]
            loss = model(x, y)  # 计算loss
            loss.backward()  # 计算梯度
            optim.step()  # 更新权重
            optim.zero_grad()  # 梯度归零
            watch_loss.append(loss.item())
        print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))
        acc = evaluate(model)  # 测试本轮模型结果
        log.append([acc, float(np.mean(watch_loss))])
    # 保存模型
    torch.save(model.state_dict(), "max_number.pt")
    # 画图
    print(log)
    plt.plot(range(len(log)), [l[0] for l in log], label="acc")  # 画acc曲线
    plt.plot(range(len(log)), [l[1] for l in log], label="loss")  # 画loss曲线
    plt.legend()
    plt.show()
    return

我们先定义一个评估函数evalute,在接下来训练时可以计算每轮训练的正确率。main函数为主要训练函数,先开始定义各种训练的超参数,20轮的反向传播。运行main()有如下结果:
在这里插入图片描述
可以看到损失函数在每轮过后逐渐减小,准确率在每轮训练后逐渐增大,准确率最终在百分之90以上。

模型评估预测

对于最终的结果我们可以查看模型 w w w, b b b的值是多少,并生成一些数据看看结果。

def predict(model_path, input_vec):
    input_size = 5
    model = TorchModel(input_size)
    model.load_state_dict(torch.load(model_path))  # 加载训练好的权重
    print(model.state_dict())

    model.eval()  # 测试模式
    with torch.no_grad():  # 不计算梯度
        result = model.forward(torch.FloatTensor(input_vec))  # 模型预测
    for vec, res in zip(input_vec, result):
        print("输入:%s, 预测类别:%d, 概率值:%f" % (vec, int(res.argmax())+1,res.max()))  # 打印结果
  
 test_vec = [[0.07889086,0.15229675,0.31082123,0.03504317,0.18920843],
            [0.94963533,0.5524256,0.95758807,0.95520434,0.84890681],
            [0.78797868,0.67482528,0.13625847,0.34675372,0.19871392],
            [0.79349776,0.59416669,0.92579291,0.41567412,0.1358894]]
 predict("model.pt",test_vec)

结果如下所示:
在这里插入图片描述
可以看到,对于差距比较大数字,该模型可以很大的概率预测准确的,对于精度比较小的值还是有小概率预测错误的。
以上就是深度学习的基本流程,你们还可以尝试用其他规则进行训练,比如说,第n个数字大于第m个数字归为正类,否则归为负类,小伙伴可以自己尝试一下。

标签:实战,loss,数字,训练,torch,batch,深度,model,size
From: https://blog.csdn.net/m0_57922605/article/details/140052461

相关文章

  • 应用数学与机器学习基础 - 深度学习的动机与挑战篇
    序言深度学习,作为当代人工智能领域的核心驱动力,其动机源于对模拟人类智能深层认知机制的渴望。我们追求的是让机器能够像人类一样理解、分析并应对复杂多变的世界。然而,这一追求之路并非坦途,面临着数据获取与处理的挑战、模型复杂度的控制、计算资源的巨大消耗等重重障碍。......
  • 开源语音转文本Speech-to-Text大模型实战之Wav2Vec篇
    前言近年来,语音转文本(Speech-to-Text,STT)技术取得了长足的进步,广泛应用于各种领域,如语音助手、自动字幕生成、智能客服等。本文将详细介绍如何利用开源语音转文本大模型进行实战,从模型选择、环境搭建、模型训练到实际应用,带您一步步实现语音转文本功能。一、模型选择目前,市......
  • 【深度学习】图形模型基础(3):从零开始认识机器学习模型
    1.引言机器学习,这一古老而又充满活力的领域,其历史可追溯至上世纪中叶。然而,直到20世纪90年代初,机器学习才开始展现出其广泛的应用潜力。在过去的十年里,机器学习更是迎来了前所未有的蓬勃发展,其应用范畴广泛,不仅在网络搜索、自动驾驶汽车、医学成像和语音识别等领域大放异彩......
  • 【深度学习】图形模型基础(1):使用潜在变量模型进行数据分析的box循环
    1.绪论探索数据背后的隐藏规律,这不仅是数据分析的艺术,更是概率模型展现其威力的舞台。在这一过程中,潜在变量模型尤为关键,它成为了数据驱动问题解决的核心引擎。潜在变量模型的基本理念在于,那些看似复杂、杂乱无章的数据表象之下,往往隐藏着一种更为简洁、有序的结构和规律,只......
  • 中小企业在数字化转型过程中遇到的挑战有哪些?
    引言:中小企业推进数字化转型的背景是多重因素叠加的结果,包括市场竞争压力、信息技术发展及普及、各级政府政策支持及引导、企业经营发展需求和人才结构变化等。这些因素共同推动了中小企业加快数字化转型的步伐,以应对日益复杂多变的市场环境。那中小企业推进数字化转型过程中......
  • 【全球首个开源AI数字人】DUIX数字人-打造你的AI伴侣!
    目录1.引言1.1数字人技术的发展背景1.2DUIX数字人项目的开源意义1.3DUIX数字人技术的独特价值1.4本文目的与结构2.DUIX数字人概述2.1定义与核心概念2.2硅基智能与DUIX的关系2.3技术架构2.4开源优势2.5应用场景2.6安全与合规性3.DUIX数字人技术特点3.1开......
  • 阿里云服务器数据库迁云: 数据从传统到云端的安全之旅(WordPress个人博客实战教学)
    ......
  • 解决springer期刊提供的LaTex模板参考文献格式为作者+年份时的顺序问题以及如何在正文
    这两天准备投稿springer下的一个期刊,拿到模板后人很麻,期刊给的latex和已出版的论文格式非常不符合,怎么办呐?不要急!下面开始改进!首先非常感谢大佬写的一篇解决方案,链接springer期刊提供的LaTex模板参考文献格式为作者+年份时的顺序问题_sn-article-CSDN博客该大佬提出的解决方......
  • 深度解析:scikit-learn Pipeline记忆功能的秘密
    标题:深度解析:scikit-learnPipeline记忆功能的秘密摘要scikit-learn(简称sklearn)是Python中一个广泛使用的机器学习库,它提供了许多用于数据挖掘和数据分析的工具。Pipeline是sklearn中一个强大的功能,允许用户以流水线的方式组合多个数据转换和/或模型训练步骤。本文将详细......
  • 大模型实战1年半,总结一下在企业落地的三个策略
    节前,我们组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、今年参加社招和校招面试的同学。针对大模型技术趋势、算法项目落地经验分享、新手如何入门算法岗、该如何准备面试攻略、面试常考点等热门话题进行了深入的讨论。总结链接如下:《大模型面试宝典》(2024......