使用混合精度导致GNN相关模型训练时出现损失无法下降:
在一次GNN相关的项目中,由于模型训练速度过慢,楼主为了加速开启混合精度。第一天使用时并未出现异常;第二天再次使用,出现了损失函数不下降的问题。经检测,一段包含稀疏矩阵转换而且矩阵计算密集的函数与混合精度发生未知作用,导致该问题。博主关掉混合精度,问题解决了。有没有大佬解释一下。这段代码如下,用于计算图的拉普拉斯矩阵:
A = torch.tensor(coo_matrix( (numpy.ones(num_edges), (edge_index[0].detach().cpu().numpy(), edge_index[1].detach().cpu().numpy())), shape=(num_nodes, num_nodes)).toarray(), dtype=torch.float32) # (N, N) A += torch.eye(len(A), dtype=torch.float32) degree_v = torch.norm(A, dim=1) L = torch.eye(len(A), dtype=torch.float32) - torch.pow(torch.unsqueeze(degree_v, dim=1), -1) * A # (N, N) Lx = L @ x # (N, D)
标签:dtype,模型,torch,混合,num,GNN,精度 From: https://www.cnblogs.com/CEUIFS/p/17703137.html