矩阵乘法求导
pyotrch中只能是标量对矩阵求导,所以矩阵乘法结束后加个sum
\[L = sum(\bm{WX}) \]其中,\(\bm{W}\)和\(\bm{X}\)都是矩阵,那么
\[\frac{\partial L}{\partial\bm{W}}_{\cdot i}=\sum\bm{X}_{i\cdot} \]梯度和W的形状相同,梯度中每列都是相同的,只要是第i列,梯度值就是\(\bm{X}\)的第i行的和。
用公式不太好表示,我们用pytorch代码来描述一下:
>>> w = torch.arange(0,50,dtype=torch.float32, requires_grad=True)
>>> nw = w.view(10, 5)
>>> x = torch.arange(50,100, dtype=torch.float32, requires_grad=True)
>>> nx = x.view(5,10)
>>> loss = torch.matmul(nw, nx)
>>> sum_loss = loss.sum()
>>> sum_loss.backward()
>>> print(w.grad.view_as(nw),'\n', nx.detach().sum(dim=1))
tensor([[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.]])
tensor([545., 645., 745., 845., 945.])
这个推导过程也不是很复杂,可以自己举个例子试试。
标签:845,745,sum,945,矩阵,645,545,求导,乘法 From: https://www.cnblogs.com/wangbingbing/p/17481291.html