6.4.3 加载并处理Taichi数据集
文件taichi_datasets.py实现了一个 Taichi 数据集类,用于加载和处理分帧存储的视频数据,特别是太极表演相关的帧序列。它包括从数据目录中读取视频帧、按时间进行帧采样、将帧数据转换为张量并应用数据增强等功能。代码通过 torch.utils.data.Dataset 和 DataLoader 提供了对视频序列数据的高效加载和预处理,适合用于深度学习模型的训练和验证。
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
def is_image_file(filename):
"""检查文件是否为图像文件"""
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
class Taichi(data.Dataset):
def __init__(self, configs, transform, temporal_sample=None, train=True):
"""
初始化Taichi数据集类。
参数:
configs: 数据集的配置信息。
transform: 数据转换流程。
temporal_sample: 时间采样函数。
train: 是否为训练模式。
"""
self.configs = configs
self.data_path = configs.data_path
self.transform = transform
self.temporal_sample = temporal_sample
self.target_video_len = self.configs.num_frames
self.frame_interval = self.configs.frame_interval
self.data_all = self.load_video_frames(self.data_path)
self.video_num = len(self.data_all)
def __getitem__(self, index):
"""
获取数据集中的一个样本。
参数:
index: 数据索引。
返回:
包含视频张量的字典。
"""
vframes = self.data_all[index]
total_frames = len(vframes)
# 对视频帧进行采样
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= self.target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.target_video_len, dtype=int)
select_video_frames = vframes[frame_indice[0]: frame_indice[-1] + 1: self.frame_interval]
# 加载并转换视频帧
video_frames = []
for path in select_video_frames:
image = Image.open(path).convert('RGB')
video_frame = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0)
video_frames.append(video_frame)
video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
video_clip = self.transform(video_clip)
return {'video': video_clip, 'video_name': 1}
def __len__(self):
"""返回数据集的大小"""
return self.video_num
def load_video_frames(self, dataroot):
"""
加载视频帧数据。
参数:
dataroot: 数据根目录。
返回:
包含所有视频帧路径的列表。
"""
data_all = []
frame_list = os.walk(dataroot)
for _, meta in enumerate(frame_list):
root = meta[0]
try:
frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
except:
print(meta[0], meta[2])
frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
if len(frames) != 0:
data_all.append(frames)
return data_all
if __name__ == '__main__':
import argparse
import torchvision
import video_transforms
import torch.utils.data as data
from torchvision import transforms
from torchvision.utils import save_image
parser = argparse.ArgumentParser()
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--frame_interval", type=int, default=4)
parser.add_argument("--load_fron_ceph", type=bool, default=True)
parser.add_argument("--data-path", type=str, default="/path/to/datasets/taichi/taichi-256/frames/train")
config = parser.parse_args()
target_video_len = config.num_frames
temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
trans = transforms.Compose([
video_transforms.ToTensorVideo(),
video_transforms.RandomHorizontalFlipVideo(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
taichi_dataset = Taichi(config, transform=trans, temporal_sample=temporal_sample)
taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
for i, video_data in enumerate(taichi_dataloader):
print(video_data['video'].shape)
6.4.4 加载并处理UCF101数据集
文件ucf101_image_datasets.py实现了一个数据加载器,用于加载 UCF101 数据集的帧数据和图像,主要用于视频相关的深度学习任务。它结合了视频帧的时间采样与图像加载功能,并支持对视频与图像的预处理(如裁剪、归一化等)。此外,它还能按类别对数据进行组织与处理,方便模型训练和验证。
class_labels_map = None
cls_sample_cnt = None
def temporal_sampling(frames, start_idx, end_idx, num_samples):
"""
给定起始帧和结束帧索引,采样指定数量的帧。
参数:
frames (tensor): 视频帧的张量,维度为 `视频帧数` x `通道数` x `高度` x `宽度`。
start_idx (int): 起始帧的索引。
end_idx (int): 结束帧的索引。
num_samples (int): 需要采样的帧数。
返回:
frames (tensor): 时间采样后的视频帧张量,维度为 `采样帧数` x `通道数` x `高度` x `宽度`。
"""
index = torch.linspace(start_idx, end_idx, num_samples)
index = torch.clamp(index, 0, frames.shape[0] - 1).long()
frames = torch.index_select(frames, 0, index)
return frames
def get_filelist(file_path):
"""
获取指定路径下的所有文件列表。
"""
Filelist = []
for home, dirs, files in os.walk(file_path):
for filename in files:
Filelist.append(os.path.join(home, filename))
return Filelist
def load_annotation_data(data_file_path):
"""
加载注释数据文件。
"""
with open(data_file_path, 'r') as data_file:
return json.load(data_file)
def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
"""
获取类别标签映射和每个类别的样本计数。
参数:
num_class (int): 类别数。
anno_pth (str): 类别映射文件的路径。
返回:
class_labels_map (dict): 类别名称到索引的映射。
cls_sample_cnt (dict): 每个类别的样本计数。
"""
global class_labels_map, cls_sample_cnt
if class_labels_map is not None:
return class_labels_map, cls_sample_cnt
else:
cls_sample_cnt = {}
class_labels_map = load_annotation_data(anno_pth)
for cls in class_labels_map:
cls_sample_cnt[cls] = 0
return class_labels_map, cls_sample_cnt
def load_annotations(ann_file, num_class, num_samples_per_cls):
"""
加载数据集注释信息,并根据类别和样本限制进行筛选。
参数:
ann_file (str): 注释文件路径。
num_class (int): 选择的类别数。
num_samples_per_cls (int): 每个类别的最大样本数。
返回:
dataset (list): 包含视频路径和类别信息的数据集列表。
"""
dataset = []
class_to_idx, cls_sample_cnt = get_class_labels(num_class)
with open(ann_file, 'r') as fin:
for line in fin:
line_split = line.strip().split('\t')
sample = {}
idx = 0
frame_dir = line_split[idx]
sample['video'] = frame_dir
idx += 1
label = [x for x in line_split[idx:]]
assert label, f'注释缺少标签: {line}'
assert len(label) == 1
class_name = label[0]
class_index = int(class_to_idx[class_name])
if class_index < num_class:
sample['label'] = class_index
if cls_sample_cnt[class_name] < num_samples_per_cls:
dataset.append(sample)
cls_sample_cnt[class_name] += 1
return dataset
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""
在数据集目录中找到类别文件夹。
返回:
classes (list): 类别名称列表。
class_to_idx (dict): 类别到索引的映射。
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"在 {directory} 找不到任何类别文件夹。")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
class DecordInit(object):
"""
使用 Decord 初始化视频读取器。
详情参考: https://github.com/dmlc/decord
"""
def __init__(self, num_threads=1):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
def __call__(self, filename):
"""
初始化视频读取器。
参数:
filename (str): 视频文件路径。
返回:
reader (VideoReader): 视频读取器对象。
"""
reader = decord.VideoReader(filename,
ctx=self.ctx,
num_threads=self.num_threads)
return reader
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'sr={self.sr},'
f'num_threads={self.num_threads})')
return repr_str
class UCF101Images(torch.utils.data.Dataset):
"""
UCF101 数据集加载器,用于加载视频帧和图像。
参数:
target_video_len (int): 加载的视频帧数量。
align_transform (callable): 对视频进行对齐处理。
temporal_sample (callable): 对视频帧进行时间采样。
"""
def __init__(self,
configs,
transform=None,
temporal_sample=None):
self.configs = configs
self.data_path = configs.data_path
self.video_lists = get_filelist(configs.data_path)
self.transform = transform
self.temporal_sample = temporal_sample
self.target_video_len = self.configs.num_frames
self.v_decoder = DecordInit()
self.classes, self.class_to_idx = find_classes(self.data_path)
self.video_num = len(self.video_lists)
self.frame_data_path = configs.frame_data_path
self.video_frame_txt = configs.frame_data_txt
self.video_frame_files = [frame_file.strip() for frame_file in open(self.video_frame_txt)]
random.shuffle(self.video_frame_files)
self.use_image_num = configs.use_image_num
self.image_tranform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
self.video_frame_num = len(self.video_frame_files)
def __getitem__(self, index):
"""
获取指定索引处的数据。
"""
video_index = index % self.video_num
path = self.video_lists[video_index]
class_name = path.split('/')[-2]
class_index = self.class_to_idx[class_name]
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
total_frames = len(vframes)
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= self.target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
video = vframes[frame_indice]
video = self.transform(video)
images = []
image_names = []
for i in range(self.use_image_num):
while True:
try:
video_frame_path = self.video_frame_files[index+i]
image_class_name = video_frame_path.split('_')[1]
image_class_index = self.class_to_idx[image_class_name]
video_frame_path = os.path.join(self.frame_data_path, video_frame_path)
image = Image.open(video_frame_path).convert('RGB')
image = self.image_tranform(image).unsqueeze(0)
images.append(image)
image_names.append(str(image_class_index))
break
except Exception as e:
index = random.randint(0, self.video_frame_num - self.use_image_num)
images = torch.cat(images, dim=0)
assert len(images) == self.use_image_num
assert len(image_names) == self.use_image_num
image_names = '====='.join(image_names)
video_cat = torch.cat([video, images], dim=0)
return {'video': video_cat,
'video_name': class_index,
'image_name': image_names}
def __len__(self):
return self.video_frame_num
if __name__ == '__main__':
import argparse
import video_transforms
import torch.utils.data as Data
import torchvision.transforms as transforms
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--frame_interval", type=int, default=3)
parser.add_argument("--data_path", type=str, default='./data/video/')
parser.add_argument("--frame_data_path", type=str, default='./data/image')
parser.add_argument("--frame_data_txt", type=str, default='./data/image.txt')
parser.add_argument("--use_image_num", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=8)
args = parser.parse_args()
temporal_sample = video_transforms.TemporalSampling(video_sample_num=args.num_frames, frame_interval=args.frame_interval)
transform = transforms.Compose([video_transforms.Resize((112, 112)),
video_transforms.ToTensor()])
data_set = UCF101Images(args, transform=transform, temporal_sample=temporal_sample)
loader = Data.DataLoader(data_set, batch_size=args.batch_size)
for idx, data in enumerate(loader):
print(data['video'].size())
标签:处理,data,self,class,num,video,数据,frame,加载
From: https://blog.csdn.net/asd343442/article/details/145168536