首页 > 其他分享 >(14-3-02)基于Latent Diffusion Transformer的文生视频系统:数据集处理(2)加载并处理Taichi数据集+加载并处理UCF101数据集

(14-3-02)基于Latent Diffusion Transformer的文生视频系统:数据集处理(2)加载并处理Taichi数据集+加载并处理UCF101数据集

时间:2025-01-16 16:32:57浏览次数:3  
标签:处理 data self class num video 数据 frame 加载

6.4.3  加载并处理Taichi数据集

文件taichi_datasets.py实现了一个 Taichi 数据集类,用于加载和处理分帧存储的视频数据,特别是太极表演相关的帧序列。它包括从数据目录中读取视频帧、按时间进行帧采样、将帧数据转换为张量并应用数据增强等功能。代码通过 torch.utils.data.Dataset 和 DataLoader 提供了对视频序列数据的高效加载和预处理,适合用于深度学习模型的训练和验证。

IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']

def is_image_file(filename):
    """检查文件是否为图像文件"""
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

class Taichi(data.Dataset):
    def __init__(self, configs, transform, temporal_sample=None, train=True):
        """
        初始化Taichi数据集类。
        
        参数:
            configs: 数据集的配置信息。
            transform: 数据转换流程。
            temporal_sample: 时间采样函数。
            train: 是否为训练模式。
        """
        self.configs = configs
        self.data_path = configs.data_path
        self.transform = transform
        self.temporal_sample = temporal_sample
        self.target_video_len = self.configs.num_frames
        self.frame_interval = self.configs.frame_interval
        self.data_all = self.load_video_frames(self.data_path)
        self.video_num = len(self.data_all)

    def __getitem__(self, index):
        """
        获取数据集中的一个样本。
        
        参数:
            index: 数据索引。
        返回:
            包含视频张量的字典。
        """
        vframes = self.data_all[index]
        total_frames = len(vframes)

        # 对视频帧进行采样
        start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
        assert end_frame_ind - start_frame_ind >= self.target_video_len
        frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.target_video_len, dtype=int)
        select_video_frames = vframes[frame_indice[0]: frame_indice[-1] + 1: self.frame_interval]

        # 加载并转换视频帧
        video_frames = []
        for path in select_video_frames:
            image = Image.open(path).convert('RGB')
            video_frame = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0)
            video_frames.append(video_frame)
        video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
        video_clip = self.transform(video_clip)

        return {'video': video_clip, 'video_name': 1}

    def __len__(self):
        """返回数据集的大小"""
        return self.video_num

    def load_video_frames(self, dataroot):
        """
        加载视频帧数据。
        
        参数:
            dataroot: 数据根目录。
        返回:
            包含所有视频帧路径的列表。
        """
        data_all = []
        frame_list = os.walk(dataroot)
        for _, meta in enumerate(frame_list):
            root = meta[0]
            try:
                frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
            except:
                print(meta[0], meta[2])
            frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
            if len(frames) != 0:
                data_all.append(frames)
        return data_all

if __name__ == '__main__':
    import argparse
    import torchvision
    import video_transforms
    import torch.utils.data as data

    from torchvision import transforms
    from torchvision.utils import save_image

    parser = argparse.ArgumentParser()
    parser.add_argument("--num_frames", type=int, default=16)
    parser.add_argument("--frame_interval", type=int, default=4)
    parser.add_argument("--load_fron_ceph", type=bool, default=True)
    parser.add_argument("--data-path", type=str, default="/path/to/datasets/taichi/taichi-256/frames/train")
    config = parser.parse_args()

    target_video_len = config.num_frames

    temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
    trans = transforms.Compose([
        video_transforms.ToTensorVideo(),
        video_transforms.RandomHorizontalFlipVideo(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])

    taichi_dataset = Taichi(config, transform=trans, temporal_sample=temporal_sample)
    taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)

    for i, video_data in enumerate(taichi_dataloader):
        print(video_data['video'].shape)

6.4.4  加载并处理UCF101数据集

