首页 > 编程语言 >(7-6)行为预测算法:基于Trajectron++模型的行为预测系统

(7-6)行为预测算法:基于Trajectron++模型的行为预测系统

时间:2024-03-28 21:31:50浏览次数:24  
标签:预测 Trajectron self agent dataset ++ dict key 数据

7.6  基于Trajectron++模型的行为预测

Trajectron++是一个用于多目标轨迹预测和规划的深度学习模型,旨在应对自动驾驶和机器人等领域中的挑战,其中多个移动目标需要被准确地预测其未来运动轨迹,以便做出智能决策。

7.6.1  Trajectron++模型的特点

Trajectron++模型的主要特点和功能如下所示。

  1. 多目标轨迹预测:Trajectron++ 的核心任务是预测多个移动目标的未来运动轨迹,这对于自动驾驶车辆、机器人等在复杂交通场景中的行为规划至关重要。
  2. 深度学习架构:Trajectron++模型采用深度学习技术,包括循环神经网络(RNN)和卷积神经网络(CNN),以便有效地处理时间序列和空间信息,从而更好地捕捉目标的运动模式。
  3. 多智能体建模:Trajectron++ 考虑了多个移动目标之间的相互作用和关系。这有助于更准确地预测每个目标的轨迹,因为它们的运动可能受到彼此的影响。
  4. 生成式模型:Trajectron++ 是一个生成式模型,它可以生成可能的未来轨迹的概率分布。这使得它能够更灵活地处理不确定性,对于智能决策非常重要。
  5. 实时性能:Trajectron++ 被设计成具有实时性能,以便在实际应用中能够及时地做出决策。

在实际应用中,Trajectron++模型被主要用在自动驾驶领域,能够通过对轨迹集合建模来捕捉不确定性,用于车辆和行人的轨迹预测,为实体感知和决策提供了强大的支持。

7.6.2  基于Trajectron++模型的行为预测系统

在本项目中,使用 PyTorch Lightning 和 Lyft 提供的 l5kit 工具包实现了一个灵活的数据加载器,支持多智能体训练。展示了配置和使用不同的数据集、rasterizer,并提供了对训练数据批次结构的详细解析,为进一步的模型训练和实验奠定了基础。

实例2-8Trajectron++行为预测系统codes/2/lyft-multi-agent.ipynb

本项目使用的是Lyft 公司提供的自动驾驶车辆运动预测数据集(Lyft Motion Prediction Autonomous Vehicles),这是一个开源数据集。该数据集的目标是帮助研究人员和开发者训练和评估自动驾驶车辆在城市环境中的运动预测能力。数据集Lyft Motion Prediction for Autonomous Vehicles的主要特点和内容如下所示。

  1. 场景和环境:数据集提供了在城市环境中采集的大量传感器数据,涵盖了各种驾驶场景,包括道路、交叉口、人行道等。这使得研究人员能够测试和优化自动驾驶系统在不同复杂环境中的性能。
  2. 传感器数据:数据集包含了来自各种传感器的信息,如激光雷达、摄像头、雷达等。这些传感器数据为车辆周围的环境提供了高分辨率的感知信息。
  3. 运动轨迹和预测:Lyft 数据集中包含了车辆的历史运动轨迹数据,并提供了对未来运动轨迹的预测。这使得研究人员可以训练和评估模型在预测其他车辆或行人行为时的准确性。
  4. 地图信息:数据集可能包括高精度地图信息,以帮助自动驾驶车辆更好地理解和导航城市环境。
  5. 用于研究的挑战性问题:Lyft 数据集通常包含一些挑战性的问题,以促使研究人员开发创新性的算法和模型。这有助于推动自动驾驶技术的发展。
  6. 使用工具包:Lyft 提供了与数据集一起使用的工具包,如 l5kit,以便更轻松地处理和分析数据。

通过使用 Lyft 的自动驾驶车辆运动预测数据集,研究人员可以进行各种实验和测试,以提高自动驾驶系统在复杂城市交通中的性能。请查阅 Lyft 公司的官方文档或数据集页面,以获取更详细和最新的信息。实例文件lyft-multi-agent.ipynb的具体实现流程如下所示。

