我是靠谱客的博主 还单身黑米,最近开发中收集的这篇文章主要介绍tensorflow_2.2_Resnet50实现花的识别Resnet50介绍1. 代码演示2. 主函数3. 预测图片,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

Resnet50介绍

Resnet50与之前在Resnet34中介绍的几乎一样,唯一有区别的就是:
残差块由两层卷积变成了三层卷积,网络更深,如下:

# 结构快
def block(x, filters, strides=1, conv_short=True):
    if conv_short:
        short_cut = Conv2D(filters=filters*4, kernel_size=1, strides=strides, padding='valid')(x)
        short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
    else:
        short_cut = x

    # 三层卷积
    x = Conv2D(filters=filters, kernel_size=1, strides=strides, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters*4, kernel_size=1, strides=1, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)

    x = Add()([x, short_cut])
    x = Activation('relu')(x)

    return x

直接进入代码演示

1. 代码演示

新建train.py

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, ZeroPadding2D, Conv2D, MaxPool2D, GlobalAvgPool2D, Input, BatchNormalization, Activation, Add
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input

# 结构快
def block(x, filters, strides=1, conv_short=True):
    if conv_short:
        short_cut = Conv2D(filters=filters*4, kernel_size=1, strides=strides, padding='valid')(x)
        short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
    else:
        short_cut = x

    # 三层卷积
    x = Conv2D(filters=filters, kernel_size=1, strides=strides, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters*4, kernel_size=1, strides=1, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)

    x = Add()([x, short_cut])
    x = Activation('relu')(x)

    return x

    
def Resnet50(inputs, classes):
    x = ZeroPadding2D((3, 3))(inputs)
    x = Conv2D(filters=64, kernel_size=7, strides=2, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)
    x = ZeroPadding2D((1, 1))(x)
    x = MaxPool2D(pool_size=3, strides=2, padding='valid')(x)

    x = block(x, filters=64, strides=1, conv_short=True)
    x = block(x, filters=64, conv_short=False)
    x = block(x, filters=64, conv_short=False)

    x = block(x, filters=128, strides=2, conv_short=True)
    x = block(x, filters=128, conv_short=False)
    x = block(x, filters=128, conv_short=False)
    x = block(x, filters=128, conv_short=False)

    x = block(x, filters=256, strides=2, conv_short=True)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)

    x = block(x, filters=512, strides=2, conv_short=True)
    x = block(x, filters=512, conv_short=False)
    x = block(x, filters=512, conv_short=False)

    x = GlobalAvgPool2D()(x)
    x = Dense(classes, activation='softmax')(x)
    
    return x
def data_process_func(datasets):
	# 数据预处理
	# ---------------------------------- #
    #   训练集进行的数据增强操作
    #   1. rotation_range -> 随机旋转角度
    #   2. width_shift_range -> 随机水平平移
    #   3. width_shift_range -> 随机数值平移
    #   4. rescale -> 数据归一化
    #   5. shear_range -> 随机错切变换
    #   6. zoom_range -> 随机放大
    #   7. horizontal_flip -> 水平翻转
    #   8. brightness_range -> 亮度变化
    #   9. fill_mode -> 填充方式
    # ---------------------------------- #
	train_data = ImageDataGenerator(
        rotation_range=20, 
        width_shift_range=0.1, 
        height_shift_range=0.1,
        rescale=1/255.0,
        shear_range=10,
        zoom_range=0.1,
        horizontal_flip=True,
        brightness_range=(0.7, 1.3),
        fill_mode='nearest'
    )
    # ---------------------------------- #
    #   测试集数据增加操作
    #   归一化即可
    # ---------------------------------- #
    test_data = ImageDataGenerator(
        rescale=1/255
    )
    # ---------------------------------- #
    #   训练器生成器
    #   测试集生成器
    # ---------------------------------- #
    train_generator = train_data.flow_from_directory(
        f'{datasets}/train',
        target_size=(224, 224),
        batch_size=8
    )
    test_generator = test_data.flow_from_directory(
        f'{datasets}/test',
        target_size=(224, 224),
        batch_size=8
    )
# 学习率调整
def adjust_lr(epoch, lr=1e-3):
    print("Seting to %s" % (lr))
    if epoch < 6:
        return lr
    else:
        return lr * 0.93

2. 主函数

  • 设置数据集路径datasets
    链接:https://pan.baidu.com/s/1zs9U76OmGAIwbYr91KQxgg
    提取码:bhjx
  • 设置预训练权重路径weight
    链接:https://pan.baidu.com/s/1AhsAA8ww5GurK-pWNQ4aHg
    提取码:y1c4
  • 注意:新建一个logs文件夹
