概述
文章目录
- 如何理解torch.max中的dim
- pytorch文档解释
如何理解torch.max中的dim
理解torch.max主要是一点,即dim的指定。
其实torch中的dim和numpy中的axis是一个东西,可以点击这个链接查看axis的详细解释,理解了axis后,dim就迎刃而解了。
pytorch文档解释
torch.max = max(...)
max(input) -> Tensor
Returns the maximum value of all elements in the ``input`` tensor.
.. warning::
This function produces deterministic (sub)gradients unlike ``max(dim=0)``
Args:
input (Tensor): the input tensor.
Example::
>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.6763, 0.7445, -2.2369]])
>>> torch.max(a)
tensor(0.7445)
.. function:: max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum
value of each row of the :attr:`input` tensor in the given dimension
:attr:`dim`. And ``indices`` is the index location of each maximum value found
(argmax).
.. warning::
``indices`` does not necessarily contain the first occurrence of each
maximal value found, unless it is unique.
The exact implementation details are device-specific.
Do not expect the same result when run on CPU and GPU in general.
For the same reason do not expect the gradients to be deterministic.
If ``keepdim`` is ``True``, the output tensors are of the same size
as ``input`` except in the dimension ``dim`` where they are of size 1.
Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting
in the output tensors having 1 fewer dimension than ``input``.
Args:
input (Tensor): the input tensor.
dim (int): the dimension to reduce.
keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``.
out (tuple, optional): the result tuple of two output tensors (max, max_indices)
Example::
>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222, 0.8475],
[ 1.1949, -1.1127, -2.2379, -0.6702],
[ 1.5717, -0.9207, 0.1297, -1.8768],
[-0.6172, 1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
.. function:: max(input, other, out=None) -> Tensor
Each element of the tensor ``input`` is compared with the corresponding
element of the tensor ``other`` and an element-wise maximum is taken.
The shapes of ``input`` and ``other`` don't need to match,
but they must be :ref:`broadcastable <broadcasting-semantics>`.
.. math::
text{out}_i = max(text{tensor}_i, text{other}_i)
.. note:: When the shapes do not match, the shape of the returned output tensor
follows the :ref:`broadcasting rules <broadcasting-semantics>`.
Args:
input (Tensor): the input tensor.
other (Tensor): the second input tensor
out (Tensor, optional): the output tensor.
Example::
>>> a = torch.randn(4)
>>> a
tensor([ 0.2942, -0.7416, 0.2653, -0.1584])
>>> b = torch.randn(4)
>>> b
tensor([ 0.8722, -1.7421, -0.4141, -0.5055])
>>> torch.max(a, b)
tensor([ 0.8722, -0.7416, 0.2653, -0.1584])
最后
以上就是魁梧金鱼为你收集整理的torch.max用法(指定维度dim)的全部内容,希望文章能够帮你解决torch.max用法(指定维度dim)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复