(1)导入必要的库和模块,设置全局变量,以及检测当前是否在 Kaggle 环境中。其中,l5kit 是 Lyft 公司提供的用于处理自动驾驶车辆运动预测数据集的工具包,而代码中的变量和模块则为后续的数据处理和模型训练做准备。

import bisect
import os
from copy import deepcopy
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pytorch_lightning as pl
from l5kit.data import ChunkedDataset, LocalDataManager
from l5kit.dataset import AgentDataset
from l5kit.rasterization import StubRasterizer, build_rasterizer
from torch.utils.data import DataLoader, Dataset, Subset

is_kaggle = os.path.isdir("")

(2)定义一个名为 CONFIG_DATA 的字典,其中包含了有关模型、光栅化参数以及数据加载器的配置信息。具体而言,它包括了模型的架构(resnet34)、历史和未来运动轨迹的帧数和时间步长、光栅化的相关参数、训练和验证数据加载器的配置等。这个配置字典被用于指定模型和数据的相关参数,以便在后续的训练和评估中使用。

# 配置数据字典,包括模型、光栅化参数以及数据加载器的相关配置信息
CONFIG_DATA = {
    "format_version": 4,
    "model_params": {
        "model_architecture": "resnet34",  # 模型架构
        "history_num_frames": 10,  # 历史运动轨迹的帧数
        "history_step_size": 1,  # 历史轨迹的时间步长
        "history_delta_time": 0.1,  # 历史轨迹的时间步长
        "future_num_frames": 50,  # 未来运动轨迹的帧数
        "future_step_size": 1,  # 未来轨迹的时间步长
        "future_delta_time": 0.1,  # 未来轨迹的时间步长
    },
    "raster_params": {
        "raster_size": [256, 256],  # 光栅化图像大小
        "pixel_size": [0.5, 0.5],  # 像素大小
        "ego_center": [0.25, 0.5],  # 智能驾驶汽车中心相对位置
        "map_type": "py_semantic",  # 地图类型
        "satellite_map_key": "aerial_map/aerial_map.png",  # 卫星地图键
        "semantic_map_key": "semantic_map/semantic_map.pb",  # 语义地图键
        "dataset_meta_key": "meta.json",  # 数据集元数据键
        "filter_agents_threshold": 0.5,  # 过滤智能体的阈值
        "disable_traffic_light_faces": False,  # 是否禁用交通灯的人脸
    },
    "train_dataloader": {
        "key": "scenes/sample.zarr",  # 训练数据集键
        "batch_size": 24,  # 批量大小
        "shuffle": True,  # 是否打乱数据
        "num_workers": 0,  # 数据加载器的工作进程数
    },
    "val_dataloader": {
        "key": "scenes/validate.zarr",  # 验证数据集键
        "batch_size": 24,  # 批量大小
        "shuffle": False,  # 是否打乱数据
        "num_workers": 4,  # 数据加载器的工作进程数
    },
    "test_dataloader": {
        "key": "scenes/test.zarr",  # 测试数据集键
        "batch_size": 24,  # 批量大小
        "shuffle": False,  # 是否打乱数据
        "num_workers": 4,  # 数据加载器的工作进程数
    },
    "train_params": {
        "max_num_steps": 400,  # 最大训练步数
        "eval_every_n_steps": 50,  # 每隔多少步进行一次评估
    },
}

(3)创建一个名为 MultiAgentDataset 的 PyTorch 数据集类,用于组合两个不同的 AgentDataset 数据集以创建一个新的多智能体数据集。该数据集用于训练神经网络等模型,以预测多个智能体(例如车辆)的运动轨迹。

from typing import List, Dict, Any, Tuple

