概述
程序使用kaggle中的猫狗大战train data set,用来生成tfrecords数据。
- 原数据库共有25000张照片
- 使用原来的train数据库,然后分为train.tfrecords和test.tfrecords两个数据集。
- train.tfrecords包含23000张照片,test.tfrecords包含另外2000张照片
代码块
代码块语法遵循标准markdown代码,例如:
import os
import tensorflow as tf
from PIL import Image
cwd = os.getcwd() #返回当前进程的工作目录。
classes = ["cat", "dog"]
def create_record():
writer_train = tf.python_io.TFRecordWriter("train_227.tfrecords")
writer_test = tf.python_io.TFRecordWriter("test_227.tfrecords")
class_path = cwd + "/train/"
i = 0
img_names = os.listdir(class_path)
print(len(img_names))
for i in range(20):
print(img_names[i])
for img_name in img_names:
i += 1
animal = img_name.split(".")[0]
if animal == "cat":
index = 0
else:
index = 1
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((227,227))
img_raw = img.tobytes()
#print(index,img_name)
if i<=23000:
example_train = tf.train.Example(
features=tf.train.Features(feature={"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
"img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))})
)
writer_train.write(example_train.SerializeToString())
else:
example_test = tf.train.Example(
features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
"img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})
)
#print("start to write test dataset....")
writer_test.write(example_test.SerializeToString())
writer_train.close()
writer_test.close()
exit()
data = create_record()
最后
以上就是风趣冰淇淋为你收集整理的tensorflow使用猫狗大战数据库生成tfrecords数据的全部内容,希望文章能够帮你解决tensorflow使用猫狗大战数据库生成tfrecords数据所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复