10.5.5 计算交并比
文件metrics.py定义了基于 PyTorch 的交并比(IoU)度量类和 IoU 度量的子类,用于计算预测与标签之间的交并比,并可以根据给定阈值和可见度遮罩进行计算。
class BaseIoUMetric(Metric):
"""
计算给定阈值下的交并比
"""
def __init__(self, thresholds=[0.4, 0.5]):
super().__init__(dist_sync_on_step=False, compute_on_step=False)
thresholds = torch.FloatTensor(thresholds)
self.add_state('thresholds', default=thresholds, dist_reduce_fx='mean')
self.add_state('tp', default=torch.zeros_like(thresholds), dist_reduce_fx='sum')
self.add_state('fp', default=torch.zeros_like(thresholds), dist_reduce_fx='sum')
self.add_state('fn', default=torch.zeros_like(thresholds), dist_reduce_fx='sum')
def update(self, pred, label):
pred = pred.detach().sigmoid().reshape(-1)
label = label.detach().bool().reshape(-1)
pred = pred[:, None] >= self.thresholds[None]
label = label[:, None]
self.tp += (pred & label).sum(0)
self.fp += (pred & ~label).sum(0)
self.fn += (~pred & label).sum(0)
def compute(self):
thresholds = self.thresholds.squeeze(0)
ious = self.tp / (self.tp + self.fp + self.fn + 1e-7)
return {f'@{t.item():.2f}': i.item() for t, i in zip(thresholds, ious)}
class IoUMetric(BaseIoUMetric):
def __init__(self, label_indices: List[List[int]], min_visibility: Optional[int] = None):
"""
label_indices:
将标签转换为 (len(labels), h, w) 格式
有关示例,请参阅 config/experiment/* 目录下的示例文件
min_visibility:
传递 "None" 将忽略可见度遮罩
否则,使用可见度值来忽略某些标签
可见度遮罩的顺序为 "逐渐可见" {1, 2, 3, 4, 255 (默认)}
"""
super().__init__()
self.label_indices = label_indices
self.min_visibility = min_visibility
def update(self, pred, batch):
if isinstance(pred, dict):
pred = pred['bev'] # b c h w
label = batch['bev'] # b n h w
label = [label[:, idx].max(1, keepdim=True).values for idx in self.label_indices]
label = torch.cat(label, 1) # b c h w
if self.min_visibility is not None:
mask = batch['visibility'] >= self.min_visibility
mask = mask[:, None].expand_as(pred) # b c h w
pred = pred[mask] # m
label = label[mask] # m
return super().update(pred, label)
10.6 训练模型
文件train.py是用于训练模型的主程序,通过读取配置文件并设置参数,创建模型、数据模块和可视化函数,然后执行训练过程。此文件支持从之前的检查点恢复训练,并设置了日志记录、模型保存、学习率监控等功能,使用了PyTorch Lightning库来管理训练过程。
CONFIG_PATH = Path.cwd() / 'config' # 配置文件路径
CONFIG_NAME = 'config.yaml' # 配置文件名称
def maybe_resume_training(experiment):
# 可能恢复训练函数
save_dir = Path(experiment.save_dir).resolve() # 保存目录路径
checkpoints = list(save_dir.glob(f'**/{experiment.uuid}/checkpoints/*.ckpt')) # 检查点列表
log.info(f'Searching {save_dir}.') # 记录信息:正在搜索保存目录
if not checkpoints: # 如果没有检查点
return None
log.info(f'Found {checkpoints[-1]}.') # 记录信息:找到最后一个检查点
return checkpoints[-1] # 返回最后一个检查点路径
@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
def main(cfg):
setup_config(cfg) # 设置配置
pl.seed_everything(cfg.experiment.seed, workers=True) # 设定随机种子
Path(cfg.experiment.save_dir).mkdir(exist_ok=True, parents=False) # 创建保存目录
# 创建和加载模型/数据
model_module, data_module, viz_fn = setup_experiment(cfg) # 设置实验
# 可选地加载模型
ckpt_path = maybe_resume_training(cfg.experiment) # 可能恢复训练
if ckpt_path is not None:
model_module.backbone = load_backbone(ckpt_path) # 加载模型骨干
# 记录器和回调
logger = pl.loggers.WandbLogger(project=cfg.experiment.project,
save_dir=cfg.experiment.save_dir,
id=cfg.experiment.uuid) # 记录器
callbacks = [
LearningRateMonitor(logging_interval='epoch'), # 学习率监控器
ModelCheckpoint(filename='model',
every_n_train_steps=cfg.experiment.checkpoint_interval), # 模型保存回调
VisualizationCallback(viz_fn, cfg.experiment.log_image_interval), # 可视化回调
GitDiffCallback(cfg) # Git差异回调
]
# 训练
trainer = pl.Trainer(logger=logger,
callbacks=callbacks,
strategy=DDPStrategy(find_unused_parameters=False),
**cfg.trainer) # 训练器
trainer.fit(model_module, datamodule=data_module, ckpt_path=ckpt_path) # 执行训练
if __name__ == '__main__':
main()
标签:11,pred,self,实时,label,experiment,cfg,交并,thresholds
From: https://blog.csdn.net/asd343442/article/details/143312676