我是靠谱客的博主 傲娇缘分,最近开发中收集的这篇文章主要介绍keras.callback fit_generator1.fit_generator2.fit_generator 训练逻辑过程,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

1.fit_generator

fit_generator(self, generator, samples_per_epoch, nb_epoch, verbose=1, callbacks=[], validation_data=None, nb_val_samples=None, class_weight=None, max_q_size=10)

函数的参数是:

generator:生成器函数,生成器的输出应该为:
一个形如(inputs,targets)的tuple

一个形如(inputs, targets,sample_weight)的tuple。所有的返回值都应该包含相同数目的样本。生成器将无限在数据集上循环。每个epoch以经过模型的样本数达到samples_per_epoch时,记一个epoch结束


def next_train(self):
while 1:
ret = self.get_batch(self.cur_train_index, self.minibatch_size, train=True)
self.cur_train_index += self.minibatch_size
if self.cur_train_index >= self.val_split:
self.cur_train_index = self.cur_train_index % 32
(self.X_text, self.Y_data, self.Y_len) = shuffle_mats_or_lists(
[self.X_text, self.Y_data, self.Y_len], self.val_split)
yield ret

这里写成了一个死循环while True,因为model.fit_generator()在使用在个函数的时候, 并不会在每一个epoch之后重新调用,那么如果这时候generator自己结束了就会有问题。

samples_per_epoch:整数,当模型处理的样本达到此数目时计一个epoch结束,执行下一个epoch

verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录

validation_data:具有以下三种形式之一

生成验证集的生成器

一个形如(inputs,targets)的tuple

一个形如(inputs,targets,sample_weights)的tuple

nb_val_samples:仅当validation_data是生成器时使用,用以限制在每个epoch结束时用来验证模型的验证集样本数,功能类似于samples_per_epoch

max_q_size:生成器队列的最大容量

函数返回一个History对象

2.fit_generator 训练逻辑过程

model.fit_generator 训练入口函数(参考上面的函数原型定义)


callbacks.on_train_begin()
while epoch < epochs:
callbacks.on_epoch_begin(epoch)
while steps_done < steps_per_epoch:
#generator_output是一个死循环while True,因为model.fit_generator()在使用在个函数的时候, 并不会在每一个epoch之后重新调用,那么如果这时候generator自己结束了就会有问题。
generator_output = next(output_generator)
#生成器next函数取输入数据进行训练,每次取一个batch大小的量
callbacks.on_batch_begin(batch_index, batch_logs)
outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
callbacks.on_batch_end(batch_index, batch_logs)
end of while steps_done < steps_per_epoch
self.evaluate_generator(...)
#当一个epoch的最后一次batch执行完毕,执行一次训练效果的评估	
callbacks.on_epoch_end(epoch, epoch_logs)
#在这个执行过程中实现模型数据的保存操作
end of while epoch < epochs
callbacks.on_train_end()
``
# 回调函数
通过传递回调函数列表到模型的.fit()中,即可在给定的训练阶段调用该函数集中的函数。eras的回调函数是一个类
```python
keras.callbacks.Callback()

这是回调函数的抽象类,定义新的回调函数必须继承自该类

3.类属性

params:字典,训练参数集(如信息显示方法verbosity,batch大小,epoch数)

model:keras.models.Model对象,为正在训练的模型的引用回调函数以字典logs为参数,该字典包含了一系列与当前batch或epoch相关的信息。

目前,模型的.fit()中有下列参数会被记录到logs中:

在每个epoch的结尾处(on_epoch_end),logs将包含训练的正确率和误差,acc和loss,如果指定了验证集,还会包含验证集正确率和误差val_acc)和val_loss,val_acc还额外需要在.compile中启用metrics=[‘accuracy’]。

在每个batch的开始处(on_batch_begin):logs包含size,即当前batch的样本数

在每个batch的结尾处(on_batch_end):logs包含loss,若启用accuracy则还包含acc

最后

以上就是傲娇缘分为你收集整理的keras.callback fit_generator1.fit_generator2.fit_generator 训练逻辑过程的全部内容,希望文章能够帮你解决keras.callback fit_generator1.fit_generator2.fit_generator 训练逻辑过程所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(49)

评论列表共有 0 条评论

立即
投稿
返回
顶部