第一个版本:
torch.max(input) → Tensor
Returns the maximum value of all elements in the input
tensor.
>>> a = torch.randn(1, 3) >>> a tensor([[ 0.6763, 0.7445, -2.2369]]) >>> torch.max(a) tensor(0.7445)
第二个版本:
torch.max(input, dim, keepdim=False, *, out=None)
- Returns a namedtuple
(values, indices)
,wherevalues
is the maximum value of each row of theinput
tensor in the given dimensiondim
.- And
indices
is the index location of each maximum value found (argmax).
- If
keepdim
isTrue
, the output tensors are of the same size asinput
except in the dimensiondim
where they are of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the output tensors having 1 fewer dimension thaninput
.
If there are multiple maximal values in a reduced row then the indices of the first maximal value are returned.
Parameters
-
input (Tensor) – the input tensor.
-
dim (int) – the dimension to reduce.
-
keepdim (bool) – whether the output tensor has
dim
retained or not. Default:False
.
>>> a = torch.randn(4, 4) >>> a tensor([[-1.2360, -0.2942, -0.1222, 0.8475], [ 1.1949, -1.1127, -2.2379, -0.6702], [ 1.5717, -0.9207, 0.1297, -1.8768], [-0.6172, 1.0036, -0.6060, -0.2432]]) >>> torch.max(a, 1) torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
标签:dim,tensor,max,torch,indices,input From: https://www.cnblogs.com/zjuhaohaoxuexi/p/16712536.html