config里
parser.add_argument('--device', type=str, default='mps')
main里
device = torch.device(cfg['device'])
train里
x_batch = x_batch.astype('float32')
y_batch = y_batch.astype('float32')
aux_batch = aux_batch.astype('float32')
x_batch = torch.from_numpy(x_batch).to(device)
aux_batch = torch.from_numpy(aux_batch).to(device)
y_batch = torch.from_numpy(y_batch).to(device)
就可以正常跑了
标签:芯片,torch,batch,astype,mac,cuda,device,aux,numpy From: https://www.cnblogs.com/xinxuann/p/17694222.html