导入库(用于深度学习)
import os import time from datetime import timedelta import json import yaml from tqdm import tqdm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader
绘图(用于各种,以信号为例)
plt.figure(1) plt.subplot(2, 1, 1) plt.plot(t, sig) # 绘制信号波形图 plt.subplot(2, 1, 2) plt.imshow(np.abs(st_Res), origin='lower', extent=(0, len(t), 0, len(t)//2)) # 绘制频谱图、对复数格式取绝对值即可 plt.savefig("./imgs/stockwell-asdvalve09-" + time.strftime("%Y%m%d%H%M", time.localtime())) plt.show()
深度学习训练模板(torch框架):
import os import time from datetime import timedelta import json import yaml from tqdm import tqdm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader class BaseTrainer(object): def __init__(self): self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu" self.model = None self.optimizer = None self.loss = None def __setup_dataloader(self, is_train): if is_train: self.train_dataset = None self.train_loader = DataLoader(self.train_dataset, batch_size=64, shuffle=True, num_workers=0) # 获取测试数据 self.valid_dataset = None self.valid_loader = DataLoader(self.valid_dataset, batch_size=64, shuffle=True, num_workers=0) def __setup_model(self, is_train): self.model = None self.model.to(self.device) # optimizer & scheduler self.loss = .to(self.device) if is_train: if self.configs.train_conf.enable_amp: self.amp_scaler = torch.cuda.amp.GradScaler(init_scale=1024) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=float(self.configs.optimizer_conf.learning_rate)) def __train_epoch(self, epoch_id): pass def train(self): pass def evaluate(self): pass def __test(self, resume_model=None): pass
保存checkpoint、加载checkpoint(torch框架):
def __load_checkpoint(self, save_model_path, resume_model): last_epoch = -1 best_auc, best_pauc = 0, 0 last_model_dir = os.path.join(save_model_path, f'{self.configs.use_model}', 'last_model') if resume_model is not None or (os.path.exists(os.path.join(last_model_dir, 'model.pth')) and os.path.exists(os.path.join(last_model_dir, 'optimizer.pth'))): # 自动获取最新保存的模型 if resume_model is None: resume_model = last_model_dir assert os.path.exists(os.path.join(resume_model, 'model.pth')), "模型参数文件不存在!" assert os.path.exists(os.path.join(resume_model, 'optimizer.pth')), "优化方法参数文件不存在!" state_dict = torch.load(os.path.join(resume_model, 'model.pth')) if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): self.model.module.load_state_dict(state_dict) else: self.model.load_state_dict(state_dict) self.optimizer.load_state_dict(torch.load(os.path.join(resume_model, 'optimizer.pth'))) # 自动混合精度参数 if self.amp_scaler is not None and os.path.exists(os.path.join(resume_model, 'scaler.pth')): self.amp_scaler.load_state_dict(torch.load(os.path.join(resume_model, 'scaler.pth'))) with open(os.path.join(resume_model, 'model.state'), 'r', encoding='utf-8') as f: json_data = json.load(f) last_epoch = json_data['last_epoch'] - 1 best_auc = json_data['best_auc'] best_pauc = json_data['best_pauc'] self.logger.info('成功恢复模型参数和优化方法参数:{}'.format(resume_model)) self.optimizer.step() [self.scheduler.step() for _ in range(last_epoch * len(self.train_loader))] return last_epoch, best_auc, best_pauc def __save_checkpoint(self, save_model_path, epoch_id, best_auc=0., best_pauc=0., best_model=False): if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): state_dict = self.model.module.state_dict() else: state_dict = self.model.state_dict() if best_model: model_path = os.path.join(save_model_path, f'{self.configs.use_model}', 'best_model') else: model_path = os.path.join(save_model_path, f'{self.configs.use_model}', 'epoch_{}'.format(epoch_id)) os.makedirs(model_path, exist_ok=True) torch.save(self.optimizer.state_dict(), os.path.join(model_path, 'optimizer.pth')) torch.save(state_dict, os.path.join(model_path, 'model.pth')) # 自动混合精度参数 if self.amp_scaler is not None: torch.save(self.amp_scaler.state_dict(), os.path.join(model_path, 'scaler.pth')) with open(os.path.join(model_path, 'model.state'), 'w', encoding='utf-8') as f: data = {"last_epoch": epoch_id, "best_auc": best_auc, "best_pauc": best_pauc, "version": __version__} f.write(json.dumps(data)) if not best_model: last_model_path = os.path.join(save_model_path, f'{self.configs.use_model}', 'last_model') shutil.rmtree(last_model_path, ignore_errors=True) shutil.copytree(model_path, last_model_path) # 删除旧的模型 old_model_path = os.path.join(save_model_path, f'{self.configs.use_model}', 'epoch_{}'.format(epoch_id - 3)) if os.path.exists(old_model_path): shutil.rmtree(old_model_path) self.logger.info('已保存模型:{}'.format(model_path))
加载预训练模型(torch):
def __load_pretrained(self, pretrained_model): # 加载预训练模型 if pretrained_model is not None: if os.path.isdir(pretrained_model): pretrained_model = os.path.join(pretrained_model, 'model.pth') assert os.path.exists(pretrained_model), f"{pretrained_model} 模型不存在!" if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): model_dict = self.model.module.state_dict() else: model_dict = self.model.state_dict() model_state_dict = torch.load(pretrained_model) # 过滤不存在的参数 for name, weight in model_dict.items(): if name in model_state_dict.keys(): if list(weight.shape) != list(model_state_dict[name].shape): self.logger.warning('{} not used, shape {} unmatched with {} in model.'. format(name, list(model_state_dict[name].shape), list(weight.shape))) model_state_dict.pop(name, None) else: self.logger.warning('Lack weight: {}'.format(name)) if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): self.model.module.load_state_dict(model_state_dict, strict=False) else: self.model.load_state_dict(model_state_dict, strict=False) self.logger.info('成功加载预训练模型:{}'.format(pretrained_model))
标签:代码,state,自用,path,model,os,self,模板,dict From: https://www.cnblogs.com/zhaoke271828/p/17934924.html