我是靠谱客的博主 时尚月亮,最近开发中收集的这篇文章主要介绍胶囊神经网络资源转载和Pytorch实现胶囊神经网络资源转载和Pytorch实现,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

胶囊神经网络资源转载和Pytorch实现

  • 胶囊神经网络的简介, 介绍了CNN的缺点以及为啥么出现胶囊神经网络
  • 胶囊神经网络原理,用图像讲解
  • 胶囊神经网络原理,有代码实现,但有些错误
  • 胶囊神经网络动态路由算法
  • 胶囊神经网络Pytorch高赞实现, github

github那个代码是pytorch0.4之前才能用,所以这里附上博主pytorch1.6写的关键部分实现代码
首先是primary capsule layer

class PrimaryCapsuleLayer(nn.Module):
   
   def __init__(self, in_channels=256, out_channels=32, num_caps=8, kernel_size=9, stride=2):
       super().__init__()
       self.capsules = nn.ModuleList([
           nn.Conv2d(in_channels, out_channels, kernel_size, stride) for _ in range(num_caps)
       ])
       
   def _squash(self, x, dim=-1):
       # x_norm shape [B, C*H*W, 1]
       x_norm = torch.norm(x, p=2, dim=dim, keepdim=True)  # compute norm in the dim -1, namely across all capsules
       scale = x_norm**2 / (1 + x_norm**2)
       v = scale * x / x_norm
       return v
       
   def forward(self, x):
       # each capsule is a conv layer, outputs shape => list[[B, C*H*W, 1]]
       outputs = [capsule[x].reshape([x.shape[0], -1, 1]) for capsule in self.capsules]
       # shape: [B, C*H*W, num_caps]
       outputs = torch.cat(outputs, dim=-1)
       return self._squash(outputs)

然后是digit capsule layer

class DigitCapsuleLayer(nn.Module):
    
    def _squash(self, x, dim=-1):
        # x_norm shape [B, C*H*W, 1]
        x_norm = torch.norm(x, p=2, dim=dim, keepdim=True)  # compute norm in the dim -1, namely across all capsules
        scale = x_norm**2 / (1 + x_norm**2)
        v = scale * x / x_norm
        return v
    
    def __init__(self, num_caps, num_route_nodes, in_channels, out_channels, num_iterations=3):
        super().__init__()
        self.route_weights = nn.Parameter(torch.randn(num_caps, num_route_nodes, in_channels, out_channels))
        
    def forward(self, x):
        # shape [num_caps, B, C*H*W->num_route_nodes, 1, out_c]
        prior = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]
        # shape [num_caps, B, num_route_nodes, 1, out_c]
        logits = torch.zeros(*prior.shape)
        
        for i in range(self.num_iterations):
            # shape [num_caps, B, C*H*W->num_route_nodes, 1, out_c], detach the node!
            u_hat = prior.detach()
            # shape [num_caps, B, 1, 1, out_c]
            probs = torch.softmax(logits, dim=2)
            # shape [num_caps, B, 1, 1, out_c]
            outputs = self._squash((u_hat * probs).sum(dim=2, keepdim=True))
            # shape [num_caps, B, num_route_nodes, 1, 1]
            delta_logits = (u_hat * outputs).sum(dim=-1, keepdim=True)
            # shape [num_caps, B, num_route_nodes, 1, out_c]
            logits += delta_logits
        # after iteration, we get the correct logits
        probs = torch.softmax(logits, dim=2)
        # shape [num_caps, B, 1, 1, out_c]
        outputs = self._squash((prior * probs).sum(dim=2, keepdim=True))
        return outputs

最后是整个胶囊网络整合,这个部分就和github那个差不多,这里就没全部写上去了

class CapsuleNetwork(nn.Module):
    
    def __init__(self):
        NUM_CLASSES = 10
        super().__init__()
        # basic conv layer to extract fature maps from mnist image
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
        # primary capsule_layer
        self.primary_capsules = PrimaryCapsuleLayer(in_channels=256, out_channels=32,
                                                    num_caps=8, kernel_size=9, stride=2)
        # digit capsule
        self.digit_capsules = DigitCapsuleLayer(num_caps=NUM_CLASSES, num_route_nodes=32*6*6,
                                                in_channels=8, out_channels=16, num_iterations=3)
        
        # decoder to reconstruct the digit image
        self.decoder = nn.Sequential(
            nn.Linear(16 * NUM_CLASSES, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),  # reconstruction (28 * 28 = 784)
            nn.Sigmoid()
        )

    def forward(self, x, y=None):
        pass

最后

以上就是时尚月亮为你收集整理的胶囊神经网络资源转载和Pytorch实现胶囊神经网络资源转载和Pytorch实现的全部内容,希望文章能够帮你解决胶囊神经网络资源转载和Pytorch实现胶囊神经网络资源转载和Pytorch实现所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(57)

评论列表共有 0 条评论

立即
投稿
返回
顶部