def weight_init(m): # 初始化权重 # print(m) if isinstance(m, torch.nn.Conv3d): n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels m.weight.data = torch.randint_like(m.weight.data, low=-128, high=127) # m.bias.data.zero_() if m.bias!=None: m.bias.data = torch.randint_like(m.bias.data, low=-128, high=127) elif isinstance(m, torch.nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # data = np.load("weight.npy") # m.weight.data = torch.tensor(data) m.weight.data = torch.randint_like(m.weight.data, low=-128, high=127) # print("weight",m.weight.data.shape) # print(m.weight.data) # print(m.weight.data) # m=torch.nn.Conv2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, bias=True, stride=m.stride, padding=m.padding) if m.bias!=None: m.bias.data = torch.randint_like(m.bias.data, low=-128, high=127) elif isinstance(m, torch.nn.BatchNorm3d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, torch.nn.Linear): m.weight.data=torch.randint_like(m.weight.data, low=-128, high=127) if m.bias is not None: m.bias.data.zero_() # 将模型权重初始化为int8 model.apply(weight_init)
标签:初始化,权重,weight,kernel,模型,torch,bias,data,size From: https://www.cnblogs.com/LuckCoder/p/17247332.html