首页 > 其他分享 >深度学习代码实践_train.py文件内容(识别数字0-9)

深度学习代码实践_train.py文件内容(识别数字0-9)

时间:2023-05-18 19:12:05浏览次数:36  
标签:torch pred py label train test 识别 data

import cv2
from MLP import MLP
from Cnn import save_model
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
import torch.utils.data as Data
import numpy as np
import matplotlib.pyplot as plt
from data_process import get_features
from sklearn.preprocessing import StandardScaler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_losses = []
train_accuracy = []
def train_model(net, data, label, lr, batch_size, epoch):
    print(net)
    data = torch.Tensor(data)
    data = data.unsqueeze(1)
    label = torch.Tensor(label).long()

    #resnet建议使用cuda
    data = data.to(device)
    label = label.to(device)
    # 训练集和测试集7:3
    train_data, test_data, train_label, test_label = train_test_split(data, label, test_size=0.3, random_state=0)

    # 学习率
    LR = lr
    # 每次投入训练数据大小
    BATCH_SIZE = batch_size
    # 训练模型次数
    EPOCH = epoch

    optimizer = torch.optim.Adam(net.parameters(), lr=LR)

    train_dataset = Data.TensorDataset(train_data, train_label)
    train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
    )

    test_dataset = Data.TensorDataset(test_data, test_label)
    test_loader = Data.DataLoader(
        dataset=test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
    )
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, LR, epochs=EPOCH, steps_per_epoch=len(train_loader))


    for epoch in range(EPOCH):
        running_loss = 0
        for step, (batch_data, batch_label) in enumerate(train_loader):
            print('Epoch:', epoch + 1, '/', EPOCH, 'Step:', step)
            prediction = net(batch_data)
            loss = F.cross_entropy(prediction, batch_label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            running_loss += loss.item()

            _, pred = torch.max(prediction, 1)

            accuracy = torch.sum(pred == batch_label).item() / len(pred)
            print('Epoch', epoch + 1, '| train loss:%.4f' % loss, '| accuracy:%.4f' % accuracy)
        train_losses.append(running_loss / len(train_loader))
        train_accuracy.append(accuracy)

    return net


def test_model(net, data, label):
    net.eval()
    data = torch.Tensor(data)
    data = data.unsqueeze(1)
    label = torch.Tensor(label).long()
    data = data.to(device)
    label = label.to(device)
    # 训练集和测试集7:3
    train_data, test_data, train_label, test_label = train_test_split(data, label, test_size=0.3, random_state=0)

    test_dataset = Data.TensorDataset(test_data, test_label)
    test_loader = Data.DataLoader(
        dataset=test_dataset,
        batch_size=32,
        shuffle=True,
    )

    y_true = []
    y_pred = []
    for stp, (test_x, test_y) in enumerate(test_loader):
        test_output = net(test_x)
        _, pred_y = torch.max(test_output, 1)
        y_true.extend(test_y)
        y_pred.extend(pred_y)
    y_true = torch.tensor(y_true, device='cpu')
    y_pred = torch.tensor(y_pred, device='cpu')
    print("Accuracy:", accuracy_score(y_true, y_pred))
    print("Precision_score:", precision_score(y_true, y_pred, average='macro'))
    print("Recall_score:", recall_score(y_true, y_pred, average='macro'))
    print("F1_score", f1_score(y_true, y_pred, average='macro'))


def predict(model, file):
    spect = get_features(file)
    data = torch.Tensor(spect)
    data = data.unsqueeze(0)
    data = data.unsqueeze(0)
    data = data.to(device)

    output = model(data)
    confidence, pred_y = torch.max(output, 1)
    print("识别结果为:", pred_y.cpu().numpy())


if __name__ == '__main__':
    data = np.load("data.npy")
    label = np.load("label.npy")

    #cnn = ResNet(ResidualBlock, [2, 2, 2]).to(device)
    #cnn = CNN().to(device)
    cnn = MLP().to(device)
    cnn = train_model(cnn, data, label, lr=0.01, batch_size=512, epoch=300)
    save_model(cnn, "mlp.pt")

    test_model(cnn, data, label)

    np.save("dropout_train_losses.npy", train_losses)
    np.save("dropout_train_accuracy.npy", train_accuracy)

    plt.plot(train_losses, label='Training loss')
    plt.plot(train_accuracy, label='Accuracy')
    plt.xlabel("epoch")
    plt.ylabel("loss/accuracy")
    plt.legend()
    plt.show()





标签:torch,pred,py,label,train,test,识别,data
From: https://www.cnblogs.com/fly-smart/p/17413052.html

相关文章

  • 【python】dumpall工具使用
    dumpall:一款信息泄漏利用工具,适用于.git/.svn/.DS_Store泄漏和目录列出  git地址:https://github.com/0xHJK/dumpall  安装使用:#unzipdumpall-master.zip#cddumpall-master#python37dumpall.py--version#查看版本#python37dumpall.py-uhttps:......
  • 2023最佳python编辑器和IDE
    IDE没有统一的标准,自己习惯就是最好的。本文列出一些较常用的IDE,供大家参考。一般而言,WingIDE、PyCharm、Spyder、Vim是比较常用的IDE。SpyderSpyder是Python(x,y)的作者为它开发的一个简单的集成开发环境。和其他的Python开发环境相比,它最大的优点就是模仿MATLAB的"工作空间"......
  • C# 判别系统版本以及Win10的识别办法
    我们都知道在C#中可以通过Environment.OSVersion来判断当前操作系统,下面是操作系统和主次版本的对应关系:操作系统主版本.次版本Windows1010.0*WindowsServer2016TechnicalPreview10.0*Windows8.16.3*WindowsServer2012R26.3*Windows86.2......
  • 利用python解析log日志,json文件,配置文件。
    对于喜欢偷懒的我来说,重复同样的工作是很令人头疼的事情,总想找到一条捷径,最好是一劳永逸。本次跟大家分享的是对log日志,json文件以及配置文件的解析,读取。首先是log日志的读写:读取数据:f=open("spring05注意事项.txt",mode='r',encoding='utf-8')line=f.readline()whileline......
  • postgres 错误duplicate key value violates unique constraint 解决方案
    出错代码tortoise.exceptions.IntegrityError:duplicatekeyvalueviolatesuniqueconstraint"word_bank2_pkey"原文连接分析bugpostgres出现该问题着实没仔细看数据表序列ID,、出现的原因是:以word_bank2表为列子.id是唯一的且id在数据库中是自增的.而现在数据库中存......
  • Python字符串替换的3种方法
    Python字符串替换笔记主要展示了如何在Python中替换字符串。Python中有以下几种替换字符串的方法,本文主要介绍前三种。replace方法(常用)translate方法re.sub方法字符串切片(根据Python字符串切片方法替换字符)1.replace方法Pythonreplace方法把字符串中的old(旧字符串)替换成......
  • Pytest根据命令行参数使用动态数据进行参数话数据驱动
    Python中有一个重要的特性是,装饰器、类属性、模块变量都是模块加载时立即执行的。因此在使用@pytest.mark.parametrize进行参数话的时候,数据一般是确定的,如下例:importpytestDATA=["a.txt","b.txt","c.txt",]@pytest.mark.parametrize('filepath',DATA)......
  • 如何安装python
    在Linux和MacOS系统中,Python通常已经预装了,可以通过以下命令检查Python是否已经安装:python--version如果Python已经安装,则会显示Python的版本号。如果Python没有安装,则可以通过以下命令安装:在Ubuntu和Debian系统中,可以使用以下命令安装Python:sudoapt-getupdates......
  • 如何安装python
    在Linux和MacOS系统中,Python通常已经预装了,可以通过以下命令检查Python是否已经安装:python--version如果Python已经安装,则会显示Python的版本号。如果Python没有安装,则可以通过以下命令安装:在Ubuntu和Debian系统中,可以使用以下命令安装Python:sudoapt-getupdates......
  • 记一次排查:接口返回值写入excel后,从单元格copy出来的数据会带有多重引号的问题
    在项目里刚好有3个服务,同一个网关内层的3个服务,两个php的,一个golang的,为了提高负载以及进行分流,部分客户的接口调用会被网关自动分配到go服务。恰好为了测试,我写了一个全量用户的生产、测试环境调用接口返回结果进行对比的脚本,于是发现了题中的问题:两个php服务里的接口返回值写入......