首页 > 其他分享 >常用代码模板自用

常用代码模板自用

时间:2023-12-29 15:27:11浏览次数:33  
标签:代码 state 自用 path model os self 模板 dict

导入库(用于深度学习)

import os
import time
from datetime import timedelta
import json
import yaml
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

 

绘图(用于各种,以信号为例)

plt.figure(1)
plt.subplot(2, 1, 1)
plt.plot(t, sig)  # 绘制信号波形图
plt.subplot(2, 1, 2)
plt.imshow(np.abs(st_Res), origin='lower', extent=(0, len(t), 0, len(t)//2))  # 绘制频谱图、对复数格式取绝对值即可
plt.savefig("./imgs/stockwell-asdvalve09-" + time.strftime("%Y%m%d%H%M", time.localtime()))
plt.show()

深度学习训练模板(torch框架):

import os
import time
from datetime import timedelta
import json
import yaml
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

class BaseTrainer(object):
    def __init__(self):
        self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
        self.model = None
        self.optimizer = None
        self.loss = None

    def __setup_dataloader(self, is_train):
        if is_train:
            self.train_dataset = None
            self.train_loader = DataLoader(self.train_dataset,
                                           batch_size=64,
                                           shuffle=True, num_workers=0)
        # 获取测试数据
        self.valid_dataset = None
        self.valid_loader = DataLoader(self.valid_dataset, batch_size=64,
                                       shuffle=True, num_workers=0)

    def __setup_model(self, is_train):
        self.model = None
        self.model.to(self.device)
        # optimizer & scheduler
        self.loss = .to(self.device)
        if is_train:
            if self.configs.train_conf.enable_amp:
                self.amp_scaler = torch.cuda.amp.GradScaler(init_scale=1024)
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=float(self.configs.optimizer_conf.learning_rate))

    def __train_epoch(self, epoch_id):
        pass

    def train(self):
        pass

    def evaluate(self):
        pass

    def __test(self, resume_model=None):
        pass

保存checkpoint、加载checkpoint(torch框架):

    def __load_checkpoint(self, save_model_path, resume_model):
        last_epoch = -1
        best_auc, best_pauc = 0, 0
        last_model_dir = os.path.join(save_model_path,
                                      f'{self.configs.use_model}',
                                      'last_model')
        if resume_model is not None or (os.path.exists(os.path.join(last_model_dir, 'model.pth'))
                                        and os.path.exists(os.path.join(last_model_dir, 'optimizer.pth'))):
            # 自动获取最新保存的模型
            if resume_model is None:
                resume_model = last_model_dir
            assert os.path.exists(os.path.join(resume_model, 'model.pth')), "模型参数文件不存在!"
            assert os.path.exists(os.path.join(resume_model, 'optimizer.pth')), "优化方法参数文件不存在!"
            state_dict = torch.load(os.path.join(resume_model, 'model.pth'))
            if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
                self.model.module.load_state_dict(state_dict)
            else:
                self.model.load_state_dict(state_dict)
            self.optimizer.load_state_dict(torch.load(os.path.join(resume_model, 'optimizer.pth')))
            # 自动混合精度参数
            if self.amp_scaler is not None and os.path.exists(os.path.join(resume_model, 'scaler.pth')):
                self.amp_scaler.load_state_dict(torch.load(os.path.join(resume_model, 'scaler.pth')))
            with open(os.path.join(resume_model, 'model.state'), 'r', encoding='utf-8') as f:
                json_data = json.load(f)
                last_epoch = json_data['last_epoch'] - 1
                best_auc = json_data['best_auc']
                best_pauc = json_data['best_pauc']
            self.logger.info('成功恢复模型参数和优化方法参数:{}'.format(resume_model))
            self.optimizer.step()
            [self.scheduler.step() for _ in range(last_epoch * len(self.train_loader))]
        return last_epoch, best_auc, best_pauc

    def __save_checkpoint(self, save_model_path, epoch_id, best_auc=0., best_pauc=0., best_model=False):
        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            state_dict = self.model.module.state_dict()
        else:
            state_dict = self.model.state_dict()
        if best_model:
            model_path = os.path.join(save_model_path,
                                      f'{self.configs.use_model}',
                                      'best_model')
        else:
            model_path = os.path.join(save_model_path,
                                      f'{self.configs.use_model}',
                                      'epoch_{}'.format(epoch_id))
        os.makedirs(model_path, exist_ok=True)
        torch.save(self.optimizer.state_dict(), os.path.join(model_path, 'optimizer.pth'))
        torch.save(state_dict, os.path.join(model_path, 'model.pth'))
        # 自动混合精度参数
        if self.amp_scaler is not None:
            torch.save(self.amp_scaler.state_dict(), os.path.join(model_path, 'scaler.pth'))
        with open(os.path.join(model_path, 'model.state'), 'w', encoding='utf-8') as f:
            data = {"last_epoch": epoch_id, "best_auc": best_auc, "best_pauc": best_pauc,
                    "version": __version__}
            f.write(json.dumps(data))
        if not best_model:
            last_model_path = os.path.join(save_model_path,
                                           f'{self.configs.use_model}',
                                           'last_model')
            shutil.rmtree(last_model_path, ignore_errors=True)
            shutil.copytree(model_path, last_model_path)
            # 删除旧的模型
            old_model_path = os.path.join(save_model_path,
                                          f'{self.configs.use_model}',
                                          'epoch_{}'.format(epoch_id - 3))
            if os.path.exists(old_model_path):
                shutil.rmtree(old_model_path)
        self.logger.info('已保存模型:{}'.format(model_path))

加载预训练模型(torch):

    def __load_pretrained(self, pretrained_model):
        # 加载预训练模型
        if pretrained_model is not None:
            if os.path.isdir(pretrained_model):
                pretrained_model = os.path.join(pretrained_model, 'model.pth')
            assert os.path.exists(pretrained_model), f"{pretrained_model} 模型不存在!"
            if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
                model_dict = self.model.module.state_dict()
            else:
                model_dict = self.model.state_dict()
            model_state_dict = torch.load(pretrained_model)
            # 过滤不存在的参数
            for name, weight in model_dict.items():
                if name in model_state_dict.keys():
                    if list(weight.shape) != list(model_state_dict[name].shape):
                        self.logger.warning('{} not used, shape {} unmatched with {} in model.'.
                                       format(name, list(model_state_dict[name].shape), list(weight.shape)))
                        model_state_dict.pop(name, None)
                else:
                    self.logger.warning('Lack weight: {}'.format(name))
            if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
                self.model.module.load_state_dict(model_state_dict, strict=False)
            else:
                self.model.load_state_dict(model_state_dict, strict=False)
            self.logger.info('成功加载预训练模型:{}'.format(pretrained_model))

 

标签:代码,state,自用,path,model,os,self,模板,dict
From: https://www.cnblogs.com/zhaoke271828/p/17934924.html

相关文章

  • day01 代码随想录算法训练营 27. 移除元素
    题目:27.移除元素感悟:用快慢指针。本题是要原地删除。而删除这个行为在真实的计算机的数组里,是覆盖。所以,就用两个指针,(人)一个跑的快,一个跑的慢。他们身上带了个对讲机。跑的快的那个人负责检测后面的数字符合要求不,比如,要不等于3的,遇到一个2,告诉跑的慢的说2符合要求。遇......
  • 代码整洁之道:边界、单元测试、类
    来源:博客园(作者-BNDong)边界边界上的代码需要清晰的分割和定义了期望的测试。应该避免我们的代码过多地了解第三方代码中的特定信息。依靠你能控制的东西,好过依靠你控制不了的东西,免得日后受它控制。单元测试TDD三定律在编写不能通过的单元测试前,不可编写生成代码......
  • 阅读笔记:《代码大全》
    当谈到软件开发的艺术和科学时,SteveMcConnell的《代码大全》是无可争议的经典之作。它是一本旨在为软件工程师和程序员提供深入洞察的指南,旨在帮助他们提升编程技能、编写高质量代码以及有效管理整个软件开发周期。这本书不仅提供了广泛的理论知识,还结合了大量实用的案例和建议,下......
  • Rocky Linux 9 x86_64 OVF (sysin) - VMware 虚拟机模板
    RockyLinux9x86_64OVF(sysin)-VMware虚拟机模板以社区方式驱动的企业Linux作者主页:sysin.orgRockyLinux9.3(5.14.0-362.8.1.el9_3.x86_64)RockyLinux9.0(5.14.0-70.13.1.el9_0.x86_64)以社区方式驱动的企业LinuxRockyLinux是一个开源的企业级操作系统,旨在与Red......
  • WPF基本布局代码
    <Windowx:Class="WpfApp2.MainWindow"xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d="http://schemas.microsoft.c......
  • Linux下netcore调用java代码
    代码备份,仅供参考自述文件#CSharpCallJavaC#invokeJavaviaC++asawraper.C#invokeC++viaP/invoke.C++startsaJVMtoruntheJavacode.C#codeshouldbecompiledin.NETcore2.0YoushouldedittheMakefiletosetthePathofJavaSDKexpor......
  • Git-统计每天特定时间区间代码提交次数-非上班时间代码提交
    git-code-specific-time-of-day.sh#!/bin/bashtotal_count=0#获取最早的提交日期first_commit_date=$(gitlog--pretty=format:'%ad'--date=format:'%Y-%m-%d'|sort|head-n1)#计算当前日期current_date=$(date+%Y-%m-%d)#遍历从最早提交日期到当前日期的所......
  • 在Python中,如果你想查找特定的SQLite数据库文件(例如'mydatabase.db'),你可以使用os模块
    这是Python中os.walk()函数的常见用法¹²⁴⁵⁶。os.walk()函数用于递归遍历指定目录及其子目录,并返回一个生成器,每次迭代都会返回一个包含三个元素的元组:当前目录的路径、当前目录下所有子目录的列表和当前目录下所有文件的列表¹²⁴⁵⁶。在fordirpath,dirnames,filenamesi......
  • Git-代码量行数提交次数统计脚本
    git-code-user-commit-stats.sh#!/bin/sh#请在unix终端或git-bash中运行此脚本printf"\n1.项目成员数量:";gitlog--pretty='%aN'|sort-u|wc-lprintf"\n\n2.按用户名统计代码提交次数:\n\n"printf"%10s%s\n""次数"&qu......
  • (实用)解决csdn登录后才能复制代码以及关注博主后才能浏览全文的问题
      在面向百度编程的过程中,我们总是会上网找答案,而CSDN就是其中一个答案聚居地,但是它总是会有下面的问题:问题一:问题二: 解决方法如下:问题一:解决登录后才能复制代码1.在需要复制的代码处,右键选择检查 2. 3.F12,在console中输入神秘代码:document.body.contentEditable......