我是靠谱客的博主 激昂画笔,最近开发中收集的这篇文章主要介绍梅飞飞飞的假期学习日记DAY8,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

构建Keras模型的不同方法

序贯模型:最简单的列表,本质上是Python列表,通过层的简单堆叠构成

函数式API:专注于类似图的模型结构,灵活性可用性都很强,是构建模型最常用的API

模型子类化:从最底层开始自己编写

1 构建序贯模型

from tensorflow import keras
from tensorflow import layers
model = keras.Sequential([
layers.Dense(64, activation="relu"),
layers.Dense(10, activation="softmax")
])
// 以上为第一种方法构建,直接在创建模型时完成所有工作
model = keras.Sequential()
model.add(layers.Dense(64, activation="relu"))
model.add(layers.Dense(10, activation="softmax"))
// 以上为第二种方法,逐渐添加层

但在此时,模型尚未被赋予权重,我们只有在第一次调用了这个层结构以后才会出现对应参数的权重,也就是说先让层接收到“输入数据的形状”

model.build(input_shape=(None, 3))

通过该行代码,我们输入了任意批量大小,样本形状为(3,)的数据,从此刻开始,模型就会带有参数

如有需要我们可以在构建每一层添加一个参数为每一层进行命名

通过在模型建立前提前声明输入形状,我们就可以使得模型在添加层的过程中,实时拥有自己的参数及对应权重

序贯模型的适用范围非常有限,只能表示具有单一输入和单一输出的模型,但我们在实际作业中常常遇到的都是多输入(例如图像)和多输出的模型,这种情况下,我们使用函数式API来建立模型就会取得更好的效果

2-1 构建函数API模型---单个输入单个输出

inputs = keras.Input(shape=(3,), name="my_input")
features = layers.Dense(64, activation="relu")(inputs)
outputs = layers.Dense(10, activation="softmax")(features)
model = keras.Model(inputs=inputs, outputs=outputs)

1)首先,我们什么了一个input,包含了输入数据的形状和数据类型信息

我们将该变量称为符号张量,其不包含实际数据

2)接着我们创建一个层,并在输入上调用该层的内容,keras创建的每一个层都可以调用与上方代码类似的符号张量,在调用符号张量时,层返回的内容是新的符号张量(其包含更新后的形状和数据类型信息)

当然也可以直接调用数据张量

3)得到输出后,我们在Model构造函数中指定输入输出,将模型实例化

但是在大部分深度学习的任务中,模型的结构更加贴近图而非列表,也就是说模型可能有多个输入或多个输出,在这种模型上,函数API才真正表现出色

下面引入一个书上的案例:

按优先级对客户支持工单进行排序,并将工单转给相应的部门,假设我们有三个输入:

① 工单标题(文本输入)

② 工单的文本正文(文本输入)

③ 用户添加的标签(分类输入,假定为one-hot编码)

我们将文本输入编码为0和1组成的数组,数组大小为vocabulary_size

模型有两个输出:

① 工单的优先级分数,它是介于0,1之间的标量(sigmoid输出)

② 应处理工单的部门(softmax)

2-2-1 构建函数API模型---多个输入多个输出(模型构建)

vocabulary_size = 10000
num_tags = 100
num_departments = 4
// 定义模型输入
title = keras.Input(shape=(vocabulary_size,), name="title")
text_body = keras.Input(shape=(vocabulary_size,) name="text_body")
tags = keras.Input(shape=(num_tags,), name="tags")
// 通过拼接将输入特征组合成张量features
features = layers.Concatenate()([title, text_body, tags])
features = layers.Dense(64, activation="relu")(features)
// 定义模型输出
priority = layers.Dense(1, activation="sigmoid", name="priority")(features)
department = layers.Dense(
num_departments, activation="softmax", name="priority")(features)
// 通过指定输入和输出来创建模型
model = keras.Model(inputs=[title, text_body, tags],
outputs=[priority, department])

2-2-2 构建函数API模型---多个输入多个输出(通过给定输入和目标组成的列表来训练模型)

import numpy as np
num_samples = 1280
// 虚构的输入数据
title_data = np.random.randint(0, 2, size=(num_samples, vobulary_size))
text_body_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
tags_data = np.random.randint(0, 2, size=(num_samples, num_tags))
// 虚构的目标数据
priority_data = np.random.random(size=(num_samples, 1))
department_data = np.random.randint(0, 2, size=(num_samples, num_departments))
model.compile(optimizer="rmsprop",
loss=["mean_squared_error", "categorical_crossentropy"],
metrics=[["mean_absolute_error"],["accuracy"]])
model.fit([title_data, text_body_data, tags_data],
[priority_data, department_data],
epochs=1)
model.evaluate([title_data, text_body_data, tags_data],
[priority_data, department_data])
priority_preds, department_preds = model.predict(
[title_data, text_body_data, tags_data])

2-2-3 构建函数API模型---多个输入多个输出(通过给定输入和目标组成的字典来训练模型)

model.compile(optimizer="rmsprop",
loss={"priority": "mean_squared_error", "department":
"categorical_crossentropy"},
metrics={"priority": ["mean_absolute_error"], "department":
["accuracy"]})
model.fit({"title": title_data, "text_data": text_body_data,
"tags": tags_data},
{"priority": priority_data, "department": department_data},
epochs=1)
model.evaluate({"title": title_data, "text_body": text_body_data,
"tags": tags_data},
{"priority": priority_data, "department": department_data})
priority_preds, department_preds = model.predict(
{"title": title_data, "text_body": text_body_data, "tags":tags_data})

如果想要可视化最终的模型结构,我们可以在keras中调用相应的函数进行

//单纯查看模型整体结构
keras.util.plot_model(model, "ticket_classifier.png)
//在查看模型结构的基础上显示模型各个部分的形状信息
keras.util.plot_model(
model, "ticket_classifier_with_shape_info.png", show_shapes=True)

最后

以上就是激昂画笔为你收集整理的梅飞飞飞的假期学习日记DAY8的全部内容,希望文章能够帮你解决梅飞飞飞的假期学习日记DAY8所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(72)

评论列表共有 0 条评论

立即
投稿
返回
顶部