我是靠谱客的博主 灵巧香氛,最近开发中收集的这篇文章主要介绍mnist_cnn训练保存模型然后去识别手写数字,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

mnist是很多人入门机器/深度学习的入门数据集,但是只是用来测试模型和入门学习,而忽略了mnist是一个非常好的数字识别的库。

那么我使用一个非常简单,大概5-6层卷积+池化再加几层全连接的结构来训练一下mnist,然后保存下模型,当我想识别一个字符的时候就可以直接读取这个模型,然后识别这个字符了。

首先是网络模型

net =slim.repeat(net,1,slim.conv2d, 32, [3, 3], scope = 'conv1')
net = slim.max_pool2d(net,[3,3],scope ='pool1',stride = 2)
'''
14*14*32
'''
net = slim.repeat(net, 1, slim.conv2d, 64, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [3, 3], scope='pool2',stride = 2)
'''
7*7*64
'''
net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [3, 3], scope='pool3',stride = 2,padding="VALID")
'''
4*4*128
'''
net = slim.repeat(net, 1, slim.conv2d, 256, [3, 3], scope='conv4')
'''
4*4*256
'''
net = slim.flatten(net, scope='flatten')
net = slim.dropout(net, keep_prob=0.8,
is_training=self._is_training)
net = slim.fully_connected(net, 1024, scope='fc1')
net = slim.fully_connected(net, 64, scope='fc2')
net = slim.fully_connected(net, self.num_classes,
activation_fn=None, scope='fc3')

然后定义输入的张量的shape是[None,784],标签是[None],然后将这个输入的tensor转化一下shape,转化成可以进行卷积操作的shape

inputs = tf.placeholder(tf.float32, shape=[None, 784], name='inputs')
labels = tf.placeholder(tf.int32, shape=[None], name='labels')
cls_model = model_mnist.Model(is_training=True, num_classes=10)
image = tf.reshape(inputs,[-1,28,28,1])

然后识别的时候将图片转化为[1,784]的格式,一次识别一张的话。

import numpy as np
import tensorflow as tf
import cv2
import os
import time
model_ckpt_path = "D:/all_model/mnist_model/model.ckpt"
def main(_):
with tf.Session() as sess:
ckpt_path = model_ckpt_path
saver = tf.train.import_meta_graph(ckpt_path + '.meta')
saver.restore(sess, ckpt_path)
inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
classes = tf.get_default_graph().get_tensor_by_name('classes:0')
image = cv2.imread("D:/5.jpg", cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (28, 28))
image_np = np.resize(image,[1,784])
predicted_label = sess.run(classes, feed_dict={inputs: image_np})
print(predicted_label)
if __name__ == '__main__':
tf.app.run()

 

最后

以上就是灵巧香氛为你收集整理的mnist_cnn训练保存模型然后去识别手写数字的全部内容,希望文章能够帮你解决mnist_cnn训练保存模型然后去识别手写数字所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(55)

评论列表共有 0 条评论

立即
投稿
返回
顶部