首页 > 其他分享 >(11-3)基于深度学习的实时地图导航:计算交并比+训练模型

(11-3)基于深度学习的实时地图导航:计算交并比+训练模型

时间:2024-10-28 20:15:52浏览次数:6  
标签:11 pred self 实时 label experiment cfg 交并 thresholds

10.5.5  计算交并比

文件metrics.py定义了基于 PyTorch 的交并比(IoU)度量类和 IoU 度量的子类,用于计算预测与标签之间的交并比,并可以根据给定阈值和可见度遮罩进行计算。

class BaseIoUMetric(Metric):
    """
    计算给定阈值下的交并比
    """
    def __init__(self, thresholds=[0.4, 0.5]):
        super().__init__(dist_sync_on_step=False, compute_on_step=False)
 
        thresholds = torch.FloatTensor(thresholds)
 
        self.add_state('thresholds', default=thresholds, dist_reduce_fx='mean')
        self.add_state('tp', default=torch.zeros_like(thresholds), dist_reduce_fx='sum')
        self.add_state('fp', default=torch.zeros_like(thresholds), dist_reduce_fx='sum')
        self.add_state('fn', default=torch.zeros_like(thresholds), dist_reduce_fx='sum')
 
    def update(self, pred, label):
        pred = pred.detach().sigmoid().reshape(-1)
        label = label.detach().bool().reshape(-1)
 
        pred = pred[:, None] >= self.thresholds[None]
        label = label[:, None]
 
        self.tp += (pred & label).sum(0)
        self.fp += (pred & ~label).sum(0)
        self.fn += (~pred & label).sum(0)
 
    def compute(self):
        thresholds = self.thresholds.squeeze(0)
        ious = self.tp / (self.tp + self.fp + self.fn + 1e-7)
 
        return {f'@{t.item():.2f}': i.item() for t, i in zip(thresholds, ious)}
 
 
class IoUMetric(BaseIoUMetric):
    def __init__(self, label_indices: List[List[int]], min_visibility: Optional[int] = None):
        """
        label_indices:
            将标签转换为 (len(labels), h, w) 格式
            有关示例,请参阅 config/experiment/* 目录下的示例文件
 
        min_visibility:
            传递 "None" 将忽略可见度遮罩
            否则,使用可见度值来忽略某些标签
            可见度遮罩的顺序为 "逐渐可见" {1, 2, 3, 4, 255 (默认)}
        """
        super().__init__()
 
        self.label_indices = label_indices
        self.min_visibility = min_visibility
 
    def update(self, pred, batch):
        if isinstance(pred, dict):
            pred = pred['bev']                                                              # b c h w
 
        label = batch['bev']                                                                # b n h w
        label = [label[:, idx].max(1, keepdim=True).values for idx in self.label_indices]
        label = torch.cat(label, 1)                                                         # b c h w
 
        if self.min_visibility is not None:
            mask = batch['visibility'] >= self.min_visibility
            mask = mask[:, None].expand_as(pred)                                            # b c h w
 
            pred = pred[mask]                                                               # m
            label = label[mask]                                                             # m
 
        return super().update(pred, label)

10.6  训练模型

文件train.py是用于训练模型的主程序,通过读取配置文件并设置参数,创建模型、数据模块和可视化函数,然后执行训练过程。此文件支持从之前的检查点恢复训练,并设置了日志记录、模型保存、学习率监控等功能,使用了PyTorch Lightning库来管理训练过程。

CONFIG_PATH = Path.cwd() / 'config'  # 配置文件路径
CONFIG_NAME = 'config.yaml'  # 配置文件名称
 
def maybe_resume_training(experiment):
    # 可能恢复训练函数
    save_dir = Path(experiment.save_dir).resolve()  # 保存目录路径
    checkpoints = list(save_dir.glob(f'**/{experiment.uuid}/checkpoints/*.ckpt'))  # 检查点列表
 
    log.info(f'Searching {save_dir}.')  # 记录信息:正在搜索保存目录
 
    if not checkpoints:  # 如果没有检查点
        return None
 
    log.info(f'Found {checkpoints[-1]}.')  # 记录信息:找到最后一个检查点
 
    return checkpoints[-1]  # 返回最后一个检查点路径
 
 