if __name__ == '__main__':
    datasets = './dataset/data_flower'
    weight = './model_data/test_acc0.860-val_loss0.557-resnet50-flower.h5'
    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    inputs = Input(shape=(img_size,img_size,3))
    # 构造器
    train_generator, test_generator = data_process_func(datasets)
    model = Model(inputs=inputs, outputs=Resnet50(inputs=inputs, classes=classes))
    callbackss = [
            EarlyStopping(monitor='val_loss', patience=10, verbose=1),
            ModelCheckpoint('logs/ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',monitor='val_loss',
                            save_weights_only=True, save_best_only=False, period=1),
            LearningRateScheduler(adjust_lr)
        ]
    print('---------->loding weight--------->')
            model.load_weights(weight, by_name=True, skip_mismatch=True)
            model.compile(optimizer=Adam(lr=1e-3), loss='categorical_crossentropy', metrics=['accuracy'])
            history = model.fit(
            x                      = train_generator,
            validation_data        = test_generator,
            workers                = 1,
            epochs                 = epochs,
            callbacks              = callbackss
        )

3. 预测图片

新建predict.py

import tensorflow as tf
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense, ZeroPadding2D, Conv2D, MaxPool2D, GlobalAvgPool2D, Input, BatchNormalization, Activation, Add


# 结构快
def block(x, filters, strides=1, conv_short=True):
    if conv_short:
        short_cut = Conv2D(filters=filters*4, kernel_size=1, strides=strides, padding='valid')(x)
        short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
    else:
        short_cut = x

    # 三层卷积
    x = Conv2D(filters=filters, kernel_size=1, strides=strides, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters*4, kernel_size=1, strides=1, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)

    x = Add()([x, short_cut])
    x = Activation('relu')(x)

    return x

    
def Resnet50(inputs, classes):
    x = ZeroPadding2D((3, 3))(inputs)
    x = Conv2D(filters=64, kernel_size=7, strides=2, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)
    x = ZeroPadding2D((1, 1))(x)
    x = MaxPool2D(pool_size=3, strides=2, padding='valid')(x)

    x = block(x, filters=64, strides=1, conv_short=True)
    x = block(x, filters=64, conv_short=False)
    x = block(x, filters=64, conv_short=False)

    x = block(x, filters=128, strides=2, conv_short=True)
    x = block(x, filters=128, conv_short=False)
    x = block(x, filters=128, conv_short=False)
    x = block(x, filters=128, conv_short=False)

    x = block(x, filters=256, strides=2, conv_short=True)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)

    x = block(x, filters=512, strides=2, conv_short=True)
    x = block(x, filters=512, conv_short=False)
    x = block(x, filters=512, conv_short=False)

    x = GlobalAvgPool2D()(x)
    x = Dense(classes, activation='softmax')(x)
    
    return x

# 这次的权重路径指向训练后之后的路径
names = os.listdir('./dataset/data_flower/test')
weight = './model_data/test_acc0.860-val_loss0.557-resnet50-flower.h5'
net = Resnet50
classes = 17
img_size = 224

def preprocess_input(x):
    x /= 255
   
    return x

def cvtColor(image):
    if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
        return image 
    else:
        image = image.convert('RGB')
        return image 

inputs = Input(shape=(img_size,img_size,3))
model = Model(inputs=inputs, outputs=net(inputs=inputs, classes=classes))
model.load_weights(weight)

while True:
    
    img_path = input('input img_path:')
    try:
        img = Image.open(img_path)
        img = cvtColor(img)
        img = img.resize((224, 224))
        image_data = np.expand_dims(preprocess_input(np.array(img, np.float32)), 0)
    except:
        print('The path is error!')
        continue
    else:
        plt.imshow(img)
        plt.axis('off')
        p =model.predict(image_data)[0]
        pred_name = names[np.argmax(p)]
        plt.title('%s:%.3f'%(pred_name, np.max(p)))
        plt.show()

效果如下:
是flower0的概率为0.801
在这里插入图片描述

最后

以上就是还单身黑米为你收集整理的tensorflow_2.2_Resnet50实现花的识别Resnet50介绍1. 代码演示2. 主函数3. 预测图片的全部内容,希望文章能够帮你解决tensorflow_2.2_Resnet50实现花的识别Resnet50介绍1. 代码演示2. 主函数3. 预测图片所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(48)

评论列表共有 0 条评论

立即
投稿
返回
顶部