谁能向我解释一下列表
self.data
如何使用numpy索引数组
selec_idx
进行索引/过滤?
完整的课程可在
https://github.com/kaidic/LDAM-DRW/blob/master/imbalance_cifar.py
def gen_imbalanced_data(self, img_num_per_cls):
new_data = []
new_targets = []
targets_np = np.array(self.targets, dtype=np.int64)
classes = np.unique(targets_np)
# np.random.shuffle(classes)
self.num_per_cls_dict = dict()
for the_class, the_img_num in zip(classes, img_num_per_cls):
self.num_per_cls_dict[the_class] = the_img_num
idx = np.where(targets_np == the_class)[0]
np.random.shuffle(idx)
selec_idx = idx[:the_img_num]
new_data.append(self.data[selec_idx, ...])
new_targets.extend([the_class, ] * the_img_num)
new_data = np.vstack(new_data)
self.data = new_data
self.targets = new_targets
我在网上搜索,找不到任何适合这段代码的解释。我发现您可以使用索引列表对 np.array 进行索引,但反之则不然。
代码段
self.data[selec_idx, ...]
使用了 NumPy 数组索引,它也称为
花式索引
。 这使你可以使用另一个数组(在本例中为
selec_idx
)提取原始数组(
self.data
)的特定元素。 让我们来分解一下:
-
selec_idx
:这是一个 NumPy 数组,其中包含要从self.data
中选择的元素的索引。 这些索引是通过之前代码中针对每个类的切片和随机排序操作生成的。 -
self.data[selec_idx, ...]
: 这是花式索引的魔力所在。 它使用selec_idx
中的索引来从self.data
中选择相应的行。...
是一个切片符号,表示选择所有列。
举例说明
假设
self.data
是一个类似于以下内容的 NumPy 数组:
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]]
并且
selec_idx
是一个数组
[0, 2]
。
然后,
self.data[selec_idx, ...]
会产生以下结果:
[[1, 2, 3],
[7, 8, 9]]
这是因为
selec_idx
包含索引
0
和
2
,它们对应于
self.data
的第一行和第三行。
简而言之,此代码段使用
selec_idx
中的索引有效地过滤了
self.data
,仅保留
selec_idx
中指定的索引处存在的行。
该操作对于基于特定条件或类别从数据集中选择子集非常有用,如你提供的代码中的不平衡数据集生成所示。