我是靠谱客的博主 悲凉小蜜蜂,最近开发中收集的这篇文章主要介绍复现Dense Extreme Inception Network(pytorch)文章目录 摘要 模型结构 论文实验结果 相关代码 摘要模型结构,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

github地址:https://github.com/xavysp/DexiNed/tree/master/DexiNed-Pytorch
论文地址:https://arxiv.org/abs/1909.01955
数据集:https://www.kaggle.com/xavysp/biped

文章目录
摘要
模型结构
论文实验结果
相关代码
摘要


这篇paper是基于深度学习的边缘检测算法,受到HED(Holistically-Nested Edge Detection)和Xception 网络的启发。该方法生成人眼可能看到的薄边缘地图,可以用于任何边缘检测任务,无需经过长时间训练或微调过程。

该论文的主要贡献:提出了一种鲁棒的CNN边缘检测架构,简称为DexiNed:Dense Extreme Inception Network for Edge Detection。这个模型是从头开始训练的,没有预先训练过的权重。

模型结构

论文实验结果

相关代码
config.py

class Config(object):
    #dataset
    mean_pixel_values = [104.00699, 116.66877, 122.67892, 137.86]
    img_width = 400
    img_height = 400
    train_root = 'data/BIPED/edges/imgs/train/rgbr/real'
    valid_root = 'data/BIPED/edges/edge_maps/train/rgbr/real'
    valid_output_dir = 'valid_temp'
    # hyper parameters
    batch_size = 2
    num_workers = 0
    num_epochs = 25

    model_output = 'result'



dataset.py

from torch.utils.data import DataLoader, Dataset
import torch
import cv2 as cv
import numpy as np
import os


class BIPEDDataset(Dataset):
    def __init__(self, img_root, mode='train', config=None):
        self.img_root = img_root
        self.mode = mode
        self.imgList = os.listdir(img_root)
        self.config = config
        self.mean_bgr = config.mean_pixel_values[0:3] if len(config.mean_pixel_values) == 4 
            else config.mean_pixel_values

    def __len__(self):
        return len(self.imgList)

    def __getitem__(self, idx):
        file_name = self.imgList[idx].split('.')[0]
        imgPath = os.path.join(self.img_root, self.imgList[idx])
        labelPath = imgPath.replace('imgs', 'edge_maps').replace('jpg', 'png')

        #load data
        image = cv.imread(imgPath, cv.IMREAD_COLOR)
        label = cv.imread(labelPath, cv.IMREAD_GRAYSCALE)
        image_shape = [image.shape[0], image.shape[1]]
        image, label = self.transform(img=image, gt=label)
        return dict(images=image, labels=label, file_name=file_name, image_shape=image_shape)

    def transform(self, img, gt):

        gt = np.array(gt, dtype=np.float32)
        if len(gt.shape) == 3:
            gt = gt[:, :, 0]
        # gt[gt< 51] = 0 # test without gt discrimination
        gt /= 255.
        # if self.yita is not None:
        #     gt[gt >= self.yita] = 1
        img = np.array(img, dtype=np.float32)
        # if self.rgb:
        #     img = img[:, :, ::-1]  # RGB->BGR
        img -= self.mean_bgr
        # data = []
        # if self.scale is not None:
        #     for scl in self.scale:
        #         img_scale = cv.resize(img, None, fx=scl, fy=scl, interpolation=cv.INTER_LINEAR)
        #         data.append(torch.from_numpy(img_scale.transpose((2, 0, 1))).float())
        #     return data, gt

        img = cv.resize(img, dsize=(self.config.img_width, self.config.img_height))
        gt = cv.resize(gt, dsize=(self.config.img_width, self.config.img_height))
        img = img.transpose((2, 0, 1))
        img = torch.from_numpy(img.copy()).float()
        gt = torch.from_numpy(np.array([gt])).float()
        return img, gt


if __name__=='__main__':
    from config import Config
    cfg = Config()
    root = 'data/BIPED/edges/imgs/train/rgbr/real'
    train_dataset = BIPEDDataset(root, config=cfg)
    train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0)
    for data_batch in train_loader:
        img, label = data_batch['images'], data_batch['labels']
        print(img.size(), label.size(),  data_batch['file_name'])



loss.py

import torch
import torch.nn.functional as F

