概述
理解ConvTranspose2d
操作
文章目录
- 理解`ConvTranspose2d`操作
- 输入与输出
- 计算过程详解
ConvTranspose2d
是一种常用的可以对图像进行上采样的方法,可以用于扩大图像尺寸。其本质上也是一个卷积操作,目的是恢复对应的卷积参数下,卷积前的原始图像大小。文章ConvTranspose2d原理,深度网络如何进行上采样?搭配了动图介绍了其计算过程,比较直观。这里通过代码的方式对该模块的输入输出以及计算过程进行解释。 首先,文章分析了该模块的输入输出,然后通过自定义的ConvTranspose2d模块,解释了算法的计算过程。
输入与输出
对于输入通道为 ( N , C i n , H i n , W i n ) (N, C_{in}, H_{in}, W_{in}) (N,Cin,Hin,Win)的图像,模型的输出为 ( N , C o u t , H o u t , W o u t ) (N, C_{out}, H_{out}, W_{out}) (N,Cout,Hout,Wout),
其中:
H
o
u
t
=
(
H
i
n
−
1
)
×
stride
[
0
]
−
2
×
padding
[
0
]
+
dilation
[
0
]
×
(
kernel_size
[
0
]
−
1
)
+
output_padding
[
0
]
+
1
H_{out} = (H_{in} - 1) times text{stride}[0] - 2 times text{padding}[0] + text{dilation}[0] times (text{kernel_size}[0] - 1) + text{output_padding}[0] + 1
Hout=(Hin−1)×stride[0]−2×padding[0]+dilation[0]×(kernel_size[0]−1)+output_padding[0]+1
W o u t = ( W i n − 1 ) × stride [ 1 ] − 2 × padding [ 1 ] + dilation [ 1 ] × ( kernel_size [ 1 ] − 1 ) + output_padding [ 1 ] + 1 W_{out} = (W_{in} - 1) times text{stride}[1] - 2 times text{padding}[1] + text{dilation}[1] times (text{kernel_size}[1] - 1) + text{output_padding}[1] + 1 Wout=(Win−1)×stride[1]−2×padding[1]+dilation[1]×(kernel_size[1]−1)+output_padding[1]+1
这里,为了使得模型更加简洁,我们不关心dialation
和output_shadding
参数。因而可以写作
H
o
u
t
=
(
H
i
n
−
1
)
×
stride
[
0
]
−
2
×
padding
[
0
]
+
kernel_size
[
0
]
H_{out} = (H_{in} - 1) times text{stride}[0] - 2 times text{padding}[0] + text{kernel_size}[0]
Hout=(Hin−1)×stride[0]−2×padding[0]+kernel_size[0]
W o u t = ( W i n − 1 ) × stride [ 1 ] − 2 × padding [ 1 ] + kernel_size [ 1 ] W_{out} = (W_{in} - 1) times text{stride}[1] - 2 times text{padding}[1] + text{kernel_size}[1] Wout=(Win−1)×stride[1]−2×padding[1]+kernel_size[1]
可以联想一下,卷积操作的输入和输出关系
H
o
u
t
=
[
H
i
n
−
k
e
r
n
e
l
_
s
i
z
e
[
0
]
+
2
∗
p
a
d
d
l
i
n
g
[
0
]
s
t
r
i
d
e
[
0
]
]
+
1
H_{out} = left[frac{H_{in}-kernel_size[0] + 2*paddling[0]}{stride[0]}right] + 1
Hout=[stride[0]Hin−kernel_size[0]+2∗paddling[0]]+1
可以看出,二者的大小形状完全是可逆的关系,通过代码进行说明
import torch
from torch import nn
conv = nn.Conv2d(3, 5, 5, padding=1)
tconv = nn.ConvTranspose2d(5, 3, 5, padding=1)
input = torch.randn((1, 3, 7, 9))
output = conv(input)
print('输入的维度', input.shape) # (1, 3, 7, 9)
print('卷积后的维度', output.shape) # 进行变换 (1, 5, 5, 7)
tinput = tconv(output)
print('经过逆卷积后的维度', tinput.shape) #->(1, 3, 7, 9) 与原始图像input的维度相同
计算过程详解
ConvTranspose2d原理,深度网络如何进行上采样?搭配了动图解释计算过程,比较直观。这里通过自写的模块Mytranspose2d
来具体说明计算过程,并与标准模块的计算结果进行了对比,可以搭配着看,更好的理解。
模块定义如下:
class MyTranspose2d(nn.Module):
# 这里为了简单起见,kernel_size限制为int 类型,对应的kernel大小为(kernel_size, kernel_size)
# 模型只是为了说明前馈计算的流程,不考虑效率和易用性
def __init__(self, inchannel, outchannel, kernel_size, padding=0, stride=1):
super().__init__()
self.weight = nn.Parameter(torch.zeros((inchannel, outchannel, kernel_size, kernel_size)))
self.bias = nn.Parameter(torch.zeros((outchannel, )))
self.padding = padding
self.stride = stride
self.outchannel = outchannel
self.inchannel = inchannel
self.F = kernel_size
def forward(self, input):
# input的维度为 N, C, H, W
N, C, H, W = input.shape
assert(C == self.inchannel)
Co, Ho, Wo = self.outchannel, (H-1)*self.stride-2*self.padding + self.F, (W-1)*self.stride-2*self.padding + self.F
output = torch.zeros((N, Co, Ho, Wo))
# 对输入进行补0,方便后续进行卷积操作
padding_input = torch.zeros((N, C, (H-1)*self.stride-1+2*self.F-2*self.padding, (W-1)*self.stride-1+2*self.F-2*self.padding))
for i in range(N):
for c in range(C):
for j in range(H):
for k in range(W):
jp = self.F-1-self.padding + j * self.stride
kp = self.F-1-self.padding + k * self.stride
padding_input[i, c, jp, kp] = input[i, c, j, k]
for i in range(N):
for c in range(Co):
for j in range(Ho):
for k in range(Wo):
jf = j + self.F
kf = k + self.F
# 由于是卷积,所以这里要Flip一下
output[i, c, j, k] = (self.weight[:, c, :, :].flip([1, 2]) * padding_input[i, :, j:jf, k:kf]).sum() + self.bias[c]
return output
计算结果的比较
# 标准的数据输入
inchannel = 1
outchanel = 1
kernel_size = 5
padding = 2
stride = 2
x = torch.randn((2, inchannel, 2, 2)) # 输入
convt = nn.ConvTranspose2d(inchannel, outchanel, kernel_size, padding=padding, stride=stride)
convm = MyTranspose2d(inchannel, outchanel, kernel_size, padding=padding, stride=stride)
# 确保二者的权重参数一致
convm.weight.data = convt.weight.data.clone().detach()
convm.bias.data = convt.bias.data.clone().detach()
# 计算模型的输出
out1 = convt(x)
out2 = convm(x)
err = out1 - out2
print(err.abs().max()) # 2.98e-8
与标准库模型的计算结果最大误差为2.98e-8
,说明计算结果是正确的。
最后
以上就是坚定大神为你收集整理的理解ConvTranspose2d操作理解ConvTranspose2d操作的全部内容,希望文章能够帮你解决理解ConvTranspose2d操作理解ConvTranspose2d操作所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复