首页 > 其他分享 >DL 基础:PyTorch 常用代码存档

DL 基础:PyTorch 常用代码存档

时间:2023-03-02 15:56:05浏览次数:36  
标签:loss DL 存档 torch PyTorch pred test model MSE

1 pandas 读 csv

import torch
from torch import nn
import numpy as np
import pandas as pd
from copy import deepcopy
device = "cuda" if torch.cuda.is_available() else "cpu"

# 读 csv
data_all = pd.read_csv('./CFD_data/record_data0.csv')
# 提取某一列
colume = np.array(data_all[['colume_name']], dtype=np.float32).reshape(-1, 1)
# 提取某一个值
value = data[data['食物种类']=='主食']['卡路里'].item()
# 数据操作
c = np.concatenate([a[1:], b[:-1]], axis=1)
c = torch.cat([a, b], axis=1)
# 存 csv
c.to_csv('./CFD_data/flow_rate.csv', index=False)

2 NN 的搭建、训练与评估

搭建:使用 nn.Sequential

# model
NN_model = nn.Sequential(
    nn.Linear(6, 256), 
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 1),
)
# 优化器
optimizer = torch.optim.Adam(NN_model.parameters(), lr=0.001)

训练:

def NN_train(train_x, train_y, model, loss_fn, optimizer, epoches, batch_size, save_path):
    """
    训练网络
    输入:
        train_x, train_y:   训练集
        model:              网络模型
        loss_fn:            损失函数
        optimizer:          优化器
        epoches:            epoches 个数
        batch_size:         mini batch 大小
        save_path:          模型保存路径
    """
    # 切换到train模式
    model.train()
    losses = []
    for epoch in range(epoches):
        batch_loss = []
        for start in range(0, len(train_x), batch_size): # mini batch
            end = start + batch_size if start + batch_size < len(train_x) else len(train_x)
            xx = torch.tensor(train_x[start:end], dtype=torch.float, requires_grad=True)
            yy = torch.tensor(train_y[start:end], dtype=torch.float, requires_grad=True)
            xx, yy = xx.to(device), yy.to(device) # 加载到 device
            pred = model(xx) # 输入数据到模型里得到输出
            loss = loss_fn(pred, yy) # 计算输出和标签的 loss           
            optimizer.zero_grad() # 清零
            loss.backward() # 反向推导
            optimizer.step() # 步进优化器
            batch_loss.append(loss.data.numpy())
        if epoch % max(1, epoches//8) == 0:
            print(f"Training Error in epoch {epoch}: {np.mean(batch_loss):>8f}")
    torch.save(model.state_dict(), save_path) # 保存模型

测试:

def NN_test(test_x, test_y, model, save_path, loss_fn):
    """
    测试网络
    输入:
        test_x, test_y:     测试集
        model:              网络模型
        loss_fn:            损失函数
        save_path:          模型保存路径
    """
    model.load_state_dict(torch.load(save_path)) # 加载模型  
    model.eval() # 切换到测试模型
    MSE_loss_fn = nn.MSELoss() # MSE loss function
    test_loss, MSE = 0, 0 # 记录 loss 和 MSE
    # 梯度截断
    with torch.no_grad():
        test_x, test_y = torch.tensor(test_x).to(device), torch.tensor(test_y).to(device) # 加载到 device
        pred = model(test_x) # 输入数据到模型里得到输出
        test_loss = loss_fn(pred, test_y).item() # 计算输出和标签的 loss
        MSE = MSE_loss_fn(pred, test_y).item() # MSE
    print(f"Test Error: \n  Avg loss: {test_loss:>8f}, MSE: {MSE:>8f}\n")
    print(f"Test Result: \n  Prediction: {pred[:5]}, \n  Y: {test_y[:5]}, \n  diff: {test_y[:5]-pred[:5]}\n")

测试 ensemble model(平均值):

def NN_test_ensemble(test_x, test_y, loaded_model_list, loss_fn):
    for model in loaded_model_list:
        model.eval() # 切换到测试模型
    MSE_loss_fn = nn.MSELoss() # MSE loss function
    test_loss, MSE = 0, 0 # 记录 loss 和 MSE
    # 梯度截断
    with torch.no_grad():
        test_x, test_y = torch.tensor(test_x).to(device), torch.tensor(test_y).to(device) # 加载到 device
        pred = torch.zeros(test_y.shape)
        for model in loaded_model_list:
            pred += model(test_x) # 输入数据到模型里得到输出
        pred /= len(loaded_model_list)
        test_loss = loss_fn(pred, test_y).item() # 计算输出和标签的 loss
        MSE = MSE_loss_fn(pred, test_y).item() # MSE
    print(f"Test Error: \n  Avg loss: {test_loss:>8f}, MSE: {MSE:>8f}\n")
    print(f"Test Result: \n  Prediction: {pred[:5]}, \n  Y: {test_y[:5]}, \n  diff: {test_y[:5]-pred[:5]}\n")

标签:loss,DL,存档,torch,PyTorch,pred,test,model,MSE
From: https://www.cnblogs.com/moonout/p/17172065.html

相关文章

  • 什么是Bundle ID​
    登录成功后我们可以看到弹出的消息提示“您账号未支付688给apple,只能创建开发证书,无法提交上传发布,无法使用apple登录,支付,推送功能”,简单来说就是只能使用此款软件进行内测,......
  • Fiddler 对真机(Android 系统)上 App 抓包图文详解 (超全)
    作为测试或开发经常需要抓取手机App的HTTP/HTTPS的数据包,通过查看App发出的HTTP请求和响应数据来协助开发去修复bug。对于测试而言,通过抓包+分析,去定位bug的前后端归属问题......
  • LinkedList集合应用:实现队列
    LinkedList集合应用:实现队列题目:使用LinkedList类实现一个Queue(队列)类。Queue类应该具有以下功能:void enqueue(Eelement):将给定的元素添加到队列的末尾。Edeque......
  • 什么是Bundle ID​
    什么是BundleID​登录成功后我们可以看到弹出的消息提示“您账号未支付688给apple,只能创建开发证书,无法提交上传发布,无法使用apple登录,支付,推送功能”,简单来说就是只能使......
  • R语言分布滞后非线性模型(DLNM)研究发病率,死亡率和空气污染示例|附代码数据
    全文下载链接:http://tecdat.cn/?p=21317最近我们被客户要求撰写关于分布滞后非线性模型(DLNM)的研究报告,包括一些图形和统计输出。本文提供了运行分布滞后非线性模型的示例......
  • [ChatGPT 勘误]SAP ABAP 里 CL_WB_ED_ENHANCEMENT_HANDLER 的用途介绍
    以下是ChatGPT关于CL_WB_ED_ENHANCEMENT_HANDLER的介绍:在ABAP中,CL_WB_ED_ENHANCEMENT_HANDLER是一个用于管理ABAP代码增强(CodeEnhancement)的类。ABAP代码增强......
  • pytorch和pyG安装
    操作系统:windows10显卡:GTX1650CUDA版本11.1下载安装CUDAToolkit11.1.0新建conda环境,python3.8condacreate-nGNNpython=3.8。激活condaactivateGNNwin下CUDA......
  • java LinkedList 源码
    概述底层数据结构是双向链表(jdk1.6是双向循环,1.7开始不循环了),所以新增/删除效率高,查询/修改效率相对较低全能冠军:既是一个顺序容器,也是队列,还可以作为栈使用未实现Ran......
  • flutter doctor错误:“Unable to find bundled Java version.”
    实际是找不到/Applications/AndroidStudio.app/Contents/jre/jdk执行如下命令:cd/Applications/AndroidStudio.app/Contents/ln-sjbrjrecdjreln-sContents......
  • ThreadLocal 详解
    ThreadLocal概述ThreadLocal类用来提供线程内部的局部变量,不同的线程之间不会相互干扰这种变量在多线程环境下访问(通过get和set方法访问)时能保证各个线程的变量相对独立......