概述
1、引用Python库
import gym
import tensorflow as tf
import numpy as np
from ou_noise import OUNoise
from critic_network import CriticNetwork
from actor_network_bn import ActorNetwork
from replay_buffer import ReplayBuffer
2、定义参数
# Hyper Parameters:
REPLAY_BUFFER_SIZE = 1000000
REPLAY_START_SIZE = 10000
BATCH_SIZE = 64
GAMMA = 0.99
3、定义类
class DDPG:
"""docstring for DDPG"""
def __init__(self, env):
self.name = 'DDPG' # name for uploading results
self.environment = env
# Randomly initialize actor network and critic network
# with both their target networks
self.state_dim = env.observation_space.shape[0]
(以下函数均在类DDPG中定义)
3.1 初始化函数
def __init__(self, env):
self.name = 'DDPG' # name for uploading results
self.environment = env
# Randomly initialize actor network and critic network
# with both their target networks
self.state_dim = env.observation_space.shape[0]
self.action_dim = env.action_space.shape[0]
self.sess = tf.InteractiveSession()
self.actor_network = ActorNetwork(self.sess,self.state_dim,self.action_dim)
self.critic_network = CriticNetwork(self.sess,self.state_dim,self.action_dim)
# initialize replay buffer
self.replay_buffer = ReplayBuffer(REPLAY_BUFFER_SIZE)
# Initialize a random process the Ornstein-Uhlenbeck process for action exploration
self.exploration_noise = OUNoise(self.action_dim)
初始化了状态、动作的维度,actor_network,critic_network,经验池和ou-noise.
tf.InteractiveSession()
参考这里,在运行图的时候插入一些计算图,便于交互环境处理。
3.2 train()函数
def train(self):
#print "train step",self.time_step
# Sample a random minibatch of N transitions from replay buffer
minibatch = self.replay_buffer.get_batch(BATCH_SIZE)
state_batch = np.asarray([data[0] for data in minibatch])
action_batch = np.asarray([data[1] for data in minibatch])
reward_batch = np.asarray([data[2] for data in minibatch])
next_state_batch = np.asarray([data[3] for data in minibatch])
done_batch = np.asarray([data[4] for data in minibatch]) #从经验池中采样得到经验序列
# for action_dim = 1
action_batch = np.resize(action_batch,[BATCH_SIZE,self.action_dim])
# Calculate y_batch
next_action_batch = self.actor_network.target_actions(next_state_batch)
q_value_batch = self.critic_network.target_q(next_state_batch,next_action_batch)#q值通过target_critic网络计算(确定性策略梯度))
y_batch = []
for i in range(len(minibatch)):
if done_batch[i]:
y_batch.append(reward_batch[i])
else :
y_batch.append(reward_batch[i] + GAMMA * q_value_batch[i]) #通过经验池数据计算y值
y_batch = np.resize(y_batch,[BATCH_SIZE,1])
# Update critic by minimizing the loss L
self.critic_network.train(y_batch,state_batch,action_batch) #通过最小化二次方误差调整critic网络
# Update the actor policy using the sampled gradient:
action_batch_for_gradients = self.actor_network.actions(state_batch) #actor网络通过经验池中的state产生动作
q_gradient_batch = self.critic_network.gradients(state_batch,action_batch_for_gradients) #critic网络通过上述状态-动作对计算Q对于a的梯度
self.actor_network.train(q_gradient_batch,state_batch) #通过梯度和state调整actor网络
# Update the target networks
self.actor_network.update_target()
self.critic_network.update_target() #更新target网络
整个actor-critic的一次训练过程。(对照伪代码)
np.asarray
参考这里,将数据结构转化为ndarray.
np.resize
参考这里,对原始数组的维度进行修改并保留。
3.3 关于action
def noise_action(self,state):
# Select action a_t according to the current policy and exploration noise
action = self.actor_network.action(state)
return action+self.exploration_noise.noise()
返回一个带噪声(探索)的动作。随机性。(exploration_noise在前面定义了就是ou-noise)
def action(self,state):
action = self.actor_network.action(state)
return action
返回一个不带噪声的动作。确定性。
3.4 perceive()函数
def perceive(self,state,action,reward,next_state,done):
# Store transition (s_t,a_t,r_t,s_{t+1}) in replay buffer
self.replay_buffer.add(state,action,reward,next_state,done)
# Store transitions to replay start size then start training
if self.replay_buffer.count() > REPLAY_START_SIZE:
self.train()
#if self.time_step % 10000 == 0:
#self.actor_network.save_network(self.time_step)
#self.critic_network.save_network(self.time_step)
# Re-iniitialize the random process when an episode ends
if done:
self.exploration_noise.reset()
向经验池中存储数据,存满时开始训练。
最后
以上就是满意白猫为你收集整理的DDPG(6)_ddpg的全部内容,希望文章能够帮你解决DDPG(6)_ddpg所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复