文件ucf101_image_datasets.py实现了一个数据加载器,用于加载 UCF101 数据集的帧数据和图像,主要用于视频相关的深度学习任务。它结合了视频帧的时间采样与图像加载功能,并支持对视频与图像的预处理(如裁剪、归一化等)。此外,它还能按类别对数据进行组织与处理,方便模型训练和验证。

class_labels_map = None
cls_sample_cnt = None


def temporal_sampling(frames, start_idx, end_idx, num_samples):
    """
    给定起始帧和结束帧索引,采样指定数量的帧。
    参数:
        frames (tensor): 视频帧的张量,维度为 `视频帧数` x `通道数` x `高度` x `宽度`。
        start_idx (int): 起始帧的索引。
        end_idx (int): 结束帧的索引。
        num_samples (int): 需要采样的帧数。
    返回:
        frames (tensor): 时间采样后的视频帧张量,维度为 `采样帧数` x `通道数` x `高度` x `宽度`。
    """
    index = torch.linspace(start_idx, end_idx, num_samples)
    index = torch.clamp(index, 0, frames.shape[0] - 1).long()
    frames = torch.index_select(frames, 0, index)
    return frames


def get_filelist(file_path):
    """
    获取指定路径下的所有文件列表。
    """
    Filelist = []
    for home, dirs, files in os.walk(file_path):
        for filename in files:
            Filelist.append(os.path.join(home, filename))
    return Filelist


def load_annotation_data(data_file_path):
    """
    加载注释数据文件。
    """
    with open(data_file_path, 'r') as data_file:
        return json.load(data_file)


def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
    """
    获取类别标签映射和每个类别的样本计数。
    参数:
        num_class (int): 类别数。
        anno_pth (str): 类别映射文件的路径。
    返回:
        class_labels_map (dict): 类别名称到索引的映射。
        cls_sample_cnt (dict): 每个类别的样本计数。
    """
    global class_labels_map, cls_sample_cnt
    
    if class_labels_map is not None:
        return class_labels_map, cls_sample_cnt
    else:
        cls_sample_cnt = {}
        class_labels_map = load_annotation_data(anno_pth)
        for cls in class_labels_map:
            cls_sample_cnt[cls] = 0
        return class_labels_map, cls_sample_cnt


def load_annotations(ann_file, num_class, num_samples_per_cls):
    """
    加载数据集注释信息,并根据类别和样本限制进行筛选。
    参数:
        ann_file (str): 注释文件路径。
        num_class (int): 选择的类别数。
        num_samples_per_cls (int): 每个类别的最大样本数。
    返回:
        dataset (list): 包含视频路径和类别信息的数据集列表。
    """
    dataset = []
    class_to_idx, cls_sample_cnt = get_class_labels(num_class)
    with open(ann_file, 'r') as fin:
        for line in fin:
            line_split = line.strip().split('\t')
            sample = {}
            idx = 0
            frame_dir = line_split[idx]
            sample['video'] = frame_dir
            idx += 1
            label = [x for x in line_split[idx:]]
            assert label, f'注释缺少标签: {line}'
            assert len(label) == 1
            class_name = label[0]
            class_index = int(class_to_idx[class_name])
            if class_index < num_class:
                sample['label'] = class_index
                if cls_sample_cnt[class_name] < num_samples_per_cls:
                    dataset.append(sample)
                    cls_sample_cnt[class_name] += 1
    return dataset


