概述
大家先看底下对这个gan小项目来源和介绍,然后再看代码,copy下来可直接运行。中文注释是我自己理解的过程中加的,对好不了解gan但是想直接上手的同志比较友好,如果注释有不合理的地方欢迎指出~
这个小项目有利于不了解gan原理的同志迅速上手和了解gan原理
#-*-coding= utf-8 -*-
import tensorflow as tf
import numpy as np
#import matplotlib.pyplot as plt
from matplotlib import
pyplot as plt
tf.set_random_seed(1)
#设定图级别的随机数:为了使所有op产生的随机序列在会话之间是可重复的
# 一般这两句话同时使用,图级和操作seed都被设置:两个seed联合使用以确定随机序列。
np.random.seed(1)
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001
# learning rate for generator
LR_D = 0.0001
# learning rate for discriminator
N_IDEAS = 5
# think of this as number of ideas for generating an art work (Generator)
ART_COMPONENTS = 15
# it could be total point G can draw in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])
print(PAINT_POINTS[0]) #list,从-1开始到1的15个均匀分布的点组成的list,共Batch_size个这样的向量组成的矩阵
# show our beautiful painting range
# 画曲线,通过15个点做完平方操作得到的点,用曲线连接这些点
#plt.plot[x,y,color] #这个y可以传入一个关于x的函数
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
plt.legend(loc='upper right')#图例放在右上角,location,此时显示的第一个图图例显示的是两个label
plt.show()
def artist_works():
# painting from the famous artist (real target)
a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
#从1,2之间的均匀分布的随机采样,左闭右开,64个点,后面的np.newaxis是加了一维度,每个数就变成一个list,此时的a就是一个矩阵了
print (a)
paintings = a * np.power(PAINT_POINTS, 2) + (a-1)
return paintings
with tf.variable_scope('Generator'):
G_in = tf.placeholder(tf.float32, [None, N_IDEAS])
# random ideas (could from normal distribution)
G_l1 = tf.layers.dense(G_in, 128, tf.nn.relu)
G_out = tf.layers.dense(G_l1, ART_COMPONENTS)
# making a painting from these random ideas
print(G_in.shape)#(?,5)
print(G_l1.shape)#(?,128)
print(G_out.shape)#(?,15) 开始绘画,15个点
with tf.variable_scope('Discriminator'):
real_art = tf.placeholder(tf.float32, [None, ART_COMPONENTS], name='real_in')
# receive art work from the famous artist
D_l0 = tf.layers.dense(real_art, 128, tf.nn.relu, name='l')
# 全连接
prob_artist0 = tf.layers.dense(D_l0, 1, tf.nn.sigmoid, name='out')
# probability that the art work is made by artist
# reuse layers for generator
D_l1 = tf.layers.dense(G_out, 128, tf.nn.relu, name='l', reuse=True)
# receive art work from a newbie like G
prob_artist1 = tf.layers.dense(D_l1, 1, tf.nn.sigmoid, name='out', reuse=True)
# probability that the art work is made by artist
#这两个变量在上面定义过了,所以直接reuse来调用这个变量,其实求真实图片概率和假图片概率过的流程是一致的
D_loss = -tf.reduce_mean(tf.log(prob_artist0) + tf.log(1-prob_artist1))
#让真实样本的prob_artist0越大越好,生成数据越小越好,总体加个负号,即为符合loss函数越小越好的特性了
G_loss = tf.reduce_mean(tf.log(1-prob_artist1))
#生成后被判别成假的的概率越小越好
train_D = tf.train.AdamOptimizer(LR_D).minimize(
D_loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator'))
train_G = tf.train.AdamOptimizer(LR_G).minimize(
G_loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator'))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
plt.ion()
# something about continuous plotting,图片连续出现
for step in range(5000):
artist_paintings = artist_works()
# real painting from artist
G_ideas = np.random.randn(BATCH_SIZE, N_IDEAS) #初始化随机值,这个ideas是很形象的,标准正态分布的随机点
G_paintings, pa0, Dl = sess.run([G_out, prob_artist0, D_loss, train_D, train_G],
# train and get results
{G_in: G_ideas, real_art: artist_paintings})[:3]
if step % 50 == 0:
# plotting
plt.cla() #清除上一个图
#plt.plot[x,y,color] #这个y可以传入一个关于x的函数
#第一条即画出自己生成的线
plt.plot(PAINT_POINTS[0], G_paintings[0], c='#4AD631', lw=3, label='Generated painting',)
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % pa0.mean(), fontdict={'size': 15})
plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -Dl, fontdict={'size': 15})
plt.ylim((0, 3)); plt.legend(loc='upper right', fontsize=12); plt.draw(); plt.pause(0.01)
plt.ioff() #关闭图片连续出现
plt.show()
gan出来好几年了,进化版的gan也出了十来个了,最近才开始了解gan,有点out了。
代码来源:https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/406_GAN.py
这里要推荐一下简单易懂的莫烦python的博客,我在理解一个新的模型的时候往往先去看看莫烦python网站上看看有没有好理解的小项目。这个项目就来自莫烦python,https://morvanzhou.github.io/tutorials/machine-learning/ML-intro/2-6-GAN/
简单介绍一下这个项目:
真实的正确数据:满足a*X^2+(a-1)且1<a<2的曲线
生成器的任务是生成曲线,且不被判别器发现是伪造的假数据
损失函数
D_loss = -tf.reduce_mean(tf.log(prob_artist0) + tf.log(1-prob_artist1))
#让真实样本被判别器判断为真的概率prob_artist0越大越好,让生成器伪造的数据被判别器判断为真的概率越小越好,用1-prob_artist即越大越好,总体加个负号,即为符合loss函数越小越好的特性了
G_loss = tf.reduce_mean(tf.log(1-prob_artist1))
#生成的曲线被判别成假的的概率越小越好
初始状态:
可以看到当a取到2时得蓝色曲线,当a取到1时得橘色曲线,生成器的目标就是生成一个满足a*x^2+(a-1)且1<a<2的曲线,即最终训练好的生成器生成出来的曲线会在蓝线橘线之间且符合函数y=a*x^2+(a-1)
以下是生成器和判别器交替训练时的输出
有了拟合的趋势
拟合,项目结束
中文注释是我自己理解的过程中加的,对好不了解gan但是想直接上手的同志比较友好,如果注释有不合理的地方欢迎指出~
最后
以上就是陶醉绿茶为你收集整理的可直接运行的gan小项目(详细注释)结合代码通俗理解gan的原理的全部内容,希望文章能够帮你解决可直接运行的gan小项目(详细注释)结合代码通俗理解gan的原理所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复