概述
Pytorch中cat和stack的用法
- 详解cat和stack
- Pytorch学习说明
- ```torch.Cat()```
- ```torch.stack()```
详解cat和stack
本文为原创,仅供交流学习,转载请注明出处,谢谢!
Pytorch学习说明
在这里推荐两份文档:
1.Pytorch中文手册:这就是一本"新华字典"。
2.动手学习深度学习:进一步学习DeepLearning
最近在学习Pytorch,在CNN的部分遇到了torch.stack()
和torch.cat()
两个函数。在网上查阅了很多博客,才搞清楚这两个函数的作用。在这里稍微总结一下。
torch.Cat()
格式说明:
torch.cat(inputs, dimension=0) → Tensor
参数:
inputs (sequence of Tensors) – 可以是任意相同Tensor 类型的python 序列
dimension (int, optional) – 沿着此维连接张量序列。
这个函数还是比较好理解的。
我们举个栗子:
我们定义了张量A,张量B(tensor类型也叫张量)。
A=torch.tensor([[1,2,3],[4,5,6]],dtype=torch.float)
print("A:",A)
B=torch.tensor([[-1,-2,-3],[-4,-5,-6],[-7,-8,-9]],dtype=torch.float)
print("B:",B)
输出结果:
A: tensor([[1., 2., 3.],
[4., 5., 6.]])
B: tensor([[-1., -2., -3.],
[-4., -5., -6.],
[-7., -8., -9.]])
我们可以看到A的尺寸:(2,3), B的尺寸:(3,3)。这里有个小tip:dtype=torch.float,不然会报错。
1. dim=0的情况
接下来我们在dim=0上执行cat,代码如下:
print("dim=0:",torch.cat((A,B),dim=0))
输出结果:
dim=0: tensor([[ 1., 2., 3.],
[ 4., 5., 6.],
[-1., -2., -3.],
[-4., -5., -6.],
[-7., -8., -9.]])
其实就是将A,B两个张量垂直拼接,思考一下:如果两个张量的列不相同会怎么样,还能在dim=0上做cat操作吗?
我们将B的尺寸改为(3,4)结果如下:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-51-b2487a7f4fe4> in <module>()
4 print("B:",B)
5 print("*********************************")
----> 6 print("dim=0:",torch.cat((A,B),dim=0))
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 4 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:711
2. dim=1的情况
类似与dim=0,dim=1其实就是将两个张量水平拼接起来。同样的如果两个张量行号不相同,还是会报错。
A=torch.tensor([[1,2,3],[4,5,6]],dtype=torch.float)
print("A:",A)
B=torch.tensor([[-1,-2,-3],[-4,-5,-6],[-7,-8,-9]],dtype=torch.float)
print("B:",B)
print("*********************************")
print("dim=1:",torch.cat((A,B),dim=1))
A: tensor([[1., 2., 3.],
[4., 5., 6.]])
B: tensor([[-1., -2., -3.],
[-4., -5., -6.],
[-7., -8., -9.]])
*********************************
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-52-a0255954420e> in <module>()
4 print("B:",B)
5 print("*********************************")
----> 6 print("dim=1:",torch.cat((A,B),dim=1))
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 2 and 3 in dimension 0 at /pytorch/aten/src/TH/generic/THTensor.cpp:711
修改下A,B后,演示如下:
A=torch.tensor([[1,2,3],[4,5,6],[7,8,9]],dtype=torch.float)
print("A:",A)
B=torch.tensor([[-1,-2,-3],[-4,-5,-6],[-7,-8,-9]],dtype=torch.float)
print("B:",B)
print("*********************************")
print("dim=1:",torch.cat((A,B),dim=1))
结果如下:
A: tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]])
B: tensor([[-1., -2., -3.],
[-4., -5., -6.],
[-7., -8., -9.]])
*********************************
dim=1: tensor([[ 1., 2., 3., -1., -2., -3.],
[ 4., 5., 6., -4., -5., -6.],
[ 7., 8., 9., -7., -8., -9.]])
我们可以看出输出的是A,B水平拼接的结果。
torch.stack()
格式说明:
torch.stack(sequence, dim=0)
沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
参数:
sqequence (Sequence) – 待连接的张量序列
dim (int) – 插入的维度。必须介于 0 与 待连接的张量序列数之间。
和cat一样我们,在这里举个例子说明:
定义了A,B,C三个张量,这里要求ABC都为相同形状。
A=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
B=10*A
C=100*A
print("A:",A)
print("B:",B)
print("C:",C)
结果如下:
A: tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
B: tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
C: tensor([[100, 200, 300],
[400, 500, 600],
[700, 800, 900]])
1. dim=0
d0=torch.stack((A,B,C),dim=0)
print("dim=0:",d0)
print("d0[0][0][0]:",d0[0][0][0])
print("d0[0][0]:",d0[0][0])
print("d0[0]:",d0[0])
结果如下:
dim=0: tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[ 10, 20, 30],
[ 40, 50, 60],
[ 70, 80, 90]],
[[100, 200, 300],
[400, 500, 600],
[700, 800, 900]]])
d0[0][0][0]: tensor(1)
d0[0][0]: tensor([1, 2, 3])
d0[0]: tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
我们输出了d1的第2,1,0维元素,我们可以清楚的看到当dim=0
时候,torch.stack
类似于
torch.cat
。我们可以看出,dim=0
的作用其实就是把所有的d[i]
拼接在了一起。简单的说,将A,B,C垂直拼接。
2. dim=1
d1=torch.stack((A,B,C),dim=1)
print("dim=1:",d1)
print("d1[0][0][0]:",d1[0][0][0])
print("d1[0][0]:",d1[0][0])
print("d1[0]:",d1[0])
结果如下:
dim=1: tensor([[[ 1, 2, 3],
[ 10, 20, 30],
[100, 200, 300]],
[[ 4, 5, 6],
[ 40, 50, 60],
[400, 500, 600]],
[[ 7, 8, 9],
[ 70, 80, 90],
[700, 800, 900]]])
d1[0][0][0]: tensor(1)
d1[0][0]: tensor([1, 2, 3])
d1[0]: tensor([[ 1, 2, 3],
[ 10, 20, 30],
[100, 200, 300]])
我们可以看出,dim=1
的作用其实就是把ABC的d[i][i]
拼接在了一起。也就是说将每个张量的相同位置的行向量拼接在一起。如A[1],B[1],C[1]=[4,5,6],[40,50,60],[400,500,600]
3. dim=2
d2=torch.stack((A,B,C),dim=2)
print("dim=2",d2)
print("d2[0][0][0]:",d2[0][0][0])
print("d2[0][0]:",d2[0][0])
print("d2[0]:",d2[0])
结果如下:
dim=2 tensor([[[ 1, 10, 100],
[ 2, 20, 200],
[ 3, 30, 300]],
[[ 4, 40, 400],
[ 5, 50, 500],
[ 6, 60, 600]],
[[ 7, 70, 700],
[ 8, 80, 800],
[ 9, 90, 900]]])
d2[0][0][0]: tensor(1)
d2[0][0]: tensor([ 1, 10, 100])
d2[0]: tensor([[ 1, 10, 100],
[ 2, 20, 200],
[ 3, 30, 300]])
我们可以看出,dim=2
的作用其实就是把ABC的d[i][i][i]
拼接在了一起。即将A,B,C三个张量的相同位置的元素拼接在一起。如A[1][1],B[1][1],C[1][1]=5,50,500
本人pytorch小白,完全为了交流学习。仅为个人理解,如果有错误地方请指出。欢迎一起交流学习。
最后
以上就是等待野狼为你收集整理的Pytorch中cat和stack的用法详解cat和stackPytorch学习说明的全部内容,希望文章能够帮你解决Pytorch中cat和stack的用法详解cat和stackPytorch学习说明所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复