class MultiAgentDataset(Dataset):
    def __init__(
        self,
        rast_only_agent_dataset: AgentDataset,
        history_agent_dataset: AgentDataset,
        num_neighbors: int = 10,
    ):
        super().__init__()
        self.rast_only_agent_dataset = rast_only_agent_dataset  # 光栅信息数据集
        self.history_agent_dataset = history_agent_dataset  # 历史信息数据集
        self.num_neighbors = num_neighbors  # 其他智能体数量

    def __len__(self) -> int:
        return len(self.rast_only_agent_dataset)  # 返回数据集长度

    def get_others_dict(
        self, index: int, ego_dict: Dict[str, Any]
    ) -> Tuple[List[Dict[str, Any]], int]:
        agent_index = self.rast_only_agent_dataset.agents_indices[index]  # 获取智能体索引
        frame_index = bisect.bisect_right(
            self.rast_only_agent_dataset.cumulative_sizes_agents, agent_index
        )  # 查找所属帧索引
        frame_indices = self.rast_only_agent_dataset.get_frame_indices(frame_index)
        assert len(frame_indices) >= 1, frame_indices
        frame_indices = frame_indices[frame_indices != index]  # 剔除当前智能体索引

        others_dict = []
        # 当前帧中 AV 的质心在世界参考系中的坐标,单位为米
        for idx, agent in zip(
            frame_indices,
            Subset(self.history_agent_dataset, frame_indices),
        ):
            agent["dataset_idx"] = idx
            agent["dist_to_ego"] = np.linalg.norm(
                agent["centroid"] - ego_dict["centroid"], ord=2
            )  # 计算到当前智能体的距离
            # 在未来版本中,可以通过智能体和智能驾驶汽车的转换矩阵将历史位置转换为归一化版本
            # 并获得标准化的版本
            del agent["image"]
            others_dict.append(agent)

        others_dict = sorted(others_dict, key=itemgetter("dist_to_ego"))
        others_dict = others_dict[: self.num_neighbors]
        others_len = len(others_dict)

        # 必须填充,因为 torch 不支持不规则张量
        # https://github.com/pytorch/pytorch/issues/25032
        length_to_pad = self.num_neighbors - others_len
        pad_item = deepcopy(ego_dict)
        pad_item["dataset_idx"] = index
        pad_item["dist_to_ego"] = np.nan  # 设置为 nan 以防止误用
        del pad_item["image"]
        return (others_dict + [pad_item] * length_to_pad, others_len)

    def __getitem__(self, index: int) -> Dict[str, Any]:
        rast_dict = self.rast_only_agent_dataset[index]
        ego_dict = self.history_agent_dataset[index]
        others_dict, others_len = self.get_others_dict(index, ego_dict)
        ego_dict["image"] = rast_dict["image"]
        return {
            "ego_dict": ego_dict,
            "others_dict": others_dict,
            "others_len": others_len,
        }

(4)定义一个 PyTorch Lightning 的数据模块 LyftAgentDataModule,用于管理 Lyft 自动驾驶车辆运动预测数据集的加载和处理。通过配置信息,创建了训练、验证和测试数据加载器,实现了数据的统一管理和准备,方便在 PyTorch Lightning 中进行模型训练和评估工作。

class LyftAgentDataModule(pl.LightningDataModule):
    def __init__(self, cfg: Dict = CONFIG_DATA, data_root: str = data_root):
        super().__init__()
        self.cfg = cfg
        self.dm = LocalDataManager(data_root)
        self.rast = build_rasterizer(self.cfg, self.dm)

    def chunked_dataset(self, key: str):
        dl_cfg = self.cfg[key]
        dataset_path = self.dm.require(dl_cfg["key"])
        zarr_dataset = ChunkedDataset(dataset_path)
        zarr_dataset.open()
        return zarr_dataset

    def get_dataloader_by_key(
        self, key: str, mask: Optional[np.ndarray] = None
    ) -> DataLoader:
        dl_cfg = self.cfg[key]
        zarr_dataset = self.chunked_dataset(key)
        agent_dataset = AgentDataset(
            self.cfg, zarr_dataset, self.rast, agents_mask=mask
        )
        return DataLoader(
            agent_dataset,
            shuffle=dl_cfg["shuffle"],
            batch_size=dl_cfg["batch_size"],
            num_workers=dl_cfg["num_workers"],
            pin_memory=True,
        )

    def train_dataloader(self):
        key = "train_dataloader"
        return self.get_dataloader_by_key(key)

    def val_dataloader(self):
        key = "val_dataloader"
        return self.get_dataloader_by_key(key)

    def test_dataloader(self):
        key = "test_dataloader"
        test_mask = np.load(f"{data_root}/scenes/mask.npz")["arr_0"]
        return self.get_dataloader_by_key(key, mask=test_mask)