def _weighted_cross_entropy_loss(preds, edges):
    """ Calculate sum of weighted cross entropy loss. """
    # Reference:
    #   hed/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
    #   https://github.com/s9xie/hed/issues/7
    mask = (edges > 0.5).float()
    b, c, h, w = mask.shape
    num_pos = torch.sum(mask, dim=[1, 2, 3]).float()  # Shape: [b,].
    num_neg = c * h * w - num_pos                     # Shape: [b,].
    weight = torch.zeros_like(mask)
    weight[edges > 0.5]  = num_neg / (num_pos + num_neg)
    weight[edges <= 0.5] = num_pos / (num_pos + num_neg)
    # Calculate loss.
    losses = F.binary_cross_entropy_with_logits(
        preds.float(), edges.float(), weight=weight, reduction='none')
    loss = torch.sum(losses) / b
    return loss

def weighted_cross_entropy_loss(preds, edges):
    """ Calculate sum of weighted cross entropy loss. """
    # Reference:
    #   hed/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
    #   https://github.com/s9xie/hed/issues/7
    mask = (edges > 0.5).float()
    b, c, h, w = mask.shape
    num_pos = torch.sum(mask, dim=[1, 2, 3], keepdim=True).float()  # Shape: [b,].
    num_neg = c * h * w - num_pos                     # Shape: [b,].
    weight = torch.zeros_like(mask)
    #weight[edges > 0.5]  = num_neg / (num_pos + num_neg)
    #weight[edges <= 0.5] = num_pos / (num_pos + num_neg)
    weight.masked_scatter_(edges > 0.5,
        torch.ones_like(edges) * num_neg / (num_pos + num_neg))
    weight.masked_scatter_(edges <= 0.5,
        torch.ones_like(edges) * num_pos / (num_pos + num_neg))
    # Calculate loss.
    # preds=torch.sigmoid(preds)
    losses = F.binary_cross_entropy_with_logits(
        preds.float(), edges.float(), weight=weight, reduction='none')
    loss = torch.sum(losses) / b
    return loss


model.py

import torch
import torch.nn as nn
import torch.nn.functional as F

class _DenseLayer(nn.Sequential):
    def __init__(self, input_features, out_features):
        super(_DenseLayer, self).__init__()
        # self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(input_features, out_features,
                        kernel_size=1, stride=1, bias=True)),
        self.add_module('norm1', nn.BatchNorm2d(out_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(out_features, out_features,
                        kernel_size=3, stride=1, padding=1, bias=True)),
        self.add_module('norm2', nn.BatchNorm2d(out_features))
        # double check the norm1 comment if necessary and put norm after conv2

    def forward(self, x):
        x1, x2 = x
        # maybe I should put here a RELU
        new_features = super(_DenseLayer, self).forward(x1) # F.relu()
        return 0.5 * (new_features + x2), x2

class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, input_features, out_features):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(input_features, out_features)
            self.add_module('denselayer%d' % (i + 1), layer)
            input_features = out_features

class UpConvBlock(nn.Module):
    def __init__(self, in_features, up_scale, mode='deconv'):
        super(UpConvBlock, self).__init__()
        self.up_factor = 2
        self.constant_features = 16

        layers = None
        if mode == 'deconv':
            layers = self.make_deconv_layers(in_features, up_scale)
        elif mode == 'pixel_shuffle':
            layers = self.make_pixel_shuffle_layers(in_features, up_scale)
        assert layers is not None, layers
        self.features = nn.Sequential(*layers)

    def make_deconv_layers(self, in_features, up_scale):
        layers = []
        for i in range(up_scale):
            kernel_size = 2 ** up_scale
            out_features = self.compute_out_features(i, up_scale)
            layers.append(nn.Conv2d(in_features, out_features, 1))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.ConvTranspose2d(
                out_features, out_features, kernel_size, stride=2))
            in_features = out_features
        return layers

    def make_pixel_shuffle_layers(self, in_features, up_scale):
        layers = []
        for i in range(up_scale):
            kernel_size = 2 ** (i + 1)
            out_features = self.compute_out_features(i, up_scale)
            in_features = int(in_features / (self.up_factor ** 2))
            layers.append(nn.PixelShuffle(self.up_factor))
            layers.append(nn.Conv2d(in_features, out_features, 1))
            if i < up_scale:
                layers.append(nn.ReLU(inplace=True))
            in_features = out_features
        return layers

    def compute_out_features(self, idx, up_scale):
        return 1 if idx == up_scale - 1 else self.constant_features

    def forward(self, x):
        return self.features(x)

