构建Keras模型的不同方法
序贯模型:最简单的列表,本质上是Python列表,通过层的简单堆叠构成
函数式API:专注于类似图的模型结构,灵活性可用性都很强,是构建模型最常用的API
模型子类化:从最底层开始自己编写
1 构建序贯模型
1
2
3
4
5
6
7
8
9
10
11from tensorflow import keras from tensorflow import layers model = keras.Sequential([ layers.Dense(64, activation="relu"), layers.Dense(10, activation="softmax") ]) // 以上为第一种方法构建,直接在创建模型时完成所有工作 model = keras.Sequential() model.add(layers.Dense(64, activation="relu")) model.add(layers.Dense(10, activation="softmax")) // 以上为第二种方法,逐渐添加层
但在此时,模型尚未被赋予权重,我们只有在第一次调用了这个层结构以后才会出现对应参数的权重,也就是说先让层接收到“输入数据的形状”
1model.build(input_shape=(None, 3))
通过该行代码,我们输入了任意批量大小,样本形状为(3,)的数据,从此刻开始,模型就会带有参数
如有需要我们可以在构建每一层添加一个参数为每一层进行命名
通过在模型建立前提前声明输入形状,我们就可以使得模型在添加层的过程中,实时拥有自己的参数及对应权重
序贯模型的适用范围非常有限,只能表示具有单一输入和单一输出的模型,但我们在实际作业中常常遇到的都是多输入(例如图像)和多输出的模型,这种情况下,我们使用函数式API来建立模型就会取得更好的效果
2-1 构建函数API模型---单个输入单个输出
1
2
3
4
5inputs = keras.Input(shape=(3,), name="my_input") features = layers.Dense(64, activation="relu")(inputs) outputs = layers.Dense(10, activation="softmax")(features) model = keras.Model(inputs=inputs, outputs=outputs)
1)首先,我们什么了一个input,包含了输入数据的形状和数据类型信息
我们将该变量称为符号张量,其不包含实际数据
2)接着我们创建一个层,并在输入上调用该层的内容,keras创建的每一个层都可以调用与上方代码类似的符号张量,在调用符号张量时,层返回的内容是新的符号张量(其包含更新后的形状和数据类型信息)
当然也可以直接调用数据张量
3)得到输出后,我们在Model构造函数中指定输入输出,将模型实例化
但是在大部分深度学习的任务中,模型的结构更加贴近图而非列表,也就是说模型可能有多个输入或多个输出,在这种模型上,函数API才真正表现出色
下面引入一个书上的案例:
按优先级对客户支持工单进行排序,并将工单转给相应的部门,假设我们有三个输入:
① 工单标题(文本输入)
② 工单的文本正文(文本输入)
③ 用户添加的标签(分类输入,假定为one-hot编码)
我们将文本输入编码为0和1组成的数组,数组大小为vocabulary_size
模型有两个输出:
① 工单的优先级分数,它是介于0,1之间的标量(sigmoid输出)
② 应处理工单的部门(softmax)
2-2-1 构建函数API模型---多个输入多个输出(模型构建)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18vocabulary_size = 10000 num_tags = 100 num_departments = 4 // 定义模型输入 title = keras.Input(shape=(vocabulary_size,), name="title") text_body = keras.Input(shape=(vocabulary_size,) name="text_body") tags = keras.Input(shape=(num_tags,), name="tags") // 通过拼接将输入特征组合成张量features features = layers.Concatenate()([title, text_body, tags]) features = layers.Dense(64, activation="relu")(features) // 定义模型输出 priority = layers.Dense(1, activation="sigmoid", name="priority")(features) department = layers.Dense( num_departments, activation="softmax", name="priority")(features) // 通过指定输入和输出来创建模型 model = keras.Model(inputs=[title, text_body, tags], outputs=[priority, department])
2-2-2 构建函数API模型---多个输入多个输出(通过给定输入和目标组成的列表来训练模型)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20import numpy as np num_samples = 1280 // 虚构的输入数据 title_data = np.random.randint(0, 2, size=(num_samples, vobulary_size)) text_body_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size)) tags_data = np.random.randint(0, 2, size=(num_samples, num_tags)) // 虚构的目标数据 priority_data = np.random.random(size=(num_samples, 1)) department_data = np.random.randint(0, 2, size=(num_samples, num_departments)) model.compile(optimizer="rmsprop", loss=["mean_squared_error", "categorical_crossentropy"], metrics=[["mean_absolute_error"],["accuracy"]]) model.fit([title_data, text_body_data, tags_data], [priority_data, department_data], epochs=1) model.evaluate([title_data, text_body_data, tags_data], [priority_data, department_data]) priority_preds, department_preds = model.predict( [title_data, text_body_data, tags_data])
2-2-3 构建函数API模型---多个输入多个输出(通过给定输入和目标组成的字典来训练模型)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15model.compile(optimizer="rmsprop", loss={"priority": "mean_squared_error", "department": "categorical_crossentropy"}, metrics={"priority": ["mean_absolute_error"], "department": ["accuracy"]}) model.fit({"title": title_data, "text_data": text_body_data, "tags": tags_data}, {"priority": priority_data, "department": department_data}, epochs=1) model.evaluate({"title": title_data, "text_body": text_body_data, "tags": tags_data}, {"priority": priority_data, "department": department_data}) priority_preds, department_preds = model.predict( {"title": title_data, "text_body": text_body_data, "tags":tags_data})
如果想要可视化最终的模型结构,我们可以在keras中调用相应的函数进行
1
2
3
4
5//单纯查看模型整体结构 keras.util.plot_model(model, "ticket_classifier.png) //在查看模型结构的基础上显示模型各个部分的形状信息 keras.util.plot_model( model, "ticket_classifier_with_shape_info.png", show_shapes=True)
最后
以上就是激昂画笔最近收集整理的关于梅飞飞飞的假期学习日记DAY8的全部内容,更多相关梅飞飞飞内容请搜索靠谱客的其他文章。
发表评论 取消回复