概述
本文实现了一个自定义的 Generator, 从文件夹中读取图片, 然后进行45 * randn(8) 角度的旋转
def custom_generator(shuffle):
import skimage.io
import skimage.transform
import numpy as np
from pandas import DataFrame
import os
from tensorflow.python.keras.utils.data_utils import Sequence
class MySequence(Sequence):
def __init__(
self,
directory,
batch_size=32,
shuffle=True
):
self.batch_size = batch_size
categories = list(os.listdir(directory)) # 查看目录下有那些文件夹, 作为分类名
indexes = [] # 保存索引(文件路径)
labels = [] # 保存标签
for category in categories:
files = [os.sep.join([directory, category, f]) for f in os.listdir(directory + os.sep + category)]
indexes += files
labels += [category] * len(files)
self.df = DataFrame({"label": labels}, index=indexes)
self.indices = self.df.index.tolist() # 将 DataFrame 中的文件路径索引导出, 后面要用
self.shuffle = shuffle
self.num_classes = len(categories)
self.on_epoch_end() # 每个 epoch 结束都需要重新打乱数据
self.category_to_id = {c: i for i, c in enumerate(categories)} # 保存类别和类别ID的映射关系
def on_epoch_end(self):
self.index = np.arange(len(self.indices))
if self.shuffle:
np.random.shuffle(self.index)
def __len__(self):
return len(self.indices) // self.batch_size
def __getitem__(self, index):
"""生成一批数据"""
# 选择对应的一批数据的索引(数字)
index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
# 找到上述索引对应的文件路径
batch = [self.indices[k] for k in index] # 取出文件名
# 按照文件路径生成数据
X, y = self.__get_data(batch)
return X, y
def __get_data(self, batch):
# 按照文件路径生成数据
X = batch # 文件路径
y = [self.df.at[b, 'label'] for b in batch] # 对应的标签
# 从文件读取图像并进行变换
for i, file_name in enumerate(X):
im = skimage.io.imread(file_name, as_gray=True) # 读取图像, 如果是彩色图像要把 as_gray 去掉
# 通过四周填充黑色, 调整图片大小
im = skimage.transform.resize(im, output_shape=(200, 200), mode='constant', cval=0)
# 旋转 45 度的倍数
im = skimage.transform.rotate(im, np.random.randint(8) * 45)
X[i] = im # 将图像保存到 X 中, scikit-image 对图像的操作结果都是 numpy 矩阵
X = np.array(X)
y = [self.category_to_id[t] for t in y] # 将标签转换成序号
# 下面选择是用序号还是转换成 one-hot 形式
y = np.array(y) # 直接用序号表示分类, loss 要用 SparseCategoricalCrossentropy
# y = np.eye(self.num_classes)[y] # 转换成 one-hot 形式, loss 要用 CategoricalCrossentropy
return X, y
return MySequence(directory=r'F:CompetionprojectPETdataimagestrain_crop', batch_size=32, shuffle=shuffle)
train_crop 文件夹下面有两个文件夹, CN 和 AD, 代表图像的分类, 文件夹下面是具体的图片.
下面是用一个只有全连接层的模型进行测试. 模型非常简单, 只有一个全连接层, 主要目的是检查自定义的 Generator 能不能使用.
def simple_model():
from tensorflow import keras
from tensorflow.keras import models
from tensorflow.keras import layers
model = models.Sequential()
model.add(layers.Flatten(input_shape=[200,200]))
model.add(layers.Dense(2))
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.SGD(learning_rate=1e-3),
metrics=['accuracy'])
model.summary()
history = model.fit(custom_generator(shuffle=True), epochs=10, validation_data=custom_generator(shuffle=False))
最后
以上就是如意蜜蜂为你收集整理的TensorFlow 2 自定义生成图像的 Generator的全部内容,希望文章能够帮你解决TensorFlow 2 自定义生成图像的 Generator所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复