上述代码的实现流程如下:

  1. 首先,创建类LyftAgentDataModule,继承自 PyTorch Lightning 的 类LightningDataModule,用于管理 Lyft 自动驾驶车辆运动预测数据集的加载和数据处理。
  2. 然后,在初始化函数 __init__ 中,配置信息 cfg 和数据根目录 data_root 被传递并存储。LocalDataManager 被用于管理本地数据路径,而 build_rasterizer 函数被用于创建光栅化器。
  3. 接着,chunked_dataset 函数用于获取指定键(key)对应的数据集。该函数根据配置信息中的键获取数据集路径,使用 ChunkedDataset 打开并返回。
  4. 接下来,get_dataloader_by_key 函数通过指定的键获取对应的数据加载器。首先,使用 chunked_dataset 函数获取数据集,然后使用 AgentDataset 类构建代理数据集,传入配置信息、数据集、光栅化器和智能体的掩码(如果有的话)。最后,使用 PyTorch 的 DataLoader 创建数据加载器,配置包括是否打乱数据、批量大小、工作进程数等。
  5. 最后,train_dataloader、val_dataloader 和 test_dataloader 函数分别用于获取训练、验证和测试数据加载器。这些函数调用了 get_dataloader_by_key,传入相应的键,同时在测试数据加载器中还传入了智能体的掩码。这样,整个 LyftAgentDataModule 类提供了训练、验证和测试数据加载器的统一接口,方便在 PyTorch Lightning 中进行训练和评估。

(5)定义一个名为 MultiAgentDataModule 的 PyTorch Lightning 数据模块,继承自 LyftAgentDataModule。通过对智能体数据集进行定制化配置,创建了一个用于多智能体训练的数据加载器。

from pprint import pprint
for item in datamodule.train_dataloader():
    pprint(item.keys())
    print('ego_dict keys')
    pprint(item['ego_dict'].keys())
    pprint(len(item['others_dict']))
    pprint(item['others_dict'][0].keys())
    pprint(item['others_len'])
    break

对上述代码的具体说明如下所示:

  1. 首先,定义类MultiAgentDataModule,继承自 LyftAgentDataModule。在初始化函数中调用父类的初始化,同时创建了一个用于调试的 StubRasterizer。
  2. 然后,通过 get_dataloader_by_key 函数获取训练数据加载器。该函数使用 AgentDataset 类构建了两个数据集:一个只包含光栅信息的智能体数据集和一个使用 StubRasterizer 的包含历史信息的智能体数据集。
  3. 接着,通过创建 MultiAgentDataset 实例,将上述两个数据集传递给 PyTorch 的 DataLoader,配置了是否打乱数据、批量大小、工作进程数等参数,以便用于模型的训练。
  4. 最后,通过创建 MultiAgentDataModule 实例,完成了整个数据模块的配置和准备,方便在 PyTorch Lightning 中进行多智能体训练。

(6)通过训练数据加载器获取了一个批次的数据,并使用函数 pprint打印输出了该批次数据的结构和内容信息。首先,展示了整个批次数据的键;然后,详细列出了 'ego_dict' 中的键和信息;接着,显示了 'others_dict' 列表的长度以及第一个元素的键和信息;最后展示了 'others_len' 的值,提供了对数据批次中智能驾驶汽车和其他智能体信息的详尽了解。

dict_keys(['ego_dict', 'others_dict', 'others_len'])
ego_dict keys
dict_keys(['image', 'target_positions', 'target_yaws', 'target_availabilities', 'history_positions', 'history_yaws', 'history_availabilities', 'world_to_image', 'raster_from_world', 'raster_from_agent', 'agent_from_world', 'world_from_agent', 'track_id', 'timestamp', 'centroid', 'yaw', 'extent'])
10
dict_keys(['target_positions', 'target_yaws', 'target_availabilities', 'history_positions', 'history_yaws', 'history_availabilities', 'world_to_image', 'raster_from_world', 'raster_from_agent', 'agent_from_world', 'world_from_agent', 'track_id', 'timestamp', 'centroid', 'yaw', 'extent', 'dataset_idx', 'dist_to_ego'])
tensor([10,  7,  7,  2,  5,  9,  4,  4, 10,  5,  6,  4,  5, 10,  6, 10, 10,  5,
         9,  9,  1, 10,  5,  3])