def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """
    在数据集目录中找到类别文件夹。
    返回:
        classes (list): 类别名称列表。
        class_to_idx (dict): 类别到索引的映射。
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"在 {directory} 找不到任何类别文件夹。")
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx


class DecordInit(object):
    """
    使用 Decord 初始化视频读取器。
    详情参考: https://github.com/dmlc/decord
    """

    def __init__(self, num_threads=1):
        self.num_threads = num_threads
        self.ctx = decord.cpu(0)
        
    def __call__(self, filename):
        """
        初始化视频读取器。
        参数:
            filename (str): 视频文件路径。
        返回:
            reader (VideoReader): 视频读取器对象。
        """
        reader = decord.VideoReader(filename,
                                    ctx=self.ctx,
                                    num_threads=self.num_threads)
        return reader

    def __repr__(self):
        repr_str = (f'{self.__class__.__name__}('
                    f'sr={self.sr},'
                    f'num_threads={self.num_threads})')
        return repr_str


class UCF101Images(torch.utils.data.Dataset):
    """
    UCF101 数据集加载器,用于加载视频帧和图像。
    参数:
        target_video_len (int): 加载的视频帧数量。
        align_transform (callable): 对视频进行对齐处理。
        temporal_sample (callable): 对视频帧进行时间采样。
    """

    def __init__(self,
                 configs,
                 transform=None,
                 temporal_sample=None):
        self.configs = configs
        self.data_path = configs.data_path
        self.video_lists = get_filelist(configs.data_path)
        self.transform = transform
        self.temporal_sample = temporal_sample
        self.target_video_len = self.configs.num_frames
        self.v_decoder = DecordInit()
        self.classes, self.class_to_idx = find_classes(self.data_path)
        self.video_num = len(self.video_lists)

        self.frame_data_path = configs.frame_data_path
        self.video_frame_txt = configs.frame_data_txt
        self.video_frame_files = [frame_file.strip() for frame_file in open(self.video_frame_txt)]
        random.shuffle(self.video_frame_files)
        self.use_image_num = configs.use_image_num
        self.image_tranform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
        self.video_frame_num = len(self.video_frame_files)

    def __getitem__(self, index):
        """
        获取指定索引处的数据。
        """
        video_index = index % self.video_num
        path = self.video_lists[video_index]
        class_name = path.split('/')[-2]
        class_index = self.class_to_idx[class_name]
        vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
        total_frames = len(vframes)
        start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
        assert end_frame_ind - start_frame_ind >= self.target_video_len
        frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
        video = vframes[frame_indice]
        video = self.transform(video)
        images = []
        image_names = []
        for i in range(self.use_image_num):
            while True:
                try:      
                    video_frame_path = self.video_frame_files[index+i]
                    image_class_name = video_frame_path.split('_')[1]
                    image_class_index = self.class_to_idx[image_class_name]
                    video_frame_path = os.path.join(self.frame_data_path, video_frame_path)
                    image = Image.open(video_frame_path).convert('RGB')
                    image = self.image_tranform(image).unsqueeze(0)
                    images.append(image)
                    image_names.append(str(image_class_index))
                    break
                except Exception as e:
                    index = random.randint(0, self.video_frame_num - self.use_image_num)
        images =  torch.cat(images, dim=0)
        assert len(images) == self.use_image_num
        assert len(image_names) == self.use_image_num
        image_names = '====='.join(image_names)
        video_cat = torch.cat([video, images], dim=0)
        return {'video': video_cat, 
                'video_name': class_index, 
                'image_name': image_names}

    def __len__(self):
        return self.video_frame_num


if __name__ == '__main__':
    import argparse
    import video_transforms
    import torch.utils.data as Data
    import torchvision.transforms as transforms
    
    from PIL import Image

    parser = argparse.ArgumentParser()
    parser.add_argument("--num_frames", type=int, default=16)
    parser.add_argument("--frame_interval", type=int, default=3)
    parser.add_argument("--data_path", type=str, default='./data/video/')
    parser.add_argument("--frame_data_path", type=str, default='./data/image')
    parser.add_argument("--frame_data_txt", type=str, default='./data/image.txt')
    parser.add_argument("--use_image_num", type=int, default=4)
    parser.add_argument("--batch_size", type=int, default=8)
    args = parser.parse_args()
    
    temporal_sample = video_transforms.TemporalSampling(video_sample_num=args.num_frames, frame_interval=args.frame_interval)
    transform = transforms.Compose([video_transforms.Resize((112, 112)),
                                    video_transforms.ToTensor()])
    data_set = UCF101Images(args, transform=transform, temporal_sample=temporal_sample)
    loader = Data.DataLoader(data_set, batch_size=args.batch_size)
    
    for idx, data in enumerate(loader):
        print(data['video'].size())

标签:处理,data,self,class,num,video,数据,frame,加载
From: https://blog.csdn.net/asd343442/article/details/145168536

相关文章

  • Redis动态热点数据缓存策略设计
    Redis动态热点数据缓存策略设计1.热点数据识别机制1.1计数器方式@ServicepublicclassHotDataCounter{@AutowiredprivateRedisTemplate<String,Object>redisTemplate;//访问计数publicvoidincrementCounter(Stringkey){Strin......
  • 大模型书籍推荐:Transformer自然语言处理: 构建语言应用,附409页pdf免费下载
    今天给大家推荐一本Transformer大模型书籍《Transformer自然语言处理:构建语言应用》Transformers已经被用来编写真实的新闻故事,改进谷歌搜索查询,甚至创造出讲笑话的聊天机器人。在本指南中,作者LewisTunstall、LeandrovonWerra和ThomasWolf(拥抱Transformers的创始......
  • Linux驱动开发:处理空指针错误,ERR_PTR、IS_ERR、PTR_ERR用法
    免责声明:本文内容摘自《Linux设备驱动开发》一书,作者为JohnMadieu,译者为袁鹏飞、刘寿永,由人民邮电出版社出版。本文仅为分享知识和讨论之用,非商业用途。书籍版权归原作者及出版社所有。本人及本博客不对因使用或误用本文内容而产生的任何后果负责。请读者尊重版权,合理使用内容。......
  • 如何在红旗系统安装PGSQL数据库
    红旗系统安装PGSQL教程一、下载pgsql源码二、创建pgsql用户三、创建pgsql目录四、解压源码五、配置构建环境六、编译和安装七、创建数据库目录八、初始化数据库集群九、启动数据库十、添加环境变量十一、连接数据库十二、创建数据库用户十三、外部连接工具访问设置一......
  • MYSQL数据类型
    数据类型结构化数据、例如关系型数据库半结构化数据、HTML、XML、JSON非结构化数据SQL(结构化查询语言)命令关系型数据库擅长处理结构化数据、可以通过结构化查询语言对数据进行CRUD(增删改查)DDL(数据定义语言):主要包含的命令有create、drop、a......
  • springboot环境下的rokectMQ多数据源实现
    业务原因,需要在一个项目中与多方MQ进行业务通信;步骤一,复制一份RocketMQProperties配置文件,避免与原来的冲突packagecom.heit.road.web.config;importorg.apache.rocketmq.common.topic.TopicValidator;importjava.util.HashMap;importjava.util.Map;publicclassMu......
  • pg数据库下 关于时间日期的取值
    --century世纪selectdate_part('century',now()::TIMESTAMP);--day天selectdate_part('day',now()::TIMESTAMP);--decade十年,即年份除以10selectdate_part('decade',now()::TIMESTAMP);--dow星期(星期天0,星期六6)selectdate_part('do......
  • elasticsearch之DSL查询结果处理
    搜索的结果可以按照用户指定的方式去处理或展示。排序分页搜索关键词高亮排序elasticsearch默认是根据相关度算分(_score)来排序,但是也支持自定义方式对搜索结果排序。可以排序字段类型有:keyword类型、数值类型、地理坐标类型、日期类型等。普通字段排序keyword、数值、日......
  • NLP意图识别数据集处理流程
    NLP意图识别数据集处理流程引言自然语言处理(NLP)技术近年来发展迅速,尤其是在对话系统和聊天机器人领域。意图识别作为其中的一个关键任务,旨在理解用户输入背后的意图,并据此作出适当的响应。为了训练高效的意图识别模型,我们需要一个精心准备的数据集。本博客将介绍处理NLP意......
  • Pandas数据重命名:列名与索引为标题
    目录一、引言二、Pandasrename方法简介三、列名重命名3.1使用字典进行列名重命名3.2使用函数进行列名重命名四、索引重命名4.1使用字典进行索引重命名4.2使用函数进行索引重命名五、同时重命名列名和索引六、原地修改与返回新对象七、处理MultiIndex(多级索引)......