我是靠谱客的博主 震动鸡翅,最近开发中收集的这篇文章主要介绍keras加载数据集较大怎么处理model.fit_generator(),觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

1.model.fit_generator()

官方代码中的解释如下:

Fits the model on data generated batch-by-batch by a Python generator.
The generator is run in parallel to the model, for efficiency.
For instance, this allows you to do real-time data augmentation
on images on CPU in parallel to training your model on
GPU

大概的意思就
允许你用CPU生成器批量生成数据,再送到GPU中去训练。并且,CPU和GPU是并行的,提高效率!!
因为你数据太大,一次性全部送到GPU去,会非常占显卡内存,而且可能内存泄漏啥的,就是不建议
一次全送进去,就是不要使用model.fit()去训练

2.使用方法

2.1 生成器的建立(下面的程序好像有问题…)

注意steps_per_epoch的值和batch_size有关

官方解释:

steps_per_epoch:
Total number of steps (batches of samples) to yield from `generator` before declaring one epoch finished and starting the next epoch. It should typically be equal to the number of samples of your dataset divided by the batch size.

batch_size = 数据集大小/steps_per_epoch的

batch_size = 32
(x_train, y_train), (x_test, y_test) = train_test_split(X,y,.....)
"""
下面的batch的值就是从0到len()之间,每次产生一个batch_size批量的数据
"""
def generator():
while 1:
batch = np.random.randint(0,len(x_train),size=batch_size)
num = np.zeros(shape=(batch_size,64,64,8))
y = np.zeros((batch_size,))
x = x_train[batch]
y = y_train[batch]
yield x,y
# generator()
history = model.fit_generator(generator(),epochs=1,steps_per_epoch=len(x_train)//(batch_size))
print(model.evaluate(x_test,y_test))
y = model.predict_classes(x_test)
print(accuracy_score(y_test,y))

3.总结

没事多看看源码,官方解释很清楚的,比那些什么博客上写的好很多,因为博客可能还是错误的…

最后

以上就是震动鸡翅为你收集整理的keras加载数据集较大怎么处理model.fit_generator()的全部内容,希望文章能够帮你解决keras加载数据集较大怎么处理model.fit_generator()所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(35)

评论列表共有 0 条评论

立即
投稿
返回
顶部