我是靠谱客的博主 繁荣砖头,最近开发中收集的这篇文章主要介绍Tensorflow简单项目实战——cats_vs_dogs一、数据集准备二、训练三、全部代码,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

目录

 

一、数据集准备

二、训练

三、全部代码


一、数据集准备

 采用Tensorflow官方数据集cats_vs_dogs

相比于mnist,图像变为RGB格式,并且数据集中的图片尺寸不统一,因此需要在训练前修改尺寸大小。另外它没有测试集,所以采用训练集进行不同的划分,分别用于训练和测试

载入数据集依然采用Tensorflow-datasets进行

import tensorflow_datasets as tfds

#载入数据集,猫0狗1 23,262
def load_datasets(visualization = False):
    #加载数据集,利用as_supervised以二元组形式返回,
    (cats_vs_dogs_train, cats_vs_dogs_test), cats_vs_dogs_info = tfds.load(name='cats_vs_dogs', split=['train[:80%]', 'train[80%:]'], as_supervised=True, with_info=True)
    if visualization:
        tfds.show_examples(cats_vs_dogs_train,cats_vs_dogs_info)

    return (cats_vs_dogs_train, cats_vs_dogs_test), cats_vs_dogs_info

函数的visualization参数是为了是否显示图片,方便查看数据集图片的样子。0~80%的train用于训练,80%~100%的train用于测试.

载入数据集

# 载入数据集
(datasets_train, datasets_test), datasets_info = load_datasets(visualization=False)

由于数据集图片的尺寸不一,因此需要统一修改尺寸,利用tf.image.resize函数

tf.image.resize(
    images, size, method=ResizeMethod.BILINEAR, preserve_aspect_ratio=False,
    antialias=False, name=None
)

利用tf.data.map来实现数据集操作

def re_shape(image, label):
    image = tf.image.resize(image,[64,64])
    return image,label

# 规定图片大小
datasets_train = datasets_train.map(map_func=re_shape)
datasets_test = datasets_test.map(map_func=re_shape)

并进行归一化

def uint8tofloat32(image, label):
    return tf.cast(image, tf.float32) / 255, 0, label

# 进行格式转化,cats_vs_dogs格式为uint8,要将其转化为float32
datasets_train = datasets_train.map(map_func=uint8tofloat32)
datasets_test = datasets_test.map(map_func=uint8tofloat32)

其实这一步可以不用tf.cast实现格式转换,因为在上一步修改尺寸的时候,如果method只要不是NEAREST,都可以将数据转化为float32型,不过多写一遍也没事

然后就是缓存->打乱(仅训练集需要)->batch->prefetch

# 缓存到内存中
datasets_train = datasets_train.cache()
datasets_test = datasets_test.cache()

# 打乱数据集,仅训练集
datasets_train = datasets_train.shuffle(buffer_size=datasets_info.splits['train'].num_examples, seed=1)

# batch
datasets_train = datasets_train.batch(batch_size=128)
datasets_test = datasets_test.batch(batch_size=128)

# prefetch
datasets_train = datasets_train.prefetch(tf.data.experimental.AUTOTUNE)
datasets_test = datasets_test.prefetch(tf.data.experimental.AUTOTUNE)

二、训练

定义网络模型,采用继承tf.keras.Model方式构建

class Mymodel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(units=16, activation='relu')
        self.dense2 = tf.keras.layers.Dense(units=2, activation='sigmoid')

    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.dense1(x)
        output = self.dense2(x)
        return output

实例化后,利用compile配置,并使用fit进行训练,最后打印网络信息

# 实例化网络
model = Mymodel()

model.compile(optimizer='Adam',loss='mse',metrics=['accuracy'])
model.fit(datasets_train, epochs=10, validation_data=datasets_test)
model.summary()

 

三、全部代码

datasets.py 文件

import tensorflow_datasets as tfds

#载入数据集,猫0狗1 23,262
def load_datasets(visualization = False):
    #加载数据集,利用as_supervised以二元组形式返回,
    (cats_vs_dogs_train, cats_vs_dogs_test), cats_vs_dogs_info = tfds.load(name='cats_vs_dogs', split=['train[:80%]', 'train[80%:]'], as_supervised=True, with_info=True)
    if visualization:
        tfds.show_examples(cats_vs_dogs_train,cats_vs_dogs_info)

    return (cats_vs_dogs_train, cats_vs_dogs_test), cats_vs_dogs_info


Model.py

import tensorflow as tf


class Mymodel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(units=16, activation='relu')
        self.dense2 = tf.keras.layers.Dense(units=2, activation='sigmoid')

    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.dense1(x)
        output = self.dense2(x)
        return output

main.py

import tensorflow as tf
from datasets import load_datasets
import numpy as np
import matplotlib.pyplot as plt
from Model import Mymodel

# 指定随机
tf.random.set_seed(1)
np.random.seed(1)


def uint8tofloat32(image, label):
    return tf.cast(image, tf.float32) / 255, 0, label

def re_shape(image, label):
    image = tf.image.resize(image,[64,64])
    return image,label
# GPU使用
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

# 载入数据集
(datasets_train, datasets_test), datasets_info = load_datasets(visualization=False)

# 规定图片大小
datasets_train = datasets_train.map(map_func=re_shape)
datasets_test = datasets_test.map(map_func=re_shape)

# 进行格式转化,cats_vs_dogs格式为uint8,要将其转化为float32
datasets_train = datasets_train.map(map_func=uint8tofloat32)
datasets_test = datasets_test.map(map_func=uint8tofloat32)

# 缓存到内存中
datasets_train = datasets_train.cache()
datasets_test = datasets_test.cache()

# 打乱数据集,仅训练集
datasets_train = datasets_train.shuffle(buffer_size=datasets_info.splits['train'].num_examples, seed=1)

# batch
datasets_train = datasets_train.batch(batch_size=128)
datasets_test = datasets_test.batch(batch_size=128)

# prefetch
datasets_train = datasets_train.prefetch(tf.data.experimental.AUTOTUNE)
datasets_test = datasets_test.prefetch(tf.data.experimental.AUTOTUNE)

# 实例化网络
model = Mymodel()

model.compile(optimizer='Adam',loss='mse',metrics=['accuracy'])
model.fit(datasets_train, epochs=10, validation_data=datasets_test)
model.summary()

 

最后

以上就是繁荣砖头为你收集整理的Tensorflow简单项目实战——cats_vs_dogs一、数据集准备二、训练三、全部代码的全部内容,希望文章能够帮你解决Tensorflow简单项目实战——cats_vs_dogs一、数据集准备二、训练三、全部代码所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(39)

评论列表共有 0 条评论

立即
投稿
返回
顶部