概述
有关矩阵处理的函数实在是太多了,在这里写一下,方便以后回忆
torch.clamp
clamp(input,min,max,out=None)-> Tensor,将input中的元素限制在[min,max]范围内并返回一个Tensor
可以看一下栗子:
a = torch.randn(2,2)
print(a)
c = a.clamp(min = 0)
print(c)
输出。将所有小于min的数全部替换成min
tensor([[ 0.6887, -2.4910],
[-1.1766, -0.9142]])
tensor([[0.6887, 0.0000],
[0.0000, 0.0000]])
当把min替换成max后
c = a.clamp(max = 0)
看一下输出
tensor([[0.1676, 1.2737],
[0.8410, 0.2963]])
tensor([[0., 0.],
[0., 0.]])
torch.diag
取矩阵的对角元素,组成一个新的tensor,输出
b =a.diag()
输出:
tensor([[ 0.8293, -1.3060],
[-0.4743, 0.8271]])
tensor([0.8293, 0.8271])
torch.eye
torch.eye(n, m=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
可以返回参数维度的单位矩阵
d = torch.eye(3)
print(d)
输出:
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
torch.max
按照维度取最大值,返回一个tuple(tensor,dtype)
vvv = torch.randn(2,2)
print(vvv)
vv = vvv.max(1)
print(vv)
等于1,是按照行取最大;等于0,按照列取最大,输出:
tensor([[ 0.5618, 0.7352],
[-0.0405, -1.2378]])
(tensor([ 0.7352, -0.0405]), tensor([1, 0]))
所以取max(1)[0]就是取元祖的第一个元素
torch.masked_fill_(mask,value)
在mask值为1的地方用value填充
print(d)
cost_im = d.masked_fill_(d, 3)
print(cost_im)
输出:
tensor([[1, 0],
[0, 1]], dtype=torch.uint8)
tensor([[3, 0],
[0, 3]], dtype=torch.uint8)
另外的矩阵调用
a = torch.randn(2,2)
print(a)
输出:
tensor([[-0.4149, -1.3434],
[ 0.2611, -0.4930]])
print(d)
cost_im = a.masked_fill_(d, 3)
print(cost_im)
将原矩阵为1的地方用3代替,其余用a中的对应元素填充,输出:
tensor([[1, 0],
[0, 1]], dtype=torch.uint8)
tensor([[ 3.0000, -1.3434],
[ 0.2611, 3.0000]])
np.argsort(d, axis=1)
argsort()函数的作用是将数组按照从小到大的顺序排序,并按照对应的索引值输出。
argsort()函数中,当axis=0时,按列排列;当axis=1时,按行排列。如果省略默认按行排列。
inds = np.argsort(d, axis=1)
print(inds)
[[11 5]
[25 11]]
[[1 0]
[1 0]]
np.where(condition,x,y)和np.where()
满足条件(condition),输出x,不满足输出y。
只有条件 (condition),没有x和y,则输出满足条件 (即非0) 元素的坐标 (等价于numpy.nonzero)。这里的坐标以tuple的形式给出,通常原数组有多少维,输出的tuple中就包含几个数组,分别对应符合条件元素的各维坐标。
eg:
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]])
>>> np.where(a > 5)
(array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
array([2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2]),
array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]))
每个数都用三维表示,例如符合条件的6,7,8,坐标就分别是(0,2,0),(0,2,1),(0,2,2);可以竖着看,每一个坐标是一个点
矩阵定义
为什么这么简单的一直记不住!!np.array()!!!
最后
以上就是顺心战斗机为你收集整理的矩阵处理的全部内容,希望文章能够帮你解决矩阵处理所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复