我是靠谱客的博主 繁荣砖头,最近开发中收集的这篇文章主要介绍Tensorflow简单项目实战——cats_vs_dogs一、数据集准备二、训练三、全部代码,觉得挺不错的,现在分享给大家,希望可以做个参考。
概述
目录
一、数据集准备
二、训练
三、全部代码
一、数据集准备
采用Tensorflow官方数据集cats_vs_dogs
相比于mnist,图像变为RGB格式,并且数据集中的图片尺寸不统一,因此需要在训练前修改尺寸大小。另外它没有测试集,所以采用训练集进行不同的划分,分别用于训练和测试
载入数据集依然采用Tensorflow-datasets进行
import tensorflow_datasets as tfds
#载入数据集,猫0狗1 23,262
def load_datasets(visualization = False):
#加载数据集,利用as_supervised以二元组形式返回,
(cats_vs_dogs_train, cats_vs_dogs_test), cats_vs_dogs_info = tfds.load(name='cats_vs_dogs', split=['train[:80%]', 'train[80%:]'], as_supervised=True, with_info=True)
if visualization:
tfds.show_examples(cats_vs_dogs_train,cats_vs_dogs_info)
return (cats_vs_dogs_train, cats_vs_dogs_test), cats_vs_dogs_info
函数的visualization参数是为了是否显示图片,方便查看数据集图片的样子。0~80%的train用于训练,80%~100%的train用于测试.
载入数据集
# 载入数据集
(datasets_train, datasets_test), datasets_info = load_datasets(visualization=False)
由于数据集图片的尺寸不一,因此需要统一修改尺寸,利用tf.image.resize函数
tf.image.resize(
images, size, method=ResizeMethod.BILINEAR, preserve_aspect_ratio=False,
antialias=False, name=None
)
利用tf.data.map来实现数据集操作
def re_shape(image, label):
image = tf.image.resize(image,[64,64])
return image,label
# 规定图片大小
datasets_train = datasets_train.map(map_func=re_shape)
datasets_test = datasets_test.map(map_func=re_shape)
并进行归一化
def uint8tofloat32(image, label):
return tf.cast(image, tf.float32) / 255, 0, label
# 进行格式转化,cats_vs_dogs格式为uint8,要将其转化为float32
datasets_train = datasets_train.map(map_func=uint8tofloat32)
datasets_test = datasets_test.map(map_func=uint8tofloat32)
其实这一步可以不用tf.cast实现格式转换,因为在上一步修改尺寸的时候,如果method只要不是NEAREST,都可以将数据转化为float32型,不过多写一遍也没事
然后就是缓存->打乱(仅训练集需要)->batch->prefetch
# 缓存到内存中
datasets_train = datasets_train.cache()
datasets_test = datasets_test.cache()
# 打乱数据集,仅训练集
datasets_train = datasets_train.shuffle(buffer_size=datasets_info.splits['train'].num_examples, seed=1)
# batch
datasets_train = datasets_train.batch(batch_size=128)
datasets_test = datasets_test.batch(batch_size=128)
# prefetch
datasets_train = datasets_train.prefetch(tf.data.experimental.AUTOTUNE)
datasets_test = datasets_test.prefetch(tf.data.experimental.AUTOTUNE)
二、训练
定义网络模型,采用继承tf.keras.Model方式构建
class Mymodel(tf.keras.Model):
def __init__(self):
super().__init__()
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(units=16, activation='relu')
self.dense2 = tf.keras.layers.Dense(units=2, activation='sigmoid')
def call(self, inputs):
x = self.flatten(inputs)
x = self.dense1(x)
output = self.dense2(x)
return output
实例化后,利用compile配置,并使用fit进行训练,最后打印网络信息
# 实例化网络
model = Mymodel()
model.compile(optimizer='Adam',loss='mse',metrics=['accuracy'])
model.fit(datasets_train, epochs=10, validation_data=datasets_test)
model.summary()
三、全部代码
datasets.py 文件
import tensorflow_datasets as tfds
#载入数据集,猫0狗1 23,262
def load_datasets(visualization = False):
#加载数据集,利用as_supervised以二元组形式返回,
(cats_vs_dogs_train, cats_vs_dogs_test), cats_vs_dogs_info = tfds.load(name='cats_vs_dogs', split=['train[:80%]', 'train[80%:]'], as_supervised=True, with_info=True)
if visualization:
tfds.show_examples(cats_vs_dogs_train,cats_vs_dogs_info)
return (cats_vs_dogs_train, cats_vs_dogs_test), cats_vs_dogs_info
Model.py
import tensorflow as tf
class Mymodel(tf.keras.Model):
def __init__(self):
super().__init__()
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(units=16, activation='relu')
self.dense2 = tf.keras.layers.Dense(units=2, activation='sigmoid')
def call(self, inputs):
x = self.flatten(inputs)
x = self.dense1(x)
output = self.dense2(x)
return output
main.py
import tensorflow as tf
from datasets import load_datasets
import numpy as np
import matplotlib.pyplot as plt
from Model import Mymodel
# 指定随机
tf.random.set_seed(1)
np.random.seed(1)
def uint8tofloat32(image, label):
return tf.cast(image, tf.float32) / 255, 0, label
def re_shape(image, label):
image = tf.image.resize(image,[64,64])
return image,label
# GPU使用
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
# 载入数据集
(datasets_train, datasets_test), datasets_info = load_datasets(visualization=False)
# 规定图片大小
datasets_train = datasets_train.map(map_func=re_shape)
datasets_test = datasets_test.map(map_func=re_shape)
# 进行格式转化,cats_vs_dogs格式为uint8,要将其转化为float32
datasets_train = datasets_train.map(map_func=uint8tofloat32)
datasets_test = datasets_test.map(map_func=uint8tofloat32)
# 缓存到内存中
datasets_train = datasets_train.cache()
datasets_test = datasets_test.cache()
# 打乱数据集,仅训练集
datasets_train = datasets_train.shuffle(buffer_size=datasets_info.splits['train'].num_examples, seed=1)
# batch
datasets_train = datasets_train.batch(batch_size=128)
datasets_test = datasets_test.batch(batch_size=128)
# prefetch
datasets_train = datasets_train.prefetch(tf.data.experimental.AUTOTUNE)
datasets_test = datasets_test.prefetch(tf.data.experimental.AUTOTUNE)
# 实例化网络
model = Mymodel()
model.compile(optimizer='Adam',loss='mse',metrics=['accuracy'])
model.fit(datasets_train, epochs=10, validation_data=datasets_test)
model.summary()
最后
以上就是繁荣砖头为你收集整理的Tensorflow简单项目实战——cats_vs_dogs一、数据集准备二、训练三、全部代码的全部内容,希望文章能够帮你解决Tensorflow简单项目实战——cats_vs_dogs一、数据集准备二、训练三、全部代码所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复