@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
def main(cfg):
    setup_config(cfg)  # 设置配置
 
    pl.seed_everything(cfg.experiment.seed, workers=True)  # 设定随机种子
 
    Path(cfg.experiment.save_dir).mkdir(exist_ok=True, parents=False)  # 创建保存目录
 
    # 创建和加载模型/数据
    model_module, data_module, viz_fn = setup_experiment(cfg)  # 设置实验
 
    # 可选地加载模型
    ckpt_path = maybe_resume_training(cfg.experiment)  # 可能恢复训练
 
    if ckpt_path is not None:
        model_module.backbone = load_backbone(ckpt_path)  # 加载模型骨干
 
    # 记录器和回调
    logger = pl.loggers.WandbLogger(project=cfg.experiment.project,
                                    save_dir=cfg.experiment.save_dir,
                                    id=cfg.experiment.uuid)  # 记录器
 
    callbacks = [
        LearningRateMonitor(logging_interval='epoch'),  # 学习率监控器
        ModelCheckpoint(filename='model',
                        every_n_train_steps=cfg.experiment.checkpoint_interval),  # 模型保存回调
 
        VisualizationCallback(viz_fn, cfg.experiment.log_image_interval),  # 可视化回调
        GitDiffCallback(cfg)  # Git差异回调
    ]
 
    # 训练
    trainer = pl.Trainer(logger=logger,
                         callbacks=callbacks,
                         strategy=DDPStrategy(find_unused_parameters=False),
                         **cfg.trainer)  # 训练器
    trainer.fit(model_module, datamodule=data_module, ckpt_path=ckpt_path)  # 执行训练
 
 
if __name__ == '__main__':
    main()

标签:11,pred,self,实时,label,experiment,cfg,交并,thresholds
From: https://blog.csdn.net/asd343442/article/details/143312676

相关文章

  • YOLO11改进 | 卷积模块 | 在主干网络中添加蛇形卷积Dynamic Snake Convolution
    秋招面试专栏推荐 :深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转......
  • YOLO11改进 | 卷积模块 | 轻量化卷积模块GSConv【附代码+小白可上手】
     秋招面试专栏推荐 :深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转......
  • 量化交易中,如何获取实时市场行情数据?
    炒股自动化:申请官方API接口,散户也可以python炒股自动化(0),申请券商API接口python炒股自动化(1),量化交易接口区别Python炒股自动化(2):获取股票实时数据和历史数据Python炒股自动化(3):分析取回的实时数据和历史数据Python炒股自动化(4):通过接口向交易所发送订单Python炒股自动化(5):......
  • MaskGCT,AI语音克隆大模型本地部署(Windows11),基于Python3.11,TTS,文字转语音
    前几天,又一款非自回归的文字转语音的AI模型:MaskGCT,开放了源码,和同样非自回归的F5-TTS模型一样,MaskGCT模型也是基于10万小时数据集Emilia训练而来的,精通中英日韩法德6种语言的跨语种合成。数据集Emilia是全球最大且最为多样的高质量多语种语音数据集之一。本次分享一下如何在本地......
  • 如何在Windows 10/11中轻松实现PDF到Word的
    PDF到Word的转换在工作场所是常见需求。编辑Word文档比PDF更加方便,因为PDF是只读文件。如果你希望在与他人共享之前对文档进行一些修改,选择Word文档会更合适。本文将介绍如何在Windows10/11中将PDF转换为Word的可行方法。请继续阅读。第1部分:有关如何在W......
  • Springboot世界美食风情展示系统211wo(程序+源码+数据库+调试部署+开发环境)
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表用户,美食类别,世界美食,美食攻略,美食订单开题报告内容一、项目背景与意义随着经济的快速发展和网络技术的进步,互联网已经深刻改变了人们的生活方式。电子商务......
  • 摄像机实时接入分析平台视频分析网关AI智能分析智慧营业厅方案
    一、方案背景随着社会对智能化服务需求的增长,传统营业厅面临着转型升级的压力。智慧营业厅方案通过引入视频监控和视频分析技术,旨在实现对营业厅的全面监控和管理,提高服务效率和客户满意度。该方案涵盖了监控布局与设备选型、实时画面采集与传输、人脸识别、行为分析与异常检......
  • 摄像机实时接入分析平台视频分析网关AI智能分析智慧营业厅方案
    一、方案背景随着社会对智能化服务需求的增长,传统营业厅面临着转型升级的压力。智慧营业厅方案通过引入视频监控和视频分析技术,旨在实现对营业厅的全面监控和管理,提高服务效率和客户满意度。该方案涵盖了监控布局与设备选型、实时画面采集与传输、人脸识别、行为分析与异常检测、......
  • 202411实践
    java#include<iostream>#include<vector>classArray{private:std::vector<std::vector<int>>matrix;intsize;public://构造函数Array(intn):size(n){matrix.resize(n,std::vector<int>(n,......
  • 安装wolfram11教程
    前言本案例仅供交流学习文件链接链接:https://pan.baidu.com/s/1GrCQ90nSoSkjP_a36TDDYw?pwd=llll提取码:llll安装解压后打开setup.exe进行安装一路安装,注意修改路径这里选择其他方式激活手动激活记录下MathID,这个ID要输入到下面html的第一个......