我是靠谱客的博主 等待鸡翅,最近开发中收集的这篇文章主要介绍BSD500数据集处理,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

from os.path import exists, join, basename
from os import makedirs, remove
from six.moves import urllib
import tarfile
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize

from dataset import DatasetFromFolder


def download_bsd300(dest="dataset"):
    output_image_dir = join(dest, "BSD500/images")

    if not exists(output_image_dir):
        makedirs(dest)
        url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
        print("downloading url ", url)

        data = urllib.request.urlopen(url)

        file_path = join(dest, basename(url))
        with open(file_path, 'wb') as f:
            f.write(data.read())

        print("Extracting data")
        with tarfile.open(file_path) as tar:
            for item in tar:
                tar.extract(item, dest)

        remove(file_path)

    return output_image_dir


def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)


def input_transform(crop_size, upscale_factor):
    return Compose([
        CenterCrop(crop_size),
        Resize(crop_size // upscale_factor),
        ToTensor(),
    ])


def target_transform(crop_size):
    return Compose([
        CenterCrop(crop_size),
        ToTensor(),
    ])


def get_training_set(data_dir, upscale_factor):
    root_dir = data_dir  # download_bsd300()
    train_dir = join(root_dir, "train")
    crop_size = calculate_valid_crop_size(256, upscale_factor)

    return DatasetFromFolder(train_dir,
                             input_transform=input_transform(crop_size, upscale_factor),
                             target_transform=target_transform(crop_size))

#from PIL import Image
#import os
#import numpy as np
def get_test_set(data_dir, upscale_factor):
    root_dir = data_dir   # download_bsd300()
    test_dir = join(root_dir, "test")
    crop_size = calculate_valid_crop_size(256, upscale_factor)
    
    #my code
#    test_images= os.listdir(test_dir)
#    
#    for input_image in test_images:
#        
#        img = Image.open(test_dir+'/'+input_image).convert('YCbCr')
#        y, cb, cr = img.split()
#        target_t=target_transform(crop_size)
#        target= target_t(y)
##        print(target)
#        out = target.cpu()
#        print('out.shape', out.shape)
#        out_img_y = out.detach().numpy()
#        out_img_y *= 255.0
#        out_img_y = out_img_y.clip(0, 255)
#        out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')
#    
#        out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
#        out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
#        out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')
#    
#    #    print(input_image)
#        out_img.save('demo/' + input_image)
    
#        print('target shape', target.shape)    #torch.Size([1, 255, 255])
#        print('cb shape', cb.shape)
        
    return DatasetFromFolder(test_dir,
                             input_transform=input_transform(crop_size, upscale_factor),
                             target_transform=target_transform(crop_size))

 

最后

以上就是等待鸡翅为你收集整理的BSD500数据集处理的全部内容,希望文章能够帮你解决BSD500数据集处理所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(71)

评论列表共有 0 条评论

立即
投稿
返回
顶部