我是靠谱客的博主 直率板栗,最近开发中收集的这篇文章主要介绍【深度强化学习】8. DDPG算法及部分代码解析,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

【DataWhale打卡】DDPG算法 Deep Deterministric Policy Gradient

视频参考自:https://www.bilibili.com/video/BV1yv411i7xd?p=19

1、思维导图

2. 详解

DDPG是解决连续性控制问题的一个算法,但是和PPO不同,PPO输出是一个策略,是一个概率分布。而DDPG输出的是一个动作。

DDPG是采用的也是Actor-Critic架构,是基于DQN进行改进的。DQN中的action space必须是离散的,所以不能处理连续的动作空间的问题。DDPG在其基础上进行改动,引入了一个Actor Network,让一个网络来的输出来得到连续的动作空间。

对比ACDDPG
Actor输出的是概率分布输出是动作
Critic预估V值预估Q值
更新带权重梯度更新梯度上升

优化Q网络的时候,如果Q-target也在不停的变动,那就会造成更新困难。类似DQN,DDPG也采取了固定网络结构的方法,先冻结target网络,更新参数以后,再把参数赋值到target网络。所以需要的是四个网络:

  • actor
  • critic
  • target actor
  • target critic

通过上图可以看出,DDPG(也是一种Actor-Critic方法),其实也是一种时序差分的方法,结合了基于Value-based和Policy-Based方法。其中Policy是Actor,用于给出动作;价值函数是Critic,评价Actor给出的Action的好坏,产生时序差分信号用于指导价值函数和策略函数的更新。

3. 代码

代码主要看DDPG算法主要几个模块:

3.1 背景

DDPG这里要解决的问题是一个钟摆问题,Pendulum-v0。这个版本的问题中,钟摆以随机位置开始,目标是将其向上摆动,使其保持直立。这是一个连续控制的问题。

状态表示:

动作空间:

奖励评估:
− ( θ 2 + 0.1 ∗ θ d t 2 + 0.001 ∗ a c t i o n 2 ) -(theta^2 + 0.1*theta_{dt}^2 + 0.001*action^2) (θ2+0.1θdt2+0.001action2)
可以看出,目标就是保持零角度,也就是垂直,同时要求旋转速度最小,力度最小。

3.2 Actor

Actor作用是接收状态描述,输出一个action,由于DDPG中的动作空间要求是连续的,所以使用了一个tanh

class Actor(nn.Module):
    def __init__(self, n_obs, n_actions, hidden_size, init_w=3e-3):
        super(Actor, self).__init__()  
        self.linear1 = nn.Linear(n_obs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, n_actions)
        
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.tanh(self.linear3(x))
        return x

实现方面,就是用了几个全连接层来设计的网络,输出的结果是一个连续的值。

3.3 Critic

Critic批评者,在DDPG中,接受来自Actor的一个Action值和当前的状态,输出的是当前状态下,采用Action动作以后得到的关于Q的期望。

class Critic(nn.Module):
    def __init__(self, n_obs, n_actions, hidden_size, init_w=3e-3):
        super(Critic, self).__init__()
        
        self.linear1 = nn.Linear(n_obs + n_actions, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, 1)
        # 随机初始化为较小的值
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, state, action):
        # 按维数1拼接
        x = torch.cat([state, action], 1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x

3.4 Replay Buffer

Replay Buffer就是用来存储一系列等待学习的SARS片段。

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
    
    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = map(np.stack, zip(*batch))
        return state_batch, action_batch, reward_batch, next_state_batch, done_batch
    
    def __len__(self):
        return len(self.buffer)

可以设置Replay Buffer的容量,push函数是向buffer中添加一个SARS片段;sample代表从buffer中采样batch size个片段。

3.5 DDPG

DDPG用到了以上的所有对象,包括Critic、Target Critic、Actor、Target Actor、memory。

init函数如下:

def __init__(self, n_states, n_actions, hidden_dim=30, device="cpu", critic_lr=1e-3,
                actor_lr=1e-4, gamma=0.99, soft_tau=1e-2, memory_capacity=100000, batch_size=128):
    self.device = device
    
    self.critic = Critic(n_states, n_actions, hidden_dim).to(device)
    self.actor = Actor(n_states, n_actions, hidden_dim).to(device)

    self.target_critic = Critic(n_states, n_actions, hidden_dim).to(device)
    self.target_actor = Actor(n_states, n_actions, hidden_dim).to(device)

    for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):
        target_param.data.copy_(param.data)
    for target_param, param in zip(self.target_actor.parameters(), self.actor.parameters()):
        target_param.data.copy_(param.data)

    self.critic_optimizer = optim.Adam(
        self.critic.parameters(),  lr=critic_lr)
        
    self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
    
    self.memory = ReplayBuffer(memory_capacity)

    self.batch_size = batch_size
    self.soft_tau = soft_tau
    self.gamma = gamma