标签:预测,Trajectron,self,agent,dataset,++,dict,key,数据
From: https://blog.csdn.net/asd343442/article/details/137124878

相关文章

  • DBO优化GRNN回归预测(matlab代码)
    DBO-GRNN回归预测matlab代码蜣螂优化算法(DungBeetleOptimizer,DBO)是一种新型的群智能优化算法,在2022年底提出,主要是受蜣螂的的滚球、跳舞、觅食、偷窃和繁殖行为的启发。数据为Excel股票预测数据。数据集划分为训练集、验证集、测试集,比例为8:1:1模块化结构:代码按照功......
  • C++17 一些新特性的简单描述
    其实很多17的官方新特性早就被很多非官方的库支持,反复验证完善后被官方收录。1、std::optionalstd::optional<vector<int>>list={}/std::nullopt/{{}};不就是表示一个值存在与否是可选的吗注意下{{}}和nullopt的区别,笔者偶尔遇见过相关bug,毕竟通信行业,信息内容中空列表......
  • C++ 字符串完全指南:学习基础知识到掌握高级应用技巧
    C++字符串字符串用于存储文本。一个字符串变量包含由双引号括起来的一组字符:示例创建一个string类型的变量并为其赋值:stringgreeting="Hello";C++字符串连接字符串连接可以使用+运算符来实现,生成一个新的字符串。示例:stringfirstName="John";stringlastN......
  • 安装 Visual C++ 可再发行组件包的简单方法
    安装VisualC++RedistributablePackages的最佳方法安装对Wampserver(以及许多其他软件)至关重要的VC++可再发行组件的最简单、最简单、最不容易出错、最快的方法是使用一个程序,该程序通过单个可执行文件安装所需的所有内容。不,这不是乌托邦!它存在,它是名为VisualCppRedistA......
  • 设计算法判断一棵树是否为完全二叉树--c++
    【题目要求】设计算法判断一棵树是否为完全二叉树。【提示】根据完全二叉树的定义可知:1)如果一个结点有右孩子而没有左孩子,那么这棵树一定不是完全二叉树。2)如果一个结点有左孩子,而没有右孩子,那么按照层序遍历的结果,这个结点之后的所有结点都是叶子结点,这棵树才是完全二叉......
  • 关于C++的跨平台性
    0前言C++作为一种编译型语言,我们常常认为他是不能跨平台的。但是实际上c++就是为了跨平台而设计的。1大人,时代变了C/C++就是为了跨平台而设计的,那个时代的跨平台指的是:一次编写,到处编译。源代码写好了,我放到哪个平台都可以编译出可执行程序。因为早期各个系统都有各自的编......
  • C++重载操作符
    在C++中,重载操作符<和重载函数调用操作符()各自适用于不同的情况,它们的使用取决于你的具体需求。比较<和()重载操作符<排序和比较:当你需要定义一个类或结构体的对象如何进行排序或比较时,你会重载操作符<。这在使用标准库中的排序函数(如std::sort)、集合(如std::set......
  • C/C++ 语言中的 ​if...else if...else 语句
    C/C++语言中的​if...elseif...else语句1.`if`statement2.`if...else`statement3.`if...elseif...else`statementReferences1.ifstatementThesyntaxoftheifstatementis:if(condition){//bodyofifstatement}Thecodeins......
  • 19、C++的指针基础
    1、指针的基本概念(1)变量的地址变量是内存变量的简称,在C++中,每定义一个变量,系统就会给变量分配一块内存,内存是有地址的。C++用运算符&获取变量在内存中的起始地址。语法:&变量名(2)指针变量指针变量简称指针,它是一种特殊的变量,专用于存放变量在内存中的起始地址。语法:数据......
  • C++_基础内容复习-跟着代码学
    二进制文件读写ios_base::out 以写入方式打开文件。ios_base::binary 以二进制模式打开文件std::ofstreamofs(FILE_PATH,ios_base::app);//以追加的形式打开文件//写入学生数量intnumStudents=students.size();ofs.write(reinterpret_cast<constcha......