概述
张量Tensor
tensor是pytorch的一种特殊的数据格式,它表示多维数组概括了所有数学意义和计算机意义上的向量形式。
Rank/shape概念
- Rank:表示我们需要多少个索引来访问或引用张量数据结构中包含的元素,即代表维度数
- Shape(size):告诉我们每个轴的长度,即每个轴上有多少个数据
Rank=len(shape) Shape是很重要的东西,因为它包含了rank,size的所有东西,一般只会用它来分析
dtype
创建一个张量的时候必须声明类型,如果你有多个卡还需要声明device
code
# array转tensor
torch.tensor(data,dtype=torch.int32)
# tensor转python
torch.item()//获取单个值
torch.tolist()//转列表
torch.numpy()//转np
# torch 内置
torch.zeros(2,2)
torch.ones(2,2)
torch.rand(2,2)
张量的运算
reshape
你可以将一个张量的shape改变成任意形状,前提只要它们的乘积相同,-1表示让Pytorch自动计算最后一个位置。
squeeze
去掉所有维数为1的的维度,对不为1的维度没有影响,不需要指定维度
t = torch.tensor([[1,2,3]],dtype=torch.float32)
t.squeeze()
#tensor([1., 2., 3.])
unsqueeze
对数据维度进行扩充。给指定位置加上维数为一的维度,需要指定维度(一定是会增加一个1维度)
t = torch.tensor([[1,2,3]],dtype=torch.float32)
t.unsqueeze(dim=2).shape
# tensor([[[1.],[2.],[3.]]])
# torch.Size([1, 3, 1])
concat
- 沿着已存在的轴连接多个tensor,把对应维度X所代表的张量进行合并
- 所有的tensor大小一致,除了需要连接那个维度,tensor不能为空
# dim=1的时候相当于
# t[0][0] 合并t[0][0] t[1][0]合并t[1][0] ...
t = torch.tensor([[1,2,3],[1,2,3]],dtype=torch.float32)
t1 = torch.tensor([[1,2,3],[1,2,3]],dtype=torch.float32)
torch.cat((t,t1),dim=1).shape
# 合并后相当于其他维不同,指定维为数量相加和[2,6]
stack
- 增加新的维度连接多个tensor
- 会先将原始数据维度扩展一维(unsqueeze),然后再按照维度进行拼接,具体拼接操作同torch.cat类似
t = torch.tensor([[1,2,3],[1,2,3]],dtype=torch.float32)
t1 = torch.tensor([[1,2,3],[1,2,3]],dtype=torch.float32)
torch.stack((t,t1),dim=1).shape
# [2,2,3] 相当于([2,3],[2,3])=>([2,1,3],[2,1,3])=>([2,2,3])
张量广播
两个张量从尾部的维度开始进行比对,维度尺寸必须满足以下一个条件方可广播:
- 或者相等,
- 或者其中一个张量的维度尺寸为 1,
- 或者其中一个张量不存在这个维度。
# 示例1:相同形状的张量总是可广播的,因为总能满足以上规则。
x = torch.empty(5, 7, 3)
y = torch.empty(5, 7, 3)
# 示例2:不可广播( a 不满足第一条规则)。
a = torch.empty((0,))
b = torch.empty(2, 2)
# 示例3:m 和 n 可广播:
m = torch.empty(5, 3, 4, 1)
n = torch.empty( 3, 1, 1)
# 倒数第一个维度:两者的尺寸均为1
# 倒数第二个维度:n尺寸为1
# 倒数第三个维度:两者尺寸相同
# 倒数第四个维度:n该维度不存在
# 示例4:不可广播,因为倒数第三个维度:2 != 3
p = torch.empty(5, 2, 4, 1)
q = torch.empty( 3, 1, 1)
张量自身元素运算
max函数就是把指定的Xdim合并为1然后去除,合并元素为最大值,保留其他维不变
t.max(dim=1)
# torch.return_types.max(values=tensor([3., 3.]),indices=tensor([2, 2]))
t.argmax(dim=1) 常用
# tensor([2, 2])
t.sum(dim=1)
# tensor([6., 6.])
t.std()
# tensor(0.8944)
整理来源
李小伟:torch的广播机制(broadcast mechanism)zhuanlan.zhihu.com最后
以上就是精明曲奇为你收集整理的pytorch 指定卡1_【科研之路】pytorch 张量的全部内容,希望文章能够帮你解决pytorch 指定卡1_【科研之路】pytorch 张量所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复