概述
胶囊神经网络资源转载和Pytorch实现
- 胶囊神经网络的简介, 介绍了CNN的缺点以及为啥么出现胶囊神经网络
- 胶囊神经网络原理,用图像讲解
- 胶囊神经网络原理,有代码实现,但有些错误
- 胶囊神经网络动态路由算法
- 胶囊神经网络Pytorch高赞实现, github
github那个代码是pytorch0.4之前才能用,所以这里附上博主pytorch1.6写的关键部分实现代码:
首先是primary capsule layer
class PrimaryCapsuleLayer(nn.Module):
def __init__(self, in_channels=256, out_channels=32, num_caps=8, kernel_size=9, stride=2):
super().__init__()
self.capsules = nn.ModuleList([
nn.Conv2d(in_channels, out_channels, kernel_size, stride) for _ in range(num_caps)
])
def _squash(self, x, dim=-1):
# x_norm shape [B, C*H*W, 1]
x_norm = torch.norm(x, p=2, dim=dim, keepdim=True) # compute norm in the dim -1, namely across all capsules
scale = x_norm**2 / (1 + x_norm**2)
v = scale * x / x_norm
return v
def forward(self, x):
# each capsule is a conv layer, outputs shape => list[[B, C*H*W, 1]]
outputs = [capsule[x].reshape([x.shape[0], -1, 1]) for capsule in self.capsules]
# shape: [B, C*H*W, num_caps]
outputs = torch.cat(outputs, dim=-1)
return self._squash(outputs)
然后是digit capsule layer
class DigitCapsuleLayer(nn.Module):
def _squash(self, x, dim=-1):
# x_norm shape [B, C*H*W, 1]
x_norm = torch.norm(x, p=2, dim=dim, keepdim=True) # compute norm in the dim -1, namely across all capsules
scale = x_norm**2 / (1 + x_norm**2)
v = scale * x / x_norm
return v
def __init__(self, num_caps, num_route_nodes, in_channels, out_channels, num_iterations=3):
super().__init__()
self.route_weights = nn.Parameter(torch.randn(num_caps, num_route_nodes, in_channels, out_channels))
def forward(self, x):
# shape [num_caps, B, C*H*W->num_route_nodes, 1, out_c]
prior = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]
# shape [num_caps, B, num_route_nodes, 1, out_c]
logits = torch.zeros(*prior.shape)
for i in range(self.num_iterations):
# shape [num_caps, B, C*H*W->num_route_nodes, 1, out_c], detach the node!
u_hat = prior.detach()
# shape [num_caps, B, 1, 1, out_c]
probs = torch.softmax(logits, dim=2)
# shape [num_caps, B, 1, 1, out_c]
outputs = self._squash((u_hat * probs).sum(dim=2, keepdim=True))
# shape [num_caps, B, num_route_nodes, 1, 1]
delta_logits = (u_hat * outputs).sum(dim=-1, keepdim=True)
# shape [num_caps, B, num_route_nodes, 1, out_c]
logits += delta_logits
# after iteration, we get the correct logits
probs = torch.softmax(logits, dim=2)
# shape [num_caps, B, 1, 1, out_c]
outputs = self._squash((prior * probs).sum(dim=2, keepdim=True))
return outputs
最后是整个胶囊网络整合,这个部分就和github那个差不多,这里就没全部写上去了
class CapsuleNetwork(nn.Module):
def __init__(self):
NUM_CLASSES = 10
super().__init__()
# basic conv layer to extract fature maps from mnist image
self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
# primary capsule_layer
self.primary_capsules = PrimaryCapsuleLayer(in_channels=256, out_channels=32,
num_caps=8, kernel_size=9, stride=2)
# digit capsule
self.digit_capsules = DigitCapsuleLayer(num_caps=NUM_CLASSES, num_route_nodes=32*6*6,
in_channels=8, out_channels=16, num_iterations=3)
# decoder to reconstruct the digit image
self.decoder = nn.Sequential(
nn.Linear(16 * NUM_CLASSES, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 784), # reconstruction (28 * 28 = 784)
nn.Sigmoid()
)
def forward(self, x, y=None):
pass
最后
以上就是时尚月亮为你收集整理的胶囊神经网络资源转载和Pytorch实现胶囊神经网络资源转载和Pytorch实现的全部内容,希望文章能够帮你解决胶囊神经网络资源转载和Pytorch实现胶囊神经网络资源转载和Pytorch实现所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复