- 如何理解torch.max中的dim
- pytorch文档解释
torch.max = max(...)
max(input) -> Tensor
Returns the maximum value of all elements in the ``input`` tensor.
.. warning::
This function produces deterministic (sub)gradients unlike ``max(dim=0)``
input (Tensor): the input tensor.
>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.6763, 0.7445, -2.2369]])
>>> torch.max(a)
.. function:: max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum
value of each row of the :attr:`input` tensor in the given dimension
:attr:`dim`. And ``indices`` is the index location of each maximum value found
.. warning::
``indices`` does not necessarily contain the first occurrence of each
maximal value found, unless it is unique.
The exact implementation details are device-specific.
Do not expect the same result when run on CPU and GPU in general.
For the same reason do not expect the gradients to be deterministic.
If ``keepdim`` is ``True``, the output tensors are of the same size
as ``input`` except in the dimension ``dim`` where they are of size 1.
Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting
in the output tensors having 1 fewer dimension than ``input``.
input (Tensor): the input tensor.
dim (int): the dimension to reduce.
keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``.
out (tuple, optional): the result tuple of two output tensors (max, max_indices)
>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222, 0.8475],
[ 1.1949, -1.1127, -2.2379, -0.6702],
[ 1.5717, -0.9207, 0.1297, -1.8768],
[-0.6172, 1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
.. function:: max(input, other, out=None) -> Tensor
Each element of the tensor ``input`` is compared with the corresponding
element of the tensor ``other`` and an element-wise maximum is taken.
The shapes of ``input`` and ``other`` don't need to match,
but they must be :ref:`broadcastable <broadcasting-semantics>`.
.. math::
text{out}_i = max(text{tensor}_i, text{other}_i)
.. note:: When the shapes do not match, the shape of the returned output tensor
follows the :ref:`broadcasting rules <broadcasting-semantics>`.
input (Tensor): the input tensor.
other (Tensor): the second input tensor
out (Tensor, optional): the output tensor.
>>> a = torch.randn(4)
>>> a
tensor([ 0.2942, -0.7416, 0.2653, -0.1584])
>>> b = torch.randn(4)
>>> b
tensor([ 0.8722, -1.7421, -0.4141, -0.5055])
>>> torch.max(a, b)
tensor([ 0.8722, -0.7416, 0.2653, -0.1584])
发表评论 取消回复