我是靠谱客的博主 彪壮鞋垫,最近开发中收集的这篇文章主要介绍tf2训练训练,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

训练

这里提供一个简单的利用tensorflow训练的例子。

模型

调用keras的mobilenetv2 api,去掉其顶层,使用其特征提取部分,然后重新设定顶层部分。选择使用预训练权重可以降低我们的训练难度。
有一个比较细节的地方是,这里将归一化直接集成在网络上,这样的好处是:一方面不需要在art上进行设定,另一方面,在制作数据集时也不用进行归一化的预处理(如果忘了归一化会导致模型不收敛,预训练权重对归一化也会有要求,这样相当于减少出问题时需要排错的事项)。因为在这里已经进行了归一化,所以在art上分类函数的参数需要加上scale=1,offset=0。

def Mobilenet_v2(input_size,weights,Dropout_rate,Trainable,alpha = 0.35):
    base_model = keras.applications.MobileNetV2(
        input_shape=(input_size, input_size, 3),
        alpha=alpha,
        weights=weights,
        include_top=False
    )
    inputs = keras.Input(shape=(input_size, input_size,3))

    scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
    x = scale_layer(inputs)
    x = base_model(x, training=False)
    if Trainable:
        base_model.trainable =True
    else:
        base_model.trainable = False
        print("特征层已冻结")
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(Dropout_rate, name='Dropout')(x)
    outputs = keras.layers.Dense(15, activation='softmax')(x)
    model = keras.Model(inputs, outputs)
    return model

数据集

这里的数据集直接使用的是官方数据集,没有独立的验证集,直接从训练集中划分。
增强使用的是tensorflow自带的,这里仅进行了旋转,缩放等增强。
这样的增强显然是不够的,有关增强的方法会在后边单独讲。

from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_root = "./train/"

train_generator = ImageDataGenerator(rotation_range=360,
                                     zoom_range  =0.2,
                                     horizontal_flip = True,
                                     validation_split =0.2
                                     )
train_dataset = train_generator.flow_from_directory(batch_size=batch_size,
                                                    directory=train_root,
                                                    shuffle=True,
                                                    target_size=(input_size,input_size),
                                                    subset='training')
valid_dataset = train_generator.flow_from_directory(batch_size=batch_size,
                                                   directory=train_root,
                                                    shuffle=True,
                                                    target_size=(input_size,input_size),
                                                    subset='validation')
print(train_dataset.class_indices)

需要注意,数据集的目录结构需如下所示,不要像逐飞那样搞分大类小类的二级目录:
在这里插入图片描述

回调

在这里我设置了3个回调,其中学习降低回调是比较重要的。一开始的学习率在训练到后期可能是不太合适的,会引起震荡,可以通过降低学习率来避免这种情况,通常通过降低学习率会使得模型的精度有少量提升。
你也可以使用LearningRateScheduler回调对学习率进行动态的调整。这也是一个常用的方法。

reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,patience=10,verbose=1)
early_stop =keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10,verbose=1)
save_weights = keras.callbacks.ModelCheckpoint(save_path + "/model_{epoch:02d}_{val_accuracy:.4f}.h5",
                                                   save_best_only=True, monitor='val_accuracy')

参考文档

在学习深度学习框架时,看官方文档是非常有用的,这里列举两个经常会用到的,或者说,看这两个就够了。
Keras API reference
tensorflow
其中也提供了一些入门的教程,跟着走一遍会对框架有一定的了解。

最后

以上就是彪壮鞋垫为你收集整理的tf2训练训练的全部内容,希望文章能够帮你解决tf2训练训练所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(40)

评论列表共有 0 条评论

立即
投稿
返回
顶部