class SingleConvBlock(nn.Module):
    def __init__(self, in_features, out_features, stride):
        super(SingleConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride)
        self.bn = nn.BatchNorm2d(out_features)

    def forward(self, x):
        return self.bn(self.conv(x))

class DoubleConvBlock(nn.Module):
    def __init__(self, in_features, mid_features,out_features=None, stride=1):
        super(DoubleConvBlock, self).__init__()
        if out_features is None:
            out_features = mid_features
        self.conv1 = nn.Conv2d(
            in_features, mid_features, 3, padding=1, stride=stride)
        self.bn1 = nn.BatchNorm2d(mid_features)
        self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_features)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

class DexiNet(nn.Module):
    """ Definition of the DXtrem network. """
    def __init__(self):
        super(DexiNet, self).__init__()
        self.block_1 = DoubleConvBlock(3, 32, 64, stride=2)
        self.block_2 = DoubleConvBlock(64, 128)
        self.dblock_3 = _DenseBlock(2, 128, 256)
        self.dblock_4 = _DenseBlock(3, 256, 512)
        self.dblock_5 = _DenseBlock(3, 512, 512)
        self.dblock_6 = _DenseBlock(3, 512, 256)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.side_1 = SingleConvBlock(64, 128, 2)
        self.side_2 = SingleConvBlock(128, 256, 2)
        self.side_3 = SingleConvBlock(256, 512, 2)
        self.side_4 = SingleConvBlock(512, 512, 1)
        self.side_5 = SingleConvBlock(512, 256, 1)

        self.pre_dense_2 = SingleConvBlock(128, 256, 2) # by me, for left skip block4
        self.pre_dense_3 = SingleConvBlock(128, 256, 1)
        self.pre_dense_4 = SingleConvBlock(256, 512, 1)
        self.pre_dense_5_0 = SingleConvBlock(256, 512, 2)
        self.pre_dense_5 = SingleConvBlock(512, 512, 1)
        self.pre_dense_6 = SingleConvBlock(512, 256, 1)

        self.up_block_1 = UpConvBlock(64, 1)
        self.up_block_2 = UpConvBlock(128, 1)
        self.up_block_3 = UpConvBlock(256, 2)
        self.up_block_4 = UpConvBlock(512, 3)
        self.up_block_5 = UpConvBlock(512, 4)
        self.up_block_6 = UpConvBlock(256, 4)
        self.block_cat = nn.Conv2d(6, 1, kernel_size=1)

    def slice(self, tensor, slice_shape):
        height, width = slice_shape
        return tensor[..., :height, :width]

    def forward(self, x):
        assert len(x.shape) == 4, x.shape
        # Block 1
        block_1 = self.block_1(x)
        block_1_side = self.side_1(block_1)

        # Block 2
        block_2 = self.block_2(block_1)
        block_2_down = self.maxpool(block_2)
        block_2_add = block_2_down + block_1_side
        block_2_side = self.side_2(block_2_add)

        # Block 3
        block_3_pre_dense = self.pre_dense_3(block_2_down)
        block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense])
        block_3_down = self.maxpool(block_3)
        block_3_add = block_3_down + block_2_side
        block_3_side = self.side_3(block_3_add)

        # Block 4
        block_4_pre_dense_256 = self.pre_dense_2(block_2_down)
        block_4_pre_dense = self.pre_dense_4(block_4_pre_dense_256 + block_3_down)
        block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense])
        block_4_down = self.maxpool(block_4)
        block_4_add = block_4_down + block_3_side
        block_4_side = self.side_4(block_4_add)

        # Block 5
        block_5_pre_dense_512 = self.pre_dense_5_0(block_4_pre_dense_256)
        block_5_pre_dense = self.pre_dense_5(block_5_pre_dense_512 + block_4_down )
        block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense])
        block_5_add = block_5 + block_4_side
#        block_5_side = self.side_5(block_5_add)

        # Block 6
        block_6_pre_dense = self.pre_dense_6(block_5)
