DilatedConv、GroupConv,膨胀卷积、组卷积,源码:torch.nn.Conv2d
复制代码
1
2torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)
当dilation > 1时,卷积核不紧凑
复制代码
1
2
3
4
5
6
7
8
9
10
11import torch import torch.nn as nn import torch.nn.functional as F a = torch.randn(7, 7) print(f"a: {a}") a[0:3, 0:3] # dilation=1 a[0:5:2, 0:5:2] # dilation=2 a[0:7:3, 0:7:3] # dilation=3 # a[0:dx2+1:d]
当group>1时,分组卷积,再进行合
复制代码
1
2
3
4
5
6
7
8
9# group convolution in_channel, out_channel = 2, 4 # kernel:2x4,8个卷积核 groups = 2 # 每组有2个卷积组,一个4个卷积核,输入和输出大小不变 sub_in_channel, sub_out_channel = 1, 2 # 把结果拼起来,通道融合并不充分,只需要在每个group内进行融合,最后拼接 # 再使用1x1卷积,进行通道融合 1x1 point-wise convolution
Convolution with DilatedConv and GroupConv:
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62import torch import torch.nn as nn import torch.nn.functional as F def matrix_multiplication_for_conv2d_final(input, kernel, bias=0, stride=1, padding=0, dilation=1, groups=1): if padding > 0: # 从里到外,width、height、channel、batch input = F.pad(input, (padding, padding, padding, padding, 0, 0, 0, 0)) bs, in_channel, input_h, input_w = input.shape # kernel一共4维,包含通道融合的功能 out_channel, _, kernel_h, kernel_w = kernel.shape assert out_channel%groups==0 and in_channel%groups==0, "groups必须要同时被输入通道和输出通道数整除!" input = input.reshape((bs, groups, in_channel//groups, input_h, input_w)) kernel = kernel.reshape((groups, out_channel//groups, in_channel//groups, kernel_h, kernel_w)) kernel_h = (kernel_h-1)*(dilation-1) + kernel_h # 例如k=3,d=2,new_k=2x1+3=5 kernel_w = (kernel_w-1)*(dilation-1) + kernel_w if bias is None: bias = torch.zeros(out_channel) # 向下取整floor, 直接pad到input,不用padding output_h = (input_h - kernel_h) // stride + 1 # 卷积输出的高度 output_w = (input_w - kernel_w) // stride + 1 # 卷积输出的宽度 output = torch.zeros((bs, groups, out_channel//groups, output_h, output_w)) # 初始化输出矩阵 for ind in range(bs): # 对batchsize进行遍历 for g in range(groups): # 对群组进行遍历 for oc in range(out_channel//groups): # 对分组后的输出通道进行遍历 for ic in range(in_channel//groups): # 对分组后的输入通道进行遍历 for i in range(0, input_h-kernel_h+1, stride): # 对高度维进行遍历,input_h已经包括padding for j in range(0, input_w-kernel_w+1, stride): # 对宽度度维进行遍历 region = input[ind, g, ic, i:i+kernel_h:dilation, j:j+kernel_w:dilation] # 点乘,并且赋值输出位置的元素 output[ind, g, oc, i//stride, j//stride] += torch.sum(region * kernel[g, oc, ic]) output[ind, g, oc] += bias[g*(out_channel//groups) + oc] output = output.reshape((bs, out_channel, output_h, output_w)) return output # 以下为验证和测试的代码,验证与函数PyTorch API结果是否一致 bs, in_channel, input_h, input_w = 2, 2, 5, 5 kernel_size = 3 out_channel = 4 groups, dilation, stride, padding = 2, 2, 2, 1 input = torch.randn((bs, in_channel, input_h, input_w)) kernel = torch.randn((out_channel, in_channel//groups, kernel_size, kernel_size)) bias = torch.randn(out_channel) # PyTorch的官方API pytorch_conv2d_api_output = F.conv2d(input, kernel, bias=bias, padding=padding, stride=stride, dilation=dilation, groups=groups) mm_conv2d_final_output = matrix_multiplication_for_conv2d_final(input, kernel, bias=bias, padding=padding, stride=stride, dilation=dilation, groups=groups) print(f"pytorch_conv2d_api_output: {pytorch_conv2d_api_output}") print(f"mm_conv2d_final_output: {mm_conv2d_final_output}") flag = torch.allclose(pytorch_conv2d_api_output, mm_conv2d_final_output) print(f"flag: {flag}")
最后
以上就是光亮大象最近收集整理的关于PyTorch笔记 - Convolution卷积运算的原理 (5)的全部内容,更多相关PyTorch笔记内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复