我是靠谱客的博主 粗犷万宝路,最近开发中收集的这篇文章主要介绍基于鱼皮肤的鱼个体识别(2)- 建立数据集import 各种toolsClass LoadFishDataUtil,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

本文主要阐述如何处理数据,加载数据,建立数据集

import 各种tools

from __future__ import absolute_import, division, print_function, unicode_literals
import pathlib
import glob
import tensorflow as tf
import os
import random
import numpy as np
from scipy.special import binom

from tensorflow import  keras
from tensorflow.keras import datasets, layers, optimizers, models
from tensorflow.keras import regularizers
from PIL import Image
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical

Class LoadFishDataUtil

SPLIT_WEIGHTS=(0.7, 0.15, 0.15) 设置train val 以及test 数据集的比例
get_label() 返回的是数据的标签
decode_img(self,img) 处理图像
数据文件夹结构:

(base) xingbo@xingbo-pc:~/Desktop/fish_identification/data/SESSION_AQUARIUM/SESSION1$ tree |less
.
├── 001
│   ├── fish_2_001_01.png
│   ├── fish_2_001_02.png
│   ├── fish_2_001_03.png
│   ├── fish_2_001_04.png

├── 002
│   ├── fish_2_002_01.png
│   ├── fish_2_002_02.png

......

class LoadFishDataUtil():
    def __init__(self, directory_str,BATCH_SIZE,IMG_SIZE,CLASS_NAMES=None,SPLIT_WEIGHTS=(0.7, 0.15, 0.15)):
      self.directory_str=directory_str
      self.SPLIT_WEIGHTS=SPLIT_WEIGHTS
      self.BATCH_SIZE=BATCH_SIZE
      self.IMG_SIZE=IMG_SIZE
      self.data_dir = pathlib.Path(directory_str)
      self.image_count = len(list(self.data_dir.glob('*/*.png')))
      if CLASS_NAMES is None:
        self.CLASS_NAMES = np.array([item.name for item in self.data_dir.glob('*') if item.name != "LICENSE.txt"])
      else:
        self.CLASS_NAMES = CLASS_NAMES

      self.class_num=len(self.CLASS_NAMES)
 
      IMG_HEIGHT = IMG_SIZE
      IMG_WIDTH = IMG_SIZE
      self.STEPS_PER_EPOCH = np.ceil(self.image_count/BATCH_SIZE)


    def get_label(self,file_path):
    # convert the path to a list of path components
      parts = tf.strings.split(file_path, '/')
    # The second to last is the class-directory
      print(parts[-2] == self.CLASS_NAMES)
      wh = tf.where(tf.equal(self.CLASS_NAMES,parts[-2]))
      return parts[-2] == self.CLASS_NAMES
    def decode_img(self,img):
    # convert the compressed string to a 3D uint8 tensor
      img = tf.image.decode_jpeg(img, channels=3)
    # Use `convert_image_dtype` to convert to floats in the [0,1] range.
      img = tf.image.convert_image_dtype(img, tf.float32)
    #img = (img/127.5) - 1
    # resize the image to the desired size.
      return tf.image.resize(img, [self.IMG_SIZE, self.IMG_SIZE])

    def process_path(self,file_path):
      label = self.get_label(file_path)
    # load the raw data from the file as a string
      img = tf.io.read_file(file_path)
      img = self.decode_img(img)
      return img, label
 

    def prepare_for_training(self,ds, cache=True, shuffle_buffer_size=1000):
    # This is a small dataset, only load it once, and keep it in memory.
    # use `.cache(filename)` to cache preprocessing work for datasets that don't
    # fit in memory.
      if cache:
        if isinstance(cache, str):
          ds = ds.cache(cache)
        else:
          ds = ds.cache()

      ds = ds.shuffle(buffer_size=shuffle_buffer_size)

    # Repeat forever
      ds = ds.repeat()

      ds = ds.batch(self.BATCH_SIZE)

  # `prefetch` lets the dataset fetch batches in the background while the model
  # is training.
      ds = ds.prefetch(buffer_size=self.AUTOTUNE)

      return ds
  

    def loadFishData(self):
      list_ds = tf.data.Dataset.list_files(str(self.data_dir/'*/*'))
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
      self.AUTOTUNE = tf.data.experimental.AUTOTUNE
      self.labeled_ds = list_ds.map(self.process_path, num_parallel_calls=self.AUTOTUNE)
   
      train_size = int(self.SPLIT_WEIGHTS[0] * self.image_count)
      val_size = int(self.SPLIT_WEIGHTS[1] * self.image_count)
      test_size = int(self.SPLIT_WEIGHTS[2] * self.image_count)
      train_ds = self.prepare_for_training(self.labeled_ds)

      full_dataset = train_ds.shuffle(buffer_size=1000,reshuffle_each_iteration = False )
      train_dataset = full_dataset.take(train_size)
      test_dataset = full_dataset.skip(train_size)
      val_dataset = test_dataset.skip(val_size)
      test_dataset = test_dataset.take(test_size)
      return train_dataset,val_dataset,test_dataset,self.STEPS_PER_EPOCH,self.CLASS_NAMES,self.class_num
		
		
    def loadTestFishData(self):
      list_ds = tf.data.Dataset.list_files(str(self.data_dir/'*/*'))
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
      self.AUTOTUNE = tf.data.experimental.AUTOTUNE
      self.labeled_ds = list_ds.map(self.process_path, num_parallel_calls=self.AUTOTUNE)
   
      test_ds = self.prepare_for_training(self.labeled_ds)

     
      return test_ds,self.class_num
		
		

最后

以上就是粗犷万宝路为你收集整理的基于鱼皮肤的鱼个体识别(2)- 建立数据集import 各种toolsClass LoadFishDataUtil的全部内容,希望文章能够帮你解决基于鱼皮肤的鱼个体识别(2)- 建立数据集import 各种toolsClass LoadFishDataUtil所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(58)

评论列表共有 0 条评论

立即
投稿
返回
顶部