#        block_5_pre_dense_256 = self.pre_dense_6(block_5_add) # if error uncomment
        block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense])

        # upsampling blocks
        height, width = x.shape[-2:]
        slice_shape = (height, width)
        out_1 = self.slice(self.up_block_1(block_1), slice_shape)
        out_2 = self.slice(self.up_block_2(block_2), slice_shape)
        out_3 = self.slice(self.up_block_3(block_3), slice_shape)
        out_4 = self.slice(self.up_block_4(block_4), slice_shape)
        out_5 = self.slice(self.up_block_5(block_5), slice_shape)
        out_6 = self.slice(self.up_block_6(block_6), slice_shape)
        results = [out_1, out_2, out_3, out_4, out_5, out_6]
        # print(out_1.shape)
        # concatenate multiscale outputs
        block_cat = torch.cat(results, dim=1)  # Bx6xHxW
        block_cat = self.block_cat(block_cat)  # Bx1xHxW

        # return results
        results.append(block_cat)
        return results


main.py

from torch import nn
from torch.utils.data import DataLoader
from dataset import BIPEDDataset
from losses import *
from config import Config
from cyclicLR import CyclicCosAnnealingLR
from model import DexiNet
import torchgeometry as tgm
import numpy as np
import time
import os
import cv2 as cv
import tqdm


def weight_init(m):
    if isinstance(m, (nn.Conv2d, )):

        torch.nn.init.normal_(m.weight,mean=0, std=0.01)
        if m.weight.data.shape[1]==torch.Size([1]):
            torch.nn.init.normal_(m.weight, mean=0.0,)
        if m.weight.data.shape==torch.Size([1,6,1,1]):
            torch.nn.init.constant_(m.weight,0.2)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
    # for fusion layer
    if isinstance(m, (nn.ConvTranspose2d,)):

        torch.nn.init.normal_(m.weight,mean=0, std=0.01)
        if m.weight.data.shape[1] == torch.Size([1]):
            torch.nn.init.normal_(m.weight, std=0.1)

        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)


class Trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = DexiNet().to(self.device).apply(weight_init)
        self.criterion = weighted_cross_entropy_loss
        self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=0.003, weight_decay=0.0001)
        milestones = [5 + x * 30 for x in range(5)]
        self.scheduler = CyclicCosAnnealingLR(self.optimizer, milestones=milestones, eta_min=5e-5)
        mkdir(cfg.model_output)

    def build_loader(self):
        train_dataset = BIPEDDataset(self.cfg.train_root, config=self.cfg)
        valid_dataset = BIPEDDataset(self.cfg.valid_root, config=self.cfg)

        train_loader = DataLoader(train_dataset,
                                  batch_size=self.cfg.batch_size,
                                  num_workers=self.cfg.num_workers,
                                  shuffle=True)
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=self.cfg.batch_size,
                                  num_workers=self.cfg.num_workers,
                                  shuffle=False)
        return train_loader, valid_loader

    def train_one_epoch(self, epoch, dataloader):
        self.model.train()
        for batch_id, sample_batched in tqdm.tqdm(enumerate(dataloader)):
            images = sample_batched['images'].to(self.device)  # BxCxHxW
            labels = sample_batched['labels'].to(self.device)  # BxHxW

            preds_list = self.model(images)
            loss = sum([self.criterion(preds, labels) for preds in preds_list])
            loss /= images.shape[0] # the bacth size

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            print(time.ctime(), 'training, Epoch: {0} Sample {1}/{2} Loss: {3}' 
                  .format(epoch, batch_id, len(dataloader), loss.item()), end='r')

    def validation(self, epoch, dataloader):
        self.model.eval()
        for batch_id, sample_batched in enumerate(dataloader):
            images = sample_batched['images'].to(self.device)  # BxCxHxW
            labels = sample_batched['labels'].to(self.device)  # BxHxW
            file_name = sample_batched['file_name']

            preds_list = self.model(images)
            loss = sum([self.criterion(preds, labels) for preds in preds_list])
            loss /= images.shape[0]  # the bacth size

            print(time.ctime(), 'validation, Epoch: {0} Sample {1}/{2} Loss: {3}' 
                  .format(epoch, batch_id, len(dataloader), loss.item()), end='r')

            self.save_image_bacth_to_disk(preds_list[-1], file_name)
            return loss

    def save_image_bacth_to_disk(self, tensor, file_names):
        output_dir = self.cfg.valid_output_dir
        mkdir(output_dir)
        assert len(tensor.shape) == 4, tensor.shape
        for tensor_image, file_name in zip(tensor, file_names):
            image_vis = tgm.utils.tensor_to_image(torch.sigmoid(tensor_image))[..., 0]
            image_vis = (255.0 * (1.0 - image_vis)).astype(np.uint8)  #
            output_file_name = os.path.join(output_dir, f"{file_name}.png")
            cv.imwrite(output_file_name, image_vis)

    def train(self):
        train_loader, valid_loader = self.build_loader()
        best_loss = 1000000
        for epoch in range(self.cfg.num_epochs):
            self.scheduler.step(epoch)

            self.model.train()
            for batch_id, sample_batched in enumerate(train_loader):
                images = sample_batched['images'].to(self.device)  # BxCxHxW
                labels = sample_batched['labels'].to(self.device)  # BxHxW

                preds_list = self.model(images)
                loss = sum([self.criterion(preds, labels) for preds in preds_list])
                loss /= images.shape[0]  # the bacth size

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                print(time.ctime(), 'training, Epoch: {0} Sample {1}/{2} Loss: {3}' 
                      .format(epoch, batch_id, len(train_loader), loss.item()), end='r')

            valid_loss = self.validation(epoch, valid_loader)
            if valid_loss < best_loss:
                torch.save(self.model, os.path.join(self.cfg.model_output, f'epoch{epoch}_model.pth'))
                print(f'find optimal model, loss {best_loss}==>{valid_loss}')
                best_loss = valid_loss


