我是靠谱客的博主 漂亮绿茶,最近开发中收集的这篇文章主要介绍Tensorflow2.x的记录一、keras与eager执行模式二、Eager模式三、tensorflow高级API,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

一、keras与eager执行模式

eager执行模式可以仅仅通过组装层就可以定义、训练并评估模型

二、Eager模式

在tf2.x中,默认执行的是eager模式
1.eager模式返回的是tf.Tensorf对象,而静态图模式执行回话后,返回的是Numpy数组对象
2.在eager模式中,print返回的tf.Tensor对象与静态图的区别:
静态图返回例子的是:

1. tf.GradientTape()自动微分运算

2. 自定义训练循环train()

定制train_step函数用在训练循环中:

def train(num_classes, batch_size):
    """
    定制 训练循环
    :param num_classes:
    :return:
    """
    # 加载模型
    net = AlexNet(num_classes)
    net.build((4,224,224,3))
    # 构建数据集
    dataset = make_dataset(10, 5, 4)
    # 损失函数
    # loss = tf.losses.CategoricalCrossentropy(from_logits=False)     #标签是one-hot形式
    loss = tf.losses.SparseCategoricalCrossentropy(from_logits=False) #标签不是one-hot形式
    # 优化器
    optimizer = tf.optimizers.Adam()

    # 配置
    step = tf.Variable(1, dtype=tf.int32, name="global_step")
    accuracy = tf.metrics.Accuracy()

    mean_loss = tf.metrics.Mean(name='loss')
    @tf.function
    def train_step(inputs, labels):
        """
        用在训练循环中
        :return:
        """
        with tf.GradientTape() as tape:  # 记录节点运算的磁带
            logits = net(inputs)
            loss_val = loss(labels, logits)
        gradients = tape.gradient(loss_val, net.trainable_variables)

        optimizer.apply_gradients(zip(gradients, net.trainable_variables))

        step.assign_add(1)
        accuracy_val = accuracy(labels, tf.argmax(logits, -1))

        return loss_val,accuracy_val
    #
    @tf.function
    def loop():
        for train_x, label_y in dataset:
            loss_val, accuracy_val = train_step(train_x, label_y)
            if tf.equal(tf.math.mod(step,10), 0):
                tf.print(step, ": ", loss_val, " - accuracy: ",accuracy_val)

    loop()
if __name__ == '__main__':
    train(5)

在这里插入图片描述

三、tensorflow高级API

1.tf.data模块高性能数据输入流水线-ETL结构

dataset的作用
1.Create a source dataset from your input data.
2.Apply dataset transformations to preprocess the data.
3.Iterate over the dataset and process the elements.
(a)基本的流程如下:
在这里插入图片描述

生产数据:tf.data顺序数据输入流水线将原始数据转换成为有用的数据格式,这些操作是在CPU上进行的;
消费数据:在目标设备(GPU、TPU)上进行对CPU生产的数据进行计算处理。

其中,cache缓存元素和prefetch预取是性能优化,目的是为了解决目标设备出现利用率为0%的情况,移除目标设备处理数据远比生产数据快导致目标设备GPU等待CPU产生数据的瓶颈。

(b)构建自己的数据集dataset

import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist

def train_dataset(epoch_num, buffer_size, batch_size):
    (train_x, train_y), (test_x, test_y) = fashion_mnist.load_data()
    
	AUTOTUNE = tf.data.experimental.AUTOTUNE

    def process_fn(image, label):
        image = (tf.image.convert_image_dtype(image, tf.float32) - 0.5) * 2.
        return (image, label)
        
    dataset = tf.data.Dataset.from_tensors_slices((tf.expand_dims(train_x, -1), tf.expand_dims(train_y, -1)))
    #-------
    # map会把数据集的每一个元素element处理,生成新的、应用变换后的数据集。
    #-------
    dataset = dataset.map(process_fn, num_parallel_calls=AUTOTUNE).cache()  # 经过计算密集变换,复杂耗时的处理运算后,把数据缓存在内存中,从而加速接下来的运算。
    dataset = dataset.repeat(epoch_num).shuffle(buffer_size).batch(batch_size)
    dataset = dataset.prefetch(1)

    return dataset

tf.image模块包提供API用于对数据集扩充

定义一个函数, 使用dataset 的map方法将函数应用在数据集上

def process_fn(img_path, label):
    label = tf.one_hot(label, depth=class_num)
    image = tf.io.read_file(img_path)
    image = tf.image.decode_jpeg(image, 0)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32) # [0,1]
    
    image = tf.image.resize(image, (im_height,im_width))
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_hue(image)
    image = tf.image.random_brightness(image)
	
	image = tf.image.resize(image,(416, 416))
    
    return image, label

2.tf.estimator.Estimator估计器

估计器的作用:训练【train】、评估【evaluate】、预测【predict】
两个要点:model_fn函数、input_fn函数,它们都不是以eager模式执行的,估计器通过调用这两个函数切换到图模式
(a)定制估计器

tf.estimator.EstimatorSpec的作用是什么
定制训练和评估阶段,通过tf.estimator.EstimatorSpec对象定义input_fn函数

最后

以上就是漂亮绿茶为你收集整理的Tensorflow2.x的记录一、keras与eager执行模式二、Eager模式三、tensorflow高级API的全部内容,希望文章能够帮你解决Tensorflow2.x的记录一、keras与eager执行模式二、Eager模式三、tensorflow高级API所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(50)

评论列表共有 0 条评论

立即
投稿
返回
顶部