概述
github地址:https://github.com/xavysp/DexiNed/tree/master/DexiNed-Pytorch
论文地址:https://arxiv.org/abs/1909.01955
数据集:https://www.kaggle.com/xavysp/biped
文章目录
摘要
模型结构
论文实验结果
相关代码
摘要
这篇paper是基于深度学习的边缘检测算法,受到HED(Holistically-Nested Edge Detection)和Xception 网络的启发。该方法生成人眼可能看到的薄边缘地图,可以用于任何边缘检测任务,无需经过长时间训练或微调过程。
该论文的主要贡献:提出了一种鲁棒的CNN边缘检测架构,简称为DexiNed:Dense Extreme Inception Network for Edge Detection。这个模型是从头开始训练的,没有预先训练过的权重。
模型结构
论文实验结果
相关代码
config.py
class Config(object):
#dataset
mean_pixel_values = [104.00699, 116.66877, 122.67892, 137.86]
img_width = 400
img_height = 400
train_root = 'data/BIPED/edges/imgs/train/rgbr/real'
valid_root = 'data/BIPED/edges/edge_maps/train/rgbr/real'
valid_output_dir = 'valid_temp'
# hyper parameters
batch_size = 2
num_workers = 0
num_epochs = 25
model_output = 'result'
dataset.py
from torch.utils.data import DataLoader, Dataset
import torch
import cv2 as cv
import numpy as np
import os
class BIPEDDataset(Dataset):
def __init__(self, img_root, mode='train', config=None):
self.img_root = img_root
self.mode = mode
self.imgList = os.listdir(img_root)
self.config = config
self.mean_bgr = config.mean_pixel_values[0:3] if len(config.mean_pixel_values) == 4
else config.mean_pixel_values
def __len__(self):
return len(self.imgList)
def __getitem__(self, idx):
file_name = self.imgList[idx].split('.')[0]
imgPath = os.path.join(self.img_root, self.imgList[idx])
labelPath = imgPath.replace('imgs', 'edge_maps').replace('jpg', 'png')
#load data
image = cv.imread(imgPath, cv.IMREAD_COLOR)
label = cv.imread(labelPath, cv.IMREAD_GRAYSCALE)
image_shape = [image.shape[0], image.shape[1]]
image, label = self.transform(img=image, gt=label)
return dict(images=image, labels=label, file_name=file_name, image_shape=image_shape)
def transform(self, img, gt):
gt = np.array(gt, dtype=np.float32)
if len(gt.shape) == 3:
gt = gt[:, :, 0]
# gt[gt< 51] = 0 # test without gt discrimination
gt /= 255.
# if self.yita is not None:
# gt[gt >= self.yita] = 1
img = np.array(img, dtype=np.float32)
# if self.rgb:
# img = img[:, :, ::-1] # RGB->BGR
img -= self.mean_bgr
# data = []
# if self.scale is not None:
# for scl in self.scale:
# img_scale = cv.resize(img, None, fx=scl, fy=scl, interpolation=cv.INTER_LINEAR)
# data.append(torch.from_numpy(img_scale.transpose((2, 0, 1))).float())
# return data, gt
img = cv.resize(img, dsize=(self.config.img_width, self.config.img_height))
gt = cv.resize(gt, dsize=(self.config.img_width, self.config.img_height))
img = img.transpose((2, 0, 1))
img = torch.from_numpy(img.copy()).float()
gt = torch.from_numpy(np.array([gt])).float()
return img, gt
if __name__=='__main__':
from config import Config
cfg = Config()
root = 'data/BIPED/edges/imgs/train/rgbr/real'
train_dataset = BIPEDDataset(root, config=cfg)
train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0)
for data_batch in train_loader:
img, label = data_batch['images'], data_batch['labels']
print(img.size(), label.size(), data_batch['file_name'])
loss.py
import torch
import torch.nn.functional as F
def _weighted_cross_entropy_loss(preds, edges):
""" Calculate sum of weighted cross entropy loss. """
# Reference:
# hed/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
# https://github.com/s9xie/hed/issues/7
mask = (edges > 0.5).float()
b, c, h, w = mask.shape
num_pos = torch.sum(mask, dim=[1, 2, 3]).float() # Shape: [b,].
num_neg = c * h * w - num_pos # Shape: [b,].
weight = torch.zeros_like(mask)
weight[edges > 0.5] = num_neg / (num_pos + num_neg)
weight[edges <= 0.5] = num_pos / (num_pos + num_neg)
# Calculate loss.
losses = F.binary_cross_entropy_with_logits(
preds.float(), edges.float(), weight=weight, reduction='none')
loss = torch.sum(losses) / b
return loss
def weighted_cross_entropy_loss(preds, edges):
""" Calculate sum of weighted cross entropy loss. """
# Reference:
# hed/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
# https://github.com/s9xie/hed/issues/7
mask = (edges > 0.5).float()
b, c, h, w = mask.shape
num_pos = torch.sum(mask, dim=[1, 2, 3], keepdim=True).float() # Shape: [b,].
num_neg = c * h * w - num_pos # Shape: [b,].
weight = torch.zeros_like(mask)
#weight[edges > 0.5] = num_neg / (num_pos + num_neg)
#weight[edges <= 0.5] = num_pos / (num_pos + num_neg)
weight.masked_scatter_(edges > 0.5,
torch.ones_like(edges) * num_neg / (num_pos + num_neg))
weight.masked_scatter_(edges <= 0.5,
torch.ones_like(edges) * num_pos / (num_pos + num_neg))
# Calculate loss.
# preds=torch.sigmoid(preds)
losses = F.binary_cross_entropy_with_logits(
preds.float(), edges.float(), weight=weight, reduction='none')
loss = torch.sum(losses) / b
return loss
model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class _DenseLayer(nn.Sequential):
def __init__(self, input_features, out_features):
super(_DenseLayer, self).__init__()
# self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(input_features, out_features,
kernel_size=1, stride=1, bias=True)),
self.add_module('norm1', nn.BatchNorm2d(out_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(out_features, out_features,
kernel_size=3, stride=1, padding=1, bias=True)),
self.add_module('norm2', nn.BatchNorm2d(out_features))
# double check the norm1 comment if necessary and put norm after conv2
def forward(self, x):
x1, x2 = x
# maybe I should put here a RELU
new_features = super(_DenseLayer, self).forward(x1) # F.relu()
return 0.5 * (new_features + x2), x2
class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, input_features, out_features):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(input_features, out_features)
self.add_module('denselayer%d' % (i + 1), layer)
input_features = out_features
class UpConvBlock(nn.Module):
def __init__(self, in_features, up_scale, mode='deconv'):
super(UpConvBlock, self).__init__()
self.up_factor = 2
self.constant_features = 16
layers = None
if mode == 'deconv':
layers = self.make_deconv_layers(in_features, up_scale)
elif mode == 'pixel_shuffle':
layers = self.make_pixel_shuffle_layers(in_features, up_scale)
assert layers is not None, layers
self.features = nn.Sequential(*layers)
def make_deconv_layers(self, in_features, up_scale):
layers = []
for i in range(up_scale):
kernel_size = 2 ** up_scale
out_features = self.compute_out_features(i, up_scale)
layers.append(nn.Conv2d(in_features, out_features, 1))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.ConvTranspose2d(
out_features, out_features, kernel_size, stride=2))
in_features = out_features
return layers
def make_pixel_shuffle_layers(self, in_features, up_scale):
layers = []
for i in range(up_scale):
kernel_size = 2 ** (i + 1)
out_features = self.compute_out_features(i, up_scale)
in_features = int(in_features / (self.up_factor ** 2))
layers.append(nn.PixelShuffle(self.up_factor))
layers.append(nn.Conv2d(in_features, out_features, 1))
if i < up_scale:
layers.append(nn.ReLU(inplace=True))
in_features = out_features
return layers
def compute_out_features(self, idx, up_scale):
return 1 if idx == up_scale - 1 else self.constant_features
def forward(self, x):
return self.features(x)
class SingleConvBlock(nn.Module):
def __init__(self, in_features, out_features, stride):
super(SingleConvBlock, self).__init__()
self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride)
self.bn = nn.BatchNorm2d(out_features)
def forward(self, x):
return self.bn(self.conv(x))
class DoubleConvBlock(nn.Module):
def __init__(self, in_features, mid_features,out_features=None, stride=1):
super(DoubleConvBlock, self).__init__()
if out_features is None:
out_features = mid_features
self.conv1 = nn.Conv2d(
in_features, mid_features, 3, padding=1, stride=stride)
self.bn1 = nn.BatchNorm2d(mid_features)
self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_features)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
class DexiNet(nn.Module):
""" Definition of the DXtrem network. """
def __init__(self):
super(DexiNet, self).__init__()
self.block_1 = DoubleConvBlock(3, 32, 64, stride=2)
self.block_2 = DoubleConvBlock(64, 128)
self.dblock_3 = _DenseBlock(2, 128, 256)
self.dblock_4 = _DenseBlock(3, 256, 512)
self.dblock_5 = _DenseBlock(3, 512, 512)
self.dblock_6 = _DenseBlock(3, 512, 256)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.side_1 = SingleConvBlock(64, 128, 2)
self.side_2 = SingleConvBlock(128, 256, 2)
self.side_3 = SingleConvBlock(256, 512, 2)
self.side_4 = SingleConvBlock(512, 512, 1)
self.side_5 = SingleConvBlock(512, 256, 1)
self.pre_dense_2 = SingleConvBlock(128, 256, 2) # by me, for left skip block4
self.pre_dense_3 = SingleConvBlock(128, 256, 1)
self.pre_dense_4 = SingleConvBlock(256, 512, 1)
self.pre_dense_5_0 = SingleConvBlock(256, 512, 2)
self.pre_dense_5 = SingleConvBlock(512, 512, 1)
self.pre_dense_6 = SingleConvBlock(512, 256, 1)
self.up_block_1 = UpConvBlock(64, 1)
self.up_block_2 = UpConvBlock(128, 1)
self.up_block_3 = UpConvBlock(256, 2)
self.up_block_4 = UpConvBlock(512, 3)
self.up_block_5 = UpConvBlock(512, 4)
self.up_block_6 = UpConvBlock(256, 4)
self.block_cat = nn.Conv2d(6, 1, kernel_size=1)
def slice(self, tensor, slice_shape):
height, width = slice_shape
return tensor[..., :height, :width]
def forward(self, x):
assert len(x.shape) == 4, x.shape
# Block 1
block_1 = self.block_1(x)
block_1_side = self.side_1(block_1)
# Block 2
block_2 = self.block_2(block_1)
block_2_down = self.maxpool(block_2)
block_2_add = block_2_down + block_1_side
block_2_side = self.side_2(block_2_add)
# Block 3
block_3_pre_dense = self.pre_dense_3(block_2_down)
block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense])
block_3_down = self.maxpool(block_3)
block_3_add = block_3_down + block_2_side
block_3_side = self.side_3(block_3_add)
# Block 4
block_4_pre_dense_256 = self.pre_dense_2(block_2_down)
block_4_pre_dense = self.pre_dense_4(block_4_pre_dense_256 + block_3_down)
block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense])
block_4_down = self.maxpool(block_4)
block_4_add = block_4_down + block_3_side
block_4_side = self.side_4(block_4_add)
# Block 5
block_5_pre_dense_512 = self.pre_dense_5_0(block_4_pre_dense_256)
block_5_pre_dense = self.pre_dense_5(block_5_pre_dense_512 + block_4_down )
block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense])
block_5_add = block_5 + block_4_side
# block_5_side = self.side_5(block_5_add)
# Block 6
block_6_pre_dense = self.pre_dense_6(block_5)
# block_5_pre_dense_256 = self.pre_dense_6(block_5_add) # if error uncomment
block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense])
# upsampling blocks
height, width = x.shape[-2:]
slice_shape = (height, width)
out_1 = self.slice(self.up_block_1(block_1), slice_shape)
out_2 = self.slice(self.up_block_2(block_2), slice_shape)
out_3 = self.slice(self.up_block_3(block_3), slice_shape)
out_4 = self.slice(self.up_block_4(block_4), slice_shape)
out_5 = self.slice(self.up_block_5(block_5), slice_shape)
out_6 = self.slice(self.up_block_6(block_6), slice_shape)
results = [out_1, out_2, out_3, out_4, out_5, out_6]
# print(out_1.shape)
# concatenate multiscale outputs
block_cat = torch.cat(results, dim=1) # Bx6xHxW
block_cat = self.block_cat(block_cat) # Bx1xHxW
# return results
results.append(block_cat)
return results
main.py
from torch import nn
from torch.utils.data import DataLoader
from dataset import BIPEDDataset
from losses import *
from config import Config
from cyclicLR import CyclicCosAnnealingLR
from model import DexiNet
import torchgeometry as tgm
import numpy as np
import time
import os
import cv2 as cv
import tqdm
def weight_init(m):
if isinstance(m, (nn.Conv2d, )):
torch.nn.init.normal_(m.weight,mean=0, std=0.01)
if m.weight.data.shape[1]==torch.Size([1]):
torch.nn.init.normal_(m.weight, mean=0.0,)
if m.weight.data.shape==torch.Size([1,6,1,1]):
torch.nn.init.constant_(m.weight,0.2)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
# for fusion layer
if isinstance(m, (nn.ConvTranspose2d,)):
torch.nn.init.normal_(m.weight,mean=0, std=0.01)
if m.weight.data.shape[1] == torch.Size([1]):
torch.nn.init.normal_(m.weight, std=0.1)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
class Trainer(object):
def __init__(self, cfg):
self.cfg = cfg
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = DexiNet().to(self.device).apply(weight_init)
self.criterion = weighted_cross_entropy_loss
self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=0.003, weight_decay=0.0001)
milestones = [5 + x * 30 for x in range(5)]
self.scheduler = CyclicCosAnnealingLR(self.optimizer, milestones=milestones, eta_min=5e-5)
mkdir(cfg.model_output)
def build_loader(self):
train_dataset = BIPEDDataset(self.cfg.train_root, config=self.cfg)
valid_dataset = BIPEDDataset(self.cfg.valid_root, config=self.cfg)
train_loader = DataLoader(train_dataset,
batch_size=self.cfg.batch_size,
num_workers=self.cfg.num_workers,
shuffle=True)
valid_loader = DataLoader(valid_dataset,
batch_size=self.cfg.batch_size,
num_workers=self.cfg.num_workers,
shuffle=False)
return train_loader, valid_loader
def train_one_epoch(self, epoch, dataloader):
self.model.train()
for batch_id, sample_batched in tqdm.tqdm(enumerate(dataloader)):
images = sample_batched['images'].to(self.device) # BxCxHxW
labels = sample_batched['labels'].to(self.device) # BxHxW
preds_list = self.model(images)
loss = sum([self.criterion(preds, labels) for preds in preds_list])
loss /= images.shape[0] # the bacth size
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
print(time.ctime(), 'training, Epoch: {0} Sample {1}/{2} Loss: {3}'
.format(epoch, batch_id, len(dataloader), loss.item()), end='r')
def validation(self, epoch, dataloader):
self.model.eval()
for batch_id, sample_batched in enumerate(dataloader):
images = sample_batched['images'].to(self.device) # BxCxHxW
labels = sample_batched['labels'].to(self.device) # BxHxW
file_name = sample_batched['file_name']
preds_list = self.model(images)
loss = sum([self.criterion(preds, labels) for preds in preds_list])
loss /= images.shape[0] # the bacth size
print(time.ctime(), 'validation, Epoch: {0} Sample {1}/{2} Loss: {3}'
.format(epoch, batch_id, len(dataloader), loss.item()), end='r')
self.save_image_bacth_to_disk(preds_list[-1], file_name)
return loss
def save_image_bacth_to_disk(self, tensor, file_names):
output_dir = self.cfg.valid_output_dir
mkdir(output_dir)
assert len(tensor.shape) == 4, tensor.shape
for tensor_image, file_name in zip(tensor, file_names):
image_vis = tgm.utils.tensor_to_image(torch.sigmoid(tensor_image))[..., 0]
image_vis = (255.0 * (1.0 - image_vis)).astype(np.uint8) #
output_file_name = os.path.join(output_dir, f"{file_name}.png")
cv.imwrite(output_file_name, image_vis)
def train(self):
train_loader, valid_loader = self.build_loader()
best_loss = 1000000
for epoch in range(self.cfg.num_epochs):
self.scheduler.step(epoch)
self.model.train()
for batch_id, sample_batched in enumerate(train_loader):
images = sample_batched['images'].to(self.device) # BxCxHxW
labels = sample_batched['labels'].to(self.device) # BxHxW
preds_list = self.model(images)
loss = sum([self.criterion(preds, labels) for preds in preds_list])
loss /= images.shape[0] # the bacth size
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
print(time.ctime(), 'training, Epoch: {0} Sample {1}/{2} Loss: {3}'
.format(epoch, batch_id, len(train_loader), loss.item()), end='r')
valid_loss = self.validation(epoch, valid_loader)
if valid_loss < best_loss:
torch.save(self.model, os.path.join(self.cfg.model_output, f'epoch{epoch}_model.pth'))
print(f'find optimal model, loss {best_loss}==>{valid_loss}')
best_loss = valid_loss
if __name__=='__main__':
config = Config()
trainer = Trainer(config)
trainer.train()
predict.py
#!/usr/bin/env python
# coding: utf-8
from PIL import Image
import cv2
from path import Path
from eval.dataset import SasDataset
import torch
import torch.backends.cudnn
from torch.utils.data import DataLoader
import os
import tqdm
import numpy as np
import argparse
import shutil
import glob
import gdal
import tqdm
from ndvi import *
from skimage import morphology
from scipy import ndimage as ndi
from config_eval import Config
from pytorch_toolbelt.inference import tta
from skimage.morphology import dilation, erosion, closing, square
import torchgeometry as tgm
# need to create a file to store temp pictures
path = './temp_pic/'
device = 'cuda:0'
dark = [0,0,0]
cfg =Config()
# 水平翻转
def flip_horizontal_tensor(batch):
columns = batch.data.size()[-1]
return batch.index_select(-1, torch.LongTensor(list(reversed(range(columns)))).cuda())
# 垂直翻转
def flip_vertical_tensor(batch):
rows = batch.data.size()[-2]
return batch.index_select(-2, torch.LongTensor(list(reversed(range(rows)))).cuda())
def input_and_output(pic_path, model, loader=None, generate_data=True):
"""
args:
pic_path : the picture you want to predict
model : the model you want to predict
note:
step one : generate some pictures from one picture
step two : predict from the images generated by step one
"""
image_size = cfg.crop_size
data = gdal.Open(pic_path)
lastChannel = data.RasterCount + 1
arr = [data.GetRasterBand(idx).ReadAsArray() for idx in range(1, lastChannel)]
data = np.dstack(arr)
raw_h, raw_w = data.shape[:2]
b = cfg.padding_size
row = raw_h // image_size + 1
col = raw_w // image_size + 1
radius_h = row * image_size - raw_h
radius_w = col * image_size - raw_w
image = cv2.copyMakeBorder(data, 0, radius_h, 0, radius_w, cv2.BORDER_REFLECT)
image = cv2.copyMakeBorder(image, b, b, b, b, cv2.BORDER_REFLECT)
h, w = image.shape[:2]
padding_img = image[:, :, :]
padding_img = np.array(padding_img)
mask_whole = np.zeros((row*image_size, col*image_size), dtype=np.uint8)
if generate_data == False:
print('starting prediction')
result = []
for batch in tqdm.tqdm(loader):
images = batch['img'].to(device, dtype=torch.float)
temp = 0
for keys in model.keys():
# model[keys].eval()
net = model[keys]
if cfg.TTA:
outputs = tta.fliplr_image2label(net, images)
else:
outputs = net(images)
tensor = outputs[-1][0, ...]
image_vis = tgm.utils.tensor_to_image(torch.sigmoid(tensor))[..., 0]
image_vis = (255.0 * (1.0 - image_vis)).astype(np.uint8) #
# outputs = torch.squeeze(outputs).detach().cpu().numpy()
temp += image_vis
preds = temp / len(model)
# preds = torch.from_numpy(preds)
result.append(preds)
map_list = [str(i.name) for i in Path('temp_pic').files()]
for i in tqdm.tqdm(range(row)):
for j in range(col):
if generate_data:
crop_img = redundancy_crop(padding_img, i, j, image_size)
ch,cw,_ = crop_img.shape
cv2.imwrite(f'temp_pic/{i}_{j}.tif', crop_img)
else:
temp = result[map_list.index(f'{i}_{j}.tif')]
temp = redundancy_crop2(temp)
mask_whole[i*image_size:i*image_size+image_size, j*image_size:j*image_size+image_size] = temp
return mask_whole[:raw_h, :raw_w]
def redundancy_crop(img, i, j, targetSize):
if len(img.shape)>2:
temp_img = img[i*targetSize:i*targetSize+targetSize+2*cfg.padding_size, j*targetSize:j*targetSize+targetSize+2*cfg.padding_size, :]
else:
temp_img = img[i*targetSize:i*targetSize+targetSize+2*cfg.padding_size, j*targetSize:j*targetSize+targetSize+2*cfg.padding_size]
return temp_img
def redundancy_crop2(img):
h = img.shape[0]
w = img.shape[1]
temp_img = img[cfg.padding_size:h-cfg.padding_size, cfg.padding_size:w-cfg.padding_size]
return temp_img
def get_dataset_loaders():
batch_size = 1
test_dataset = SasDataset(
"./temp_pic",
mode='test'
)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0)
return test_loader
def get_labels():
"""Load the mapping that associates pascal classes with label colors
Returns:
np.ndarray with dimensions (2, 3)
"""
return np.asarray(
[
[0, 0, 0],
[255, 255, 255]
]
)
def decode_segmap(label_mask, n_classes):
"""Decode segmentation class labels into a color image
Args:
label_mask (np.ndarray): an (M,N) array of integer values denoting
the class label at each spatial location.
plot (bool, optional): whether to show the resulting color image
in a figure.
Returns:
(np.ndarray, optional): the resulting decoded color image.
"""
label_colours = get_labels()
r = label_mask.copy()
g = label_mask.copy()
b = label_mask.copy()
for ll in range(0, n_classes):
r[label_mask == ll] = label_colours[ll, 0]
g[label_mask == ll] = label_colours[ll, 1]
b[label_mask == ll] = label_colours[ll, 2]
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
rgb[:, :, 0] = r
rgb[:, :, 1] = g
rgb[:, :, 2] = b
return rgb
def mkdir(path):
if not os.path.exists(path):
os.mkdir(path)
def predict():
model_groups = glob.glob(cfg.model_path+'/*.pth')
imgList = glob.glob(cfg.data_path + '/*.tif')
num = len(imgList)
# predict on more model
print('loading models')
models = {}
for index, item in enumerate(model_groups):
models[item] = torch.load(item, map_location='cuda:0')
# models[item] = torch.load(item, map_location='cuda:0')["ema_state_dict"]
# model = torch.load(f'./results_{args.model_name}/{args.model_name}_weights_best.pth')["model_state"]
for i in tqdm.tqdm(range(num)):
if not os.path.exists('temp_pic'):
os.makedirs('temp_pic')
input_and_output(imgList[i], models, generate_data=True)
name = os.path.split(imgList[i])[-1].split(".")[0]
test_loader = get_dataset_loaders()
mask_result = input_and_output(imgList[i], models, loader=test_loader, generate_data=False)
# 递归删除文件夹
try:
shutil.rmtree('temp_pic')
except:
pass
# mask_result = closing(mask_result, square(20))
mask_result = mask_result.astype(np.uint8)
mkdir(cfg.save_path)
cv2.imwrite(os.path.join(cfg.save_path, name + '_predict.tif'), mask_result)
if __name__ == '__main__':
predict()
遥感影像切片
预测结果
原文链接:https://blog.csdn.net/weixin_42990464/article/details/107243620
最后
以上就是悲凉小蜜蜂为你收集整理的复现Dense Extreme Inception Network(pytorch)文章目录 摘要 模型结构 论文实验结果 相关代码 摘要模型结构的全部内容,希望文章能够帮你解决复现Dense Extreme Inception Network(pytorch)文章目录 摘要 模型结构 论文实验结果 相关代码 摘要模型结构所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复