概述
DilatedConv、GroupConv,膨胀卷积、组卷积,源码:torch.nn.Conv2d
torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)
当dilation > 1时,卷积核不紧凑
import torch
import torch.nn as nn
import torch.nn.functional as F
a = torch.randn(7, 7)
print(f"a: {a}")
a[0:3, 0:3] # dilation=1
a[0:5:2, 0:5:2] # dilation=2
a[0:7:3, 0:7:3] # dilation=3
# a[0:dx2+1:d]
当group>1时,分组卷积,再进行合
# group convolution
in_channel, out_channel = 2, 4
# kernel:2x4,8个卷积核
groups = 2
# 每组有2个卷积组,一个4个卷积核,输入和输出大小不变
sub_in_channel, sub_out_channel = 1, 2
# 把结果拼起来,通道融合并不充分,只需要在每个group内进行融合,最后拼接
# 再使用1x1卷积,进行通道融合 1x1 point-wise convolution
Convolution with DilatedConv and GroupConv:
import torch
import torch.nn as nn
import torch.nn.functional as F
def matrix_multiplication_for_conv2d_final(input, kernel, bias=0, stride=1, padding=0, dilation=1, groups=1):
if padding > 0:
# 从里到外,width、height、channel、batch
input = F.pad(input, (padding, padding, padding, padding, 0, 0, 0, 0))
bs, in_channel, input_h, input_w = input.shape
# kernel一共4维,包含通道融合的功能
out_channel, _, kernel_h, kernel_w = kernel.shape
assert out_channel%groups==0 and in_channel%groups==0, "groups必须要同时被输入通道和输出通道数整除!"
input = input.reshape((bs, groups, in_channel//groups, input_h, input_w))
kernel = kernel.reshape((groups, out_channel//groups, in_channel//groups, kernel_h, kernel_w))
kernel_h = (kernel_h-1)*(dilation-1) + kernel_h # 例如k=3,d=2,new_k=2x1+3=5
kernel_w = (kernel_w-1)*(dilation-1) + kernel_w
if bias is None:
bias = torch.zeros(out_channel)
# 向下取整floor, 直接pad到input,不用padding
output_h = (input_h - kernel_h) // stride + 1 # 卷积输出的高度
output_w = (input_w - kernel_w) // stride + 1 # 卷积输出的宽度
output = torch.zeros((bs, groups, out_channel//groups, output_h, output_w)) # 初始化输出矩阵
for ind in range(bs): # 对batchsize进行遍历
for g in range(groups): # 对群组进行遍历
for oc in range(out_channel//groups): # 对分组后的输出通道进行遍历
for ic in range(in_channel//groups): # 对分组后的输入通道进行遍历
for i in range(0, input_h-kernel_h+1, stride): # 对高度维进行遍历,input_h已经包括padding
for j in range(0, input_w-kernel_w+1, stride): # 对宽度度维进行遍历
region = input[ind, g, ic, i:i+kernel_h:dilation, j:j+kernel_w:dilation]
# 点乘,并且赋值输出位置的元素
output[ind, g, oc, i//stride, j//stride] += torch.sum(region * kernel[g, oc, ic])
output[ind, g, oc] += bias[g*(out_channel//groups) + oc]
output = output.reshape((bs, out_channel, output_h, output_w))
return output
# 以下为验证和测试的代码,验证与函数PyTorch API结果是否一致
bs, in_channel, input_h, input_w = 2, 2, 5, 5
kernel_size = 3
out_channel = 4
groups, dilation, stride, padding = 2, 2, 2, 1
input = torch.randn((bs, in_channel, input_h, input_w))
kernel = torch.randn((out_channel, in_channel//groups, kernel_size, kernel_size))
bias = torch.randn(out_channel)
# PyTorch的官方API
pytorch_conv2d_api_output = F.conv2d(input, kernel, bias=bias, padding=padding,
stride=stride, dilation=dilation, groups=groups)
mm_conv2d_final_output = matrix_multiplication_for_conv2d_final(input, kernel, bias=bias, padding=padding,
stride=stride, dilation=dilation, groups=groups)
print(f"pytorch_conv2d_api_output: {pytorch_conv2d_api_output}")
print(f"mm_conv2d_final_output: {mm_conv2d_final_output}")
flag = torch.allclose(pytorch_conv2d_api_output, mm_conv2d_final_output)
print(f"flag: {flag}")
最后
以上就是光亮大象为你收集整理的PyTorch笔记 - Convolution卷积运算的原理 (5)的全部内容,希望文章能够帮你解决PyTorch笔记 - Convolution卷积运算的原理 (5)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复