对于函数torch.max(tensor, dim, keepdim=False),需要厘清两个概念
一个是torch.max怎么比较的
实际上是取被比较元素对应位置的最大值
- 如果是在一个向量中比较,那每个元素都是【c】的形式,对应位置就是本身比较,宏观来讲就是取该向量最大元素。
- 如果在一个二维矩阵中比较,被比较元素就是每一个向量,对应位置就是向量的对应位置,宏观来讲就是取每一列的最大值。
- 如果是在一个三维矩阵中比较,被比较的就是各个二维矩阵,对应位置就是矩阵,宏观来讲就是取每个矩阵(i,j)位置的最大值。
dim表示和返回的indices
dim表示的是从外到里括号的维度,dim=0(从0计数)就是第1个(从1计数)括号内的个元素,dim=1就是(每一个)第2个括号内的元素,依此类推。
上面括号内的每一个表示从dim=1(如果有)那么要看有几个第二级(dim=1)括号
返回的indices矩阵值为对应位置是哪一个该维度的张量,数值表示返回的最大值张量各个位置的取哪一个元素该位置值。