if __name__=='__main__':
    config = Config()
    trainer = Trainer(config)
    trainer.train()


predict.py

#!/usr/bin/env python
# coding: utf-8
from PIL import Image
import cv2
from path import Path
from eval.dataset import SasDataset
import torch
import torch.backends.cudnn
from torch.utils.data import DataLoader
import os
import tqdm
import numpy as np
import argparse
import shutil
import glob
import gdal
import tqdm
from ndvi import *
from skimage import morphology
from scipy import ndimage as ndi
from config_eval import Config
from pytorch_toolbelt.inference import tta
from skimage.morphology import dilation, erosion, closing, square
import torchgeometry as tgm

# need to create a file to store temp pictures
path = './temp_pic/'

device = 'cuda:0'
dark = [0,0,0]

cfg =Config()

# 水平翻转
def flip_horizontal_tensor(batch):
    columns = batch.data.size()[-1]
    return batch.index_select(-1, torch.LongTensor(list(reversed(range(columns)))).cuda())


#   垂直翻转
def flip_vertical_tensor(batch):
    rows = batch.data.size()[-2]
    return batch.index_select(-2, torch.LongTensor(list(reversed(range(rows)))).cuda())

def input_and_output(pic_path, model, loader=None, generate_data=True):
    """
    args:
        pic_path : the picture you want to predict
        model    : the model you want to predict
    note:
        step one : generate some pictures from one picture
        step two : predict from the images generated by step one 
    """
    image_size = cfg.crop_size

    data = gdal.Open(pic_path)
    lastChannel = data.RasterCount + 1
    arr = [data.GetRasterBand(idx).ReadAsArray() for idx in range(1, lastChannel)]
    data = np.dstack(arr)

    raw_h, raw_w = data.shape[:2]

    b = cfg.padding_size
    row = raw_h // image_size + 1
    col = raw_w // image_size + 1
    radius_h = row * image_size - raw_h
    radius_w = col * image_size - raw_w
    image = cv2.copyMakeBorder(data, 0, radius_h, 0, radius_w, cv2.BORDER_REFLECT)

    image = cv2.copyMakeBorder(image, b, b, b, b, cv2.BORDER_REFLECT)

    h, w = image.shape[:2]

    padding_img = image[:, :, :]

    padding_img = np.array(padding_img)
    mask_whole = np.zeros((row*image_size, col*image_size), dtype=np.uint8)
    if generate_data == False:
        print('starting prediction')
        result = []
        for batch in tqdm.tqdm(loader):
            images = batch['img'].to(device, dtype=torch.float)
            temp = 0
            for keys in model.keys():
                # model[keys].eval()
                net = model[keys]
                if cfg.TTA:
                    outputs = tta.fliplr_image2label(net, images)
                else:
                    outputs = net(images)
                tensor = outputs[-1][0, ...]
                image_vis = tgm.utils.tensor_to_image(torch.sigmoid(tensor))[..., 0]
                image_vis = (255.0 * (1.0 - image_vis)).astype(np.uint8)  #
                # outputs = torch.squeeze(outputs).detach().cpu().numpy()
                temp += image_vis
            preds = temp / len(model)
            # preds = torch.from_numpy(preds)
            result.append(preds)
        map_list = [str(i.name) for i in Path('temp_pic').files()]
    for i in tqdm.tqdm(range(row)):
        for j in range(col):
            if generate_data:
                crop_img = redundancy_crop(padding_img, i, j, image_size)
                ch,cw,_ = crop_img.shape
                cv2.imwrite(f'temp_pic/{i}_{j}.tif', crop_img)
            else:
                temp = result[map_list.index(f'{i}_{j}.tif')]
                temp = redundancy_crop2(temp)
                mask_whole[i*image_size:i*image_size+image_size, j*image_size:j*image_size+image_size] = temp
    return mask_whole[:raw_h, :raw_w]


