我是靠谱客的博主 文静睫毛膏,最近开发中收集的这篇文章主要介绍批量分割mask转json,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

代码

import os
import cv2
import sys
import PIL
import copy
import json
import yaml
import base64
import numpy as np
import skimage.io as io
from glob import glob
try:
from labelme import __version__ as labelme_version
except:
labelme_version = '4.2.9'
sys.path.append('..')
currentCV_version = cv2.__version__
def rm(filepath):
p = open(filepath, 'r+')
lines = p.readlines()
d = ""
for line in lines:
c = line.replace('"group_id": "null",', '"group_id": null,')
d += c
p.seek(0)
p.truncate()
p.write(d)
p.close()
def imgEncode(img_or_path):
if isinstance(img_or_path, np.ndarray):
"""
copy from labelme image.py
"""
img_pil = PIL.Image.fromarray(img_or_path)
f = io.BytesIO()
img_pil.save(f, format='PNG')
img_bin = f.getvalue()
if hasattr(base64, 'encodebytes'):
img_b64 = base64.encodebytes(img_bin)
else:
img_b64 = base64.encodestring(img_bin)
return img_b64
else:
if isinstance(img_or_path, str):
i = open(img_or_path, 'rb')
elif isinstance(img_or_path, io.BufferedReader):
i = img_or_path
else:
raise TypeError('Input type error!')
base64_data = base64.b64encode(i.read())
return base64_data.decode()
def rs(st: str):
s = st.replace('n', '').strip()
return s
def readYmal(filepath, labeledImg=None):
if os.path.exists(filepath):
if filepath.endswith('.yaml'):
f = open(filepath)
y = yaml.load(f, Loader=yaml.FullLoader)
f.close()
# print(y)
tmp = y['label_names']
# print(tmp["tag1"])
objs = zip(tmp.keys(), tmp.values())
return sorted(objs)
elif filepath.endswith('.txt'):
f = open(filepath, 'r', encoding='utf-8')
classList = f.readlines()
f.close()
l3 = [rs(i) for i in classList]
l = list(range(1, len(classList)+1))
objs = zip(l3, l)
return sorted(objs)
elif labeledImg is not None and filepath == "":
"""
should make sure your label is correct!!!
"""
labeledImg = np.array(labeledImg, dtype=np.uint8)
labeledImg[labeledImg > 0] = 255
labeledImg[labeledImg != 255] = 0
# print(labeledImg)
_, labels, stats, centroids = cv2.connectedComponentsWithStats(
labeledImg)
labels = np.max(labels) + 1
labels = [x for x in range(1, labels)]
classes = []
for i in range(0, len(labels)):
classes.append("class{}".format(i))
return zip(classes, labels)
else:
raise FileExistsError('file not found')
def get_approx(img, contour, length_p=0.005):
"""获取逼近多边形
:param img: 处理图片
:param contour: 连通域
:param length_p: 逼近长度百分比
"""
img_adp = img.copy()
# 逼近长度计算
epsilon = length_p * cv2.arcLength(contour, True)
# 获取逼近多边形
approx = cv2.approxPolyDP(contour, epsilon, True)
return approx
def getBinary(img_or_path, minConnectedArea=1):
if isinstance(img_or_path, str):
i = cv2.imread(img_or_path)
elif isinstance(img_or_path, np.ndarray):
i = img_or_path
else:
raise TypeError('Input type error')
if len(i.shape) == 3:
img_gray = cv2.cvtColor(i, cv2.COLOR_BGR2GRAY)
else:
img_gray = i
ret, img_bin = cv2.threshold(img_gray, 127, 255, cv2.THRESH_BINARY)
_, labels, stats, centroids = cv2.connectedComponentsWithStats(img_bin, connectivity=4)
# labels:图像上每一像素的标记,用数字1、2、3…表示(不同的数字表示不同的连通域)
# stats:每一个标记的统计信息,是一个5列的矩阵,每一行对应每个连通区域的外接矩形的x、y、width、height和面积,示例如下: 0 0 720 720 291805
# centroids:连通域的中心点
# print(stats.shape)
(19,5)
# 删除区域小的图片
for index in range(1, stats.shape[0]):
if stats[index][4] < minConnectedArea or stats[index][4] < 0.0001 * (
stats[index][2] * stats[index][3]):
labels[labels == index] = 0
labels[labels != 0] = 1
img_bin = np.array(img_bin * labels).astype(np.uint8)
return i, img_bin
def getMultiRegion(img, img_bin):
"""
for multiple objs in same class
"""
if float(currentCV_version[0:3]) < 3.5:
img_bin, contours, hierarchy = cv2.findContours(
img_bin, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
else:
contours, hierarchy = cv2.findContours(img_bin, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
regions = []
if len(contours) >= 1:
for i in range(0, len(contours)):
if i:
# print(len(contours[i]))
region = get_approx(img, contours[i], 0.0001)
# print(region)
if region.shape[0] > 3:
regions.append(region)
return regions
else:
return []
def process(oriImg):
img, img_bin = getBinary(oriImg)
return getMultiRegion(img, img_bin)
def getMultiShapes(oriImgPath, labelPath, savePath='', labelYamlPath='', flag=False):
"""
oriImgPath : for change img to base64
n
labelPath : after fcn/unet or other machine learning objects outlining , the generated label img
or labelme labeled imgs(after json files converted to mask files)
n
savePath : json file save path
n
labelYamlPath : after json files converted to mask files. if doesn't have this file,should have a labeled img.
but the classes should change by yourself(labelme 4.2.9 has a bug,when change the label there will be an error.
)
n
"""
if isinstance(labelPath, str):
if os.path.exists(labelPath):
label_img = io.imread(labelPath)
else:
raise FileNotFoundError('mask/labeled image not found')
else:
label_img = labelPath
# print(np.max(label_img))
if np.max(label_img) > 127:
# print('too many classes! n maybe binary?')
label_img[label_img > 127] = 255
label_img[label_img != 255] = 0
label_img = label_img / 255
labelShape = label_img.shape
labels = readYmal(labelYamlPath, label_img)
# print(list(labels))
shapes = []
obj = dict()
obj['version'] = labelme_version
obj['flags'] = {}
for la in list(labels):
if la[1] > 0:
# print(la[0])
img = copy.deepcopy(label_img)
# img = label_img.copy()
img = img.astype(np.uint8)
img[img == la[1]] = 255
img[img != 255] = 0
region = process(img.astype(np.uint8))
if isinstance(region, np.ndarray):
points = []
for i in range(0, region.shape[0]):
print(len(region[i][0]))
points.append(region[i][0].tolist())
shape = dict()
shape['label'] = la[0]
shape['points'] = points
shape['group_id'] = 'null'
shape['shape_type'] = 'polygon'
shape['flags'] = {}
shapes.append(shape)
elif isinstance(region, list):
# print(len(region))
for subregion in region:
points = []
for i in range(0, subregion.shape[0]):
points.append(subregion[i][0].tolist())
shape = dict()
shape['label'] = la[0]
shape['points'] = points
shape['group_id'] = 'null'
shape['shape_type'] = 'polygon'
shape['flags'] = {}
shapes.append(shape)
# print(len(shapes))
obj['shapes'] = shapes
# print(shapes)
(_, imgname) = os.path.split(oriImgPath)
obj['imagePath'] = imgname
# print(obj['imagePath'])
obj['imageData'] = str(imgEncode(oriImgPath))
obj['imageHeight'] = labelShape[0]
obj['imageWidth'] = labelShape[1]
j = json.dumps(obj, sort_keys=True, indent=4)
# print(j)
if not flag:
saveJsonPath = savePath + os.sep + obj['imagePath'][:-4] + '.json'
# print(saveJsonPath)
with open(saveJsonPath, 'w') as f:
f.write(j)
rm(saveJsonPath)
else:
return j
if __name__ == "__main__":
path = ''
init_path = '%s/image' % path
mask_path = '%s/mask' % path
yaml_file = '%s/label_names.yaml' % path
save_json = '%s/json' % path
mask_images_list = glob(os.path.join(mask_path, "*.png"))
init_images_list = glob(os.path.join(init_path, "*.png"))
if not os.path.exists(save_json):
os.mkdir(save_json)
for mask_image, init_image in zip(mask_images_list, init_images_list):
print(mask_image)
getMultiShapes(init_image, mask_image, save_json, yaml_file)

label_name.yaml格式

label_names:
Tag1: 1
类别: 掩码像素值
....

参考

https://github.com/guchengxi1994/mask2json

最后

以上就是文静睫毛膏为你收集整理的批量分割mask转json的全部内容,希望文章能够帮你解决批量分割mask转json所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(51)

评论列表共有 0 条评论

立即
投稿
返回
顶部