我是靠谱客的博主 体贴美女,最近开发中收集的这篇文章主要介绍使用DQN解决cartpole问题(深度强化学习入门)使用DQN解决cartpole问题(深度强化学习入门),觉得挺不错的,现在分享给大家,希望可以做个参考。
概述
使用DQN解决cartpole问题(深度强化学习入门)
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 22 11:16:50 2021
@author: wss
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # 调用relu啥的
import collections
import random
import torch.optim as optim
#放一些参数
Lr = 0.1
#学习率
Buffer_size = 10000 #经验回放的buffer的大小
Eps = 0.1
# eps 贪心算法的随机选择比列
GAMMA = 0.99
# reward的衰减
#用队列存 transition
并定义了采样函数
Transition = collections.namedtuple('Transition',
('state', 'action', 'next_state', 'reward'))
# 用一个类来实现经验回放 去除state的相关性和利用经验
class ReplayMemory(object):
def __init__(self, capacity):
self.memory = collections.deque([],maxlen=capacity)
def push(self, *args):
"""Save a transition"""
self.memory.append(Transition(*args))
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
#定义DQN 的神经网络部分
class Net(nn.Module):
def __init__(self,n_in,n_hidden,n_out):
super(Net,self).__init__()
self.fc1 = nn.Linear(n_in, n_hidden)
self.fc2 = nn.Linear(n_hidden, n_hidden)
self.fc3 = nn.Linear(n_hidden, n_out)
def forward(self,x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
out = self.fc3(x)
return out
# net属性是神经网络的对象
class DQN(object):
def __init__(self,n_in,n_hidden,n_out):
#
super(DQN,self).__init__()
self.net = Net(n_in,n_hidden,n_out)
self.target_net = Net(n_in,n_hidden,n_out)
self.optimer = optim.Adam(self.net.parameters(),lr = Lr)
self.loss_func = nn.MSELoss()
self.target_net.load_state_dict(self.net.state_dict())
#
self.target_net.eval()
# 解决高估问题
不用训练直接加载policy_net的参数
self.buffer = ReplayMemory(Buffer_size)
#根据state选择 action
def select_action(self,state): #返回的action是个数字(不是张量)
threshold = random.random()
Q_actions = self.net(torch.Tensor(state)) #返回不同action对应的Q值
if
threshold<Eps
: #随机选择动作
return
np.random.randint(0,Q_actions.shape[0])
else:
return torch.argmax(Q_actions).numpy()
def update_param(self,batch_size):
if self.buffer.__len__() < batch_size:
return
transitions = self.buffer.sample(batch_size)
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
# detailed explanation). This converts batch-array of Transitions
# to Transition of batch-arrays.
batch = Transition(*zip(*transitions))
# 将tuple转化为numpy
tmp = np.vstack(batch.action)
# 转化成Tensor
state_batch = torch.Tensor(batch.state)
action_batch = torch.LongTensor(tmp.astype(int))
reward_batch = torch.Tensor(batch.reward)
next_state_batch = torch.Tensor(batch.next_state)
q_pred_s1 = torch.max(self.target_net(next_state_batch).detach(), dim=1,
keepdim=True)[0]
#max的值有两个返回值,分别是最大值 和 最大值所在的位置
##detach()是用来阻断梯度的
q_pred_s0 = self.net(state_batch).gather(1, action_batch)
##需要的q_pred_s0预测值是相对应的action的Q(s,a)
net()返回的是所有action对应的Q(s,a)
## 看看这个q_pred_s0的shape
#q_td_tar = reward_batch.unsqueeze(1) +
GAMMA * q_pred_s1
q_td_tar = reward_batch.unsqueeze(1) +
GAMMA * q_pred_s1
#
reward_batch 的shape为[100]
q_pred_s1的shape[100, 1]
#
需要把[100]
改成[100,1]
unsqueeze(1) 可以增加一个维度
loss = self.loss_func(q_pred_s0, q_td_tar)
# print(loss)
# Optimize the model
self.optimer.zero_grad()
loss.backward()
self.optimer.step()
if __name__ == '__main__':
import gym
num_episode =10000
batch_size = 32
target_update = 20 # target network 的更新频率(循环20次更新一次)
env = gym.make('CartPole-v0').unwrapped
#unwrapped解除次数限制
Agent = DQN(env.observation_space.shape[0], 256, env.action_space.n)
average_time =0 #记录一百轮的总时间
for i_episode in range(num_episode):
state = env.reset()
#每一轮循环就是一次游戏结束,就要重新开始
total_time =0 #记录一轮的时间
while True:
env.render()
action = Agent.select_action(state)# action 不是张量
next_state,reward,done,_=env.step(action)
total_time+=1
if done:
average_time +=total_time
#print('Episode ', i_episode, 'total_time: ', total_time)
break
Agent.buffer.push(state,action,next_state,reward)
state = next_state
##应该在buffer放满后就更新参数
Agent.update_param(batch_size)
if i_episode % target_update == 0:
Agent.target_net.load_state_dict(Agent.net.state_dict())
if (i_episode+1) % 100 == 0:
print("一百轮的平均时间",average_time/100)
average_time =0
print('Complete')
env.render()
env.close()
刚刚接触深度学习以及强化学习,不知道为什么这个DQN并没有随着训练越来越来越好?
最后
以上就是体贴美女为你收集整理的使用DQN解决cartpole问题(深度强化学习入门)使用DQN解决cartpole问题(深度强化学习入门)的全部内容,希望文章能够帮你解决使用DQN解决cartpole问题(深度强化学习入门)使用DQN解决cartpole问题(深度强化学习入门)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复