def redundancy_crop(img, i, j, targetSize):
    if len(img.shape)>2:
        temp_img = img[i*targetSize:i*targetSize+targetSize+2*cfg.padding_size, j*targetSize:j*targetSize+targetSize+2*cfg.padding_size, :]
    else:
        temp_img = img[i*targetSize:i*targetSize+targetSize+2*cfg.padding_size, j*targetSize:j*targetSize+targetSize+2*cfg.padding_size]
    return temp_img


def redundancy_crop2(img):
    h = img.shape[0]
    w = img.shape[1]
    temp_img = img[cfg.padding_size:h-cfg.padding_size, cfg.padding_size:w-cfg.padding_size]
    return temp_img


def get_dataset_loaders():
    batch_size = 1

    test_dataset = SasDataset(
        "./temp_pic",
        mode='test'
    )

    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0)
    return test_loader


def get_labels():
    """Load the mapping that associates pascal classes with label colors

    Returns:
        np.ndarray with dimensions (2, 3)
    """
    return np.asarray(
        [
            [0, 0, 0],
            [255, 255, 255]
        ]
    )


def decode_segmap(label_mask, n_classes):
    """Decode segmentation class labels into a color image

    Args:
        label_mask (np.ndarray): an (M,N) array of integer values denoting
          the class label at each spatial location.
        plot (bool, optional): whether to show the resulting color image
          in a figure.

    Returns:
        (np.ndarray, optional): the resulting decoded color image.
    """
    label_colours = get_labels()
    r = label_mask.copy()
    g = label_mask.copy()
    b = label_mask.copy()
    for ll in range(0, n_classes):
        r[label_mask == ll] = label_colours[ll, 0]
        g[label_mask == ll] = label_colours[ll, 1]
        b[label_mask == ll] = label_colours[ll, 2]
    rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
    rgb[:, :, 0] = r
    rgb[:, :, 1] = g
    rgb[:, :, 2] = b
    return rgb


def mkdir(path):
    if not os.path.exists(path):
        os.mkdir(path)


def predict():
    model_groups = glob.glob(cfg.model_path+'/*.pth')
    imgList = glob.glob(cfg.data_path + '/*.tif')
    num = len(imgList)

    # predict on more model
    print('loading models')
    models = {}
    for index, item in enumerate(model_groups):

        models[item] = torch.load(item, map_location='cuda:0')
        # models[item] = torch.load(item, map_location='cuda:0')["ema_state_dict"]

    # model = torch.load(f'./results_{args.model_name}/{args.model_name}_weights_best.pth')["model_state"]

    for i in tqdm.tqdm(range(num)):
        if not os.path.exists('temp_pic'):
            os.makedirs('temp_pic')

        input_and_output(imgList[i], models, generate_data=True)
        name = os.path.split(imgList[i])[-1].split(".")[0]
        test_loader = get_dataset_loaders()
        mask_result = input_and_output(imgList[i], models, loader=test_loader, generate_data=False)
        # 递归删除文件夹
        try:
            shutil.rmtree('temp_pic')
        except:
            pass

        # mask_result = closing(mask_result, square(20))
        mask_result = mask_result.astype(np.uint8)

        mkdir(cfg.save_path)
        cv2.imwrite(os.path.join(cfg.save_path, name + '_predict.tif'), mask_result)


if __name__ == '__main__':
    predict()


遥感影像切片
预测结果
原文链接:https://blog.csdn.net/weixin_42990464/article/details/107243620

最后

以上就是悲凉小蜜蜂为你收集整理的复现Dense Extreme Inception Network(pytorch)文章目录 摘要 模型结构 论文实验结果 相关代码 摘要模型结构的全部内容,希望文章能够帮你解决复现Dense Extreme Inception Network(pytorch)文章目录 摘要 模型结构 论文实验结果 相关代码 摘要模型结构所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(44)

评论列表共有 0 条评论

立即
投稿
返回
顶部