我是靠谱客的博主 风趣冰淇淋,这篇文章主要介绍tensorflow使用猫狗大战数据库生成tfrecords数据,现在分享给大家,希望可以做个参考。

程序使用kaggle中的猫狗大战train data set,用来生成tfrecords数据。

  • 原数据库共有25000张照片
  • 使用原来的train数据库,然后分为train.tfrecords和test.tfrecords两个数据集。
  • train.tfrecords包含23000张照片,test.tfrecords包含另外2000张照片

代码块

代码块语法遵循标准markdown代码,例如:

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os import tensorflow as tf from PIL import Image cwd = os.getcwd() #返回当前进程的工作目录。 classes = ["cat", "dog"] def create_record(): writer_train = tf.python_io.TFRecordWriter("train_227.tfrecords") writer_test = tf.python_io.TFRecordWriter("test_227.tfrecords") class_path = cwd + "/train/" i = 0 img_names = os.listdir(class_path) print(len(img_names)) for i in range(20): print(img_names[i]) for img_name in img_names: i += 1 animal = img_name.split(".")[0] if animal == "cat": index = 0 else: index = 1 img_path = class_path + img_name img = Image.open(img_path) img = img.resize((227,227)) img_raw = img.tobytes() #print(index,img_name) if i<=23000: example_train = tf.train.Example( features=tf.train.Features(feature={"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), "img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}) ) writer_train.write(example_train.SerializeToString()) else: example_test = tf.train.Example( features=tf.train.Features(feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), "img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) }) ) #print("start to write test dataset....") writer_test.write(example_test.SerializeToString()) writer_train.close() writer_test.close() exit() data = create_record()

最后

以上就是风趣冰淇淋最近收集整理的关于tensorflow使用猫狗大战数据库生成tfrecords数据的全部内容,更多相关tensorflow使用猫狗大战数据库生成tfrecords数据内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(40)

评论列表共有 0 条评论

立即
投稿
返回
顶部