概述
Tensorflow : 1.9.0
import tensorflow as tf
import keras
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import keras.backend as K
import os
from keras.layers import Dense,Dropout
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/',one_hot=True)
if not os.path.exists('out/'):
os.makedirs('out/')
#画图方法
def plot(samples):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
return fig
#输入维度
input_dim = 784
#隐变量z的维度
z_dim = 2
#学习率
learning_rate = 0.001
#minibatch的大小
batch_size = 250
#定义占位符
X = tf.placeholder(tf.float32, shape=[None, input_dim])
z = tf.placeholder(tf.float32, shape=[None, z_dim])
X /= 255.
#编码层
encode = Dense(512,activation='relu',
kernel_initializer=keras.initializers.TruncatedNormal())(X)
encode = Dense(256,activation='relu',
kernel_initializer=keras.initializers.TruncatedNormal())(encode)
encode = Dropout(rate=0.3)(encode)
mu = Dense(z_dim,
kernel_initializer=keras.initializers.TruncatedNormal())(encode)
logvar = Dense(z_dim,
kernel_initializer=keras.initializers.TruncatedNormal())(encode)
#计算Z
def get_z(z_mu,z_logvar):
#标准正态分布
eps = tf.random_normal(shape=tf.shape(z_mu))
#返回Z的分布
return mu + tf.exp(z_logvar / 2) * eps
z = get_z(mu,logvar)
decode = Dense(256,activation='relu',
kernel_initializer=keras.initializers.TruncatedNormal())(z)
decode = Dense(512,activation='relu',
kernel_initializer=keras.initializers.TruncatedNormal())(decode)
decode = Dropout(rate=0.3)(decode)
x_samples = Dense(784,activation='sigmoid')(decode)
#重构损失,比较还原的图像与原图像的损失
BCE = K.sum(K.binary_crossentropy(X,x_samples),axis=1)
#KL散度
KLD = 0.5*K.sum(tf.pow(mu, 2) + K.exp(logvar) - 1 - logvar,1)
# VAE loss
vae_loss = K.mean(KLD + BCE)
train = tf.train.AdamOptimizer(learning_rate).minimize(vae_loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
i = 10
for it in range(5001):
batch_xs, _ = mnist.train.next_batch(batch_size)
_, loss = sess.run([train, vae_loss], feed_dict={X: batch_xs})
if it % 1000 == 0:
print('Iter: {}'.format(it))
print('Loss: {:.4}'. format(loss))
print('-----------------------')
samples = sess.run(x_samples, feed_dict={z:np.random.randn(16, z_dim)})
fig = plot(samples)
plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
i += 1
plt.close(fig)
最后
以上就是诚心帅哥为你收集整理的18.Keras与Tensorflow混用(再战VAE)的全部内容,希望文章能够帮你解决18.Keras与Tensorflow混用(再战VAE)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复