概述
num_classes = 10
class AnchorPositivePairs(keras.utils.Sequence):
def __init__(self, num_batchs):
self.num_batchs = num_batchs
def __len__(self):
return self.num_batchs
def __getitem__(self, _idx):
x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)
for class_idx in range(num_classes):
examples_for_class = class_idx_to_train_idxs[class_idx]
anchor_idx = random.choice(examples_for_class)
positive_idx = random.choice(examples_for_class)
while positive_idx == anchor_idx:
positive_idx = random.choice(examples_for_class)
x[0, class_idx] = x_train[anchor_idx]
x[1, class_idx] = x_train[positive_idx]
return x
Sequence类必须重载三个私有方法__init__、__len__和__getitem__,主要是__getitem__。__init__是构造方法,用于初始化数据的,只要能让样本数据集通过形参顺利传进来就行了。__len__基本上不用改写,用于计算样本数据长度。__getitem__用于生成批量数据,喂给神经网络模型训练用,其输出格式是元组。元组里面有两个元素,每个元素各是一个列表,第一个元素是batch data构成的列表,第二个元素是label构成的列表。在第一个列表中每个元素是一个batch,每个batch里面才是图像张量,所有的batch串成一个列表。第二个列表比较普通,每个元素是个实数,表示标签。这点跟生成器不一样,生成器是在执行时通过yield关键字把每个batch数据喂给模型,一次喂一个batch。__getitem__相当于生成器的作用,如同ImageDataGenerator,但注意编写方法时返回数据不要用yield,而要用return,像一个普通函数一样。至于后台怎么迭代调用这个生成器,这是keras后台程序处理好的事,我们不用担心。我这里程序假定样本数据已经转成pickle的形式,是字典构成的列表,字典的键值分别是feature和label,其中feature是图像张量。
最后
以上就是丰富裙子为你收集整理的keras.utils.Sequence类的全部内容,希望文章能够帮你解决keras.utils.Sequence类所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复