首页 > 其他分享 >be_weight是0.1

be_weight是0.1

时间:2022-12-29 13:34:26浏览次数:34  
标签:weight 0.1 joints self hms res np output


def __call__(self, joints, sgm, ct_sgm, bg_weight=1.0):
assert self.num_joints_with_center == joints.shape[1], \
'the number of joints should be %d' % self.num_joints_with_center

hms = np.zeros((self.num_joints_with_center, self.output_res, self.output_res),
dtype=np.float32)
ignored_hms = 2*np.ones((self.num_joints_with_center, self.output_res, self.output_res),
dtype=np.float32)

hms_list = [hms, ignored_hms]

for p in joints:
for idx, pt in enumerate(p):
if idx < 17:
sigma = sgm
else:
sigma = ct_sgm
if pt[2] > 0:
x, y = pt[0], pt[1]
if x < 0 or y < 0 or \
x >= self.output_res or y >= self.output_res:
continue

ul = int(np.floor(x - 3 * sigma - 1)
), int(np.floor(y - 3 * sigma - 1))
br = int(np.ceil(x + 3 * sigma + 2)
), int(np.ceil(y + 3 * sigma + 2))

cc, dd = max(0, ul[0]), min(br[0], self.output_res)
aa, bb = max(0, ul[1]), min(br[1], self.output_res)

joint_rg = np.zeros((bb-aa, dd-cc))
for sy in range(aa, bb):
for sx in range(cc, dd):
joint_rg[sy-aa, sx -
cc] = self.get_heat_val(sigma, sx, sy, x, y)

hms_list[0][idx, aa:bb, cc:dd] = np.maximum(
hms_list[0][idx, aa:bb, cc:dd], joint_rg)
hms_list[1][idx, aa:bb, cc:dd] = 1.

hms_list[1][hms_list[1] == 2] = bg_weight

return hms_list

标签:weight,0.1,joints,self,hms,res,np,output
From: https://www.cnblogs.com/hahaah/p/17012310.html

相关文章