其中核心的函数就是update函数:

def update(self):
    if len(self.memory) < self.batch_size:
        return
    state, action, reward, next_state, done = self.memory.sample(
        self.batch_size)
    # 将所有变量转为张量
    state = torch.FloatTensor(state).to(self.device)
    next_state = torch.FloatTensor(next_state).to(self.device)
    action = torch.FloatTensor(action).to(self.device)
    reward = torch.FloatTensor(reward).unsqueeze(1).to(self.device)
    done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(self.device)
    # 注意critic将(s_t,a)作为输入
    policy_loss = self.critic(state, self.actor(state))
    
    policy_loss = -policy_loss.mean()

    next_action = self.target_actor(next_state)
    target_value = self.target_critic(next_state, next_action.detach())
    expected_value = reward + (1.0 - done) * self.gamma * target_value
    expected_value = torch.clamp(expected_value, -np.inf, np.inf)

    value = self.critic(state, action)
    value_loss = nn.MSELoss()(value, expected_value.detach())
    
    self.actor_optimizer.zero_grad()
    policy_loss.backward()
    self.actor_optimizer.step()

    self.critic_optimizer.zero_grad()
    value_loss.backward()
    self.critic_optimizer.step()
    for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):
        target_param.data.copy_(
            target_param.data * (1.0 - self.soft_tau) +
            param.data * self.soft_tau
        )
    for target_param, param in zip(self.target_actor.parameters(), self.actor.parameters()):
        target_param.data.copy_(
            target_param.data * (1.0 - self.soft_tau) +
            param.data * self.soft_tau
        )

整体流程如下:

  • 从memory中采样一个batch的数据。
  • policy_loss = self.critic(state, self.actor(state))
    • 将state放到actor对象得到action
    • 将state,action放到critic对象得到policy loss
next_action = self.target_actor(next_state)
target_value = self.target_critic(next_state, next_action.detach())
  • 然后target actor和target critic也按照以上过程得到target value
  • 根据target value 计算expected value:

r + γ Q r+gamma Q r+γQ

实现如下:

expected_value = reward + (1.0 - done) * self.gamma * target_value
expected_value = torch.clamp(expected_value, -np.inf, np.inf)

如果done为1,代表已经结束了,也就不需要这个系数了。第二行对expected value进行了数值上的限制。

  • 接下来计算根据数据集中action得到的value值。
value = self.critic(state, action)
  • 计算优化Q网络的loss, 采用的是MSEloss
value_loss = nn.MSELoss()(value, expected_value.detach())

对比下图:

  • 对policy loss和value loss进行梯度回传,更新训练参数。

训练结果如下:

4. 参考文献

代码部分全部来自于johnjim的实现,感谢。

https://www.jianshu.com/p/af3a7853268f

https://datawhalechina.github.io/leedeeprl-notes/#/chapter12/project3

https://www.bilibili.com/video/BV1yv411i7xd?p=19

最后

以上就是直率板栗为你收集整理的【深度强化学习】8. DDPG算法及部分代码解析的全部内容,希望文章能够帮你解决【深度强化学习】8. DDPG算法及部分代码解析所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(59)

评论列表共有 0 条评论

立即
投稿
返回
顶部