我是靠谱客的博主 粗犷万宝路,最近开发中收集的这篇文章主要介绍基于鱼皮肤的鱼个体识别(2)- 建立数据集import 各种toolsClass LoadFishDataUtil,觉得挺不错的,现在分享给大家,希望可以做个参考。
概述
本文主要阐述如何处理数据,加载数据,建立数据集
import 各种tools
from __future__ import absolute_import, division, print_function, unicode_literals
import pathlib
import glob
import tensorflow as tf
import os
import random
import numpy as np
from scipy.special import binom
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, models
from tensorflow.keras import regularizers
from PIL import Image
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
Class LoadFishDataUtil
SPLIT_WEIGHTS=(0.7, 0.15, 0.15) 设置train val 以及test 数据集的比例
get_label() 返回的是数据的标签
decode_img(self,img) 处理图像
数据文件夹结构:
(base) xingbo@xingbo-pc:~/Desktop/fish_identification/data/SESSION_AQUARIUM/SESSION1$ tree |less
.
├── 001
│ ├── fish_2_001_01.png
│ ├── fish_2_001_02.png
│ ├── fish_2_001_03.png
│ ├── fish_2_001_04.png
├── 002
│ ├── fish_2_002_01.png
│ ├── fish_2_002_02.png
......
class LoadFishDataUtil():
def __init__(self, directory_str,BATCH_SIZE,IMG_SIZE,CLASS_NAMES=None,SPLIT_WEIGHTS=(0.7, 0.15, 0.15)):
self.directory_str=directory_str
self.SPLIT_WEIGHTS=SPLIT_WEIGHTS
self.BATCH_SIZE=BATCH_SIZE
self.IMG_SIZE=IMG_SIZE
self.data_dir = pathlib.Path(directory_str)
self.image_count = len(list(self.data_dir.glob('*/*.png')))
if CLASS_NAMES is None:
self.CLASS_NAMES = np.array([item.name for item in self.data_dir.glob('*') if item.name != "LICENSE.txt"])
else:
self.CLASS_NAMES = CLASS_NAMES
self.class_num=len(self.CLASS_NAMES)
IMG_HEIGHT = IMG_SIZE
IMG_WIDTH = IMG_SIZE
self.STEPS_PER_EPOCH = np.ceil(self.image_count/BATCH_SIZE)
def get_label(self,file_path):
# convert the path to a list of path components
parts = tf.strings.split(file_path, '/')
# The second to last is the class-directory
print(parts[-2] == self.CLASS_NAMES)
wh = tf.where(tf.equal(self.CLASS_NAMES,parts[-2]))
return parts[-2] == self.CLASS_NAMES
def decode_img(self,img):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img, channels=3)
# Use `convert_image_dtype` to convert to floats in the [0,1] range.
img = tf.image.convert_image_dtype(img, tf.float32)
#img = (img/127.5) - 1
# resize the image to the desired size.
return tf.image.resize(img, [self.IMG_SIZE, self.IMG_SIZE])
def process_path(self,file_path):
label = self.get_label(file_path)
# load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = self.decode_img(img)
return img, label
def prepare_for_training(self,ds, cache=True, shuffle_buffer_size=1000):
# This is a small dataset, only load it once, and keep it in memory.
# use `.cache(filename)` to cache preprocessing work for datasets that don't
# fit in memory.
if cache:
if isinstance(cache, str):
ds = ds.cache(cache)
else:
ds = ds.cache()
ds = ds.shuffle(buffer_size=shuffle_buffer_size)
# Repeat forever
ds = ds.repeat()
ds = ds.batch(self.BATCH_SIZE)
# `prefetch` lets the dataset fetch batches in the background while the model
# is training.
ds = ds.prefetch(buffer_size=self.AUTOTUNE)
return ds
def loadFishData(self):
list_ds = tf.data.Dataset.list_files(str(self.data_dir/'*/*'))
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
self.AUTOTUNE = tf.data.experimental.AUTOTUNE
self.labeled_ds = list_ds.map(self.process_path, num_parallel_calls=self.AUTOTUNE)
train_size = int(self.SPLIT_WEIGHTS[0] * self.image_count)
val_size = int(self.SPLIT_WEIGHTS[1] * self.image_count)
test_size = int(self.SPLIT_WEIGHTS[2] * self.image_count)
train_ds = self.prepare_for_training(self.labeled_ds)
full_dataset = train_ds.shuffle(buffer_size=1000,reshuffle_each_iteration = False )
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
return train_dataset,val_dataset,test_dataset,self.STEPS_PER_EPOCH,self.CLASS_NAMES,self.class_num
def loadTestFishData(self):
list_ds = tf.data.Dataset.list_files(str(self.data_dir/'*/*'))
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
self.AUTOTUNE = tf.data.experimental.AUTOTUNE
self.labeled_ds = list_ds.map(self.process_path, num_parallel_calls=self.AUTOTUNE)
test_ds = self.prepare_for_training(self.labeled_ds)
return test_ds,self.class_num
最后
以上就是粗犷万宝路为你收集整理的基于鱼皮肤的鱼个体识别(2)- 建立数据集import 各种toolsClass LoadFishDataUtil的全部内容,希望文章能够帮你解决基于鱼皮肤的鱼个体识别(2)- 建立数据集import 各种toolsClass LoadFishDataUtil所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复