我是靠谱客的博主 腼腆鞋垫,最近开发中收集的这篇文章主要介绍MAML-RL Pytorch 代码解读 (17) -- maml_rl/metalearner.pyMAML-RL Pytorch 代码解读 (17) – maml_rl/metalearner.py,觉得挺不错的,现在分享给大家,希望可以做个参考。
概述
MAML-RL Pytorch 代码解读 (17) – maml_rl/metalearner.py
文章目录
- MAML-RL Pytorch 代码解读 (17) -- maml_rl/metalearner.py
- 基本介绍
- 源码链接
- 文件路径
- `import` 包
- `make_env()` 函数
- `BatchSampler()` 类
基本介绍
在网上看到的元学习 MAML 的代码大多是跟图像相关的,强化学习这边的代码比较少。
因为自己的思路跟 MAML-RL 相关,所以打算读一些源码。
MAML 的原始代码是基于 tensorflow 的,在 Github 上找到了基于 Pytorch 源码包,学习这个包。
源码链接
https://github.com/dragen1860/MAML-Pytorch-RL
文件路径
./maml_rl/sampler.py
import
包
import
gym
import
torch
import
multiprocessing as mp
from
maml_rl.envs.subproc_vec_env import SubprocVecEnv
from
maml_rl.episode import BatchEpisodes
make_env()
函数
#### 对gym库的.make(env_name)做了一个简单的包装。
def make_env(env_name):
"""
return a function
:param env_name:
:return:
"""
def _make_env():
return gym.make(env_name)
return _make_env
BatchSampler()
类
class BatchSampler:
def __init__(self, env_name, batch_size, num_workers=mp.cpu_count()):
"""
:param env_name:
:param batch_size: fast batch size
:param num_workers:
"""
#### 将环境名字self.env_name、一批任务的数量self.batch_size和能参与工作的cpu数量self.num_workers初始化,初始化多线程队列。
self.env_name = env_name
self.batch_size = batch_size
self.num_workers = num_workers
self.queue = mp.Queue()
#### 对于self.num_workers数量做迭代,也就是为每个线程开辟一个环境。最后通过列表的形式存储到env_factorys变量中。
# [lambda function]
env_factorys = [make_env(env_name) for _ in range(num_workers)]
#### 创建父进程,用于管理self.num_workers数量的线程。最后再创建一个环境,应该是用于内环更新结束后的用于测试的环境。
# this is the main process manager, and it will be in charge of num_workers sub-processes interacting with environment.
self.envs = SubprocVecEnv(env_factorys, queue_=self.queue)
self._env = gym.make(env_name)
def sample(self, policy, params=None, gamma=0.95, device='cpu'):
"""
:param policy:
:param params:
:param gamma:
:param device:
:return:
"""
#### 创建一个批处理实例。现在队列中加入批任务大小的数字,然后再加入self.num_workers数量的None这样应该是做一个标志。
episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device)
for i in range(self.batch_size):
self.queue.put(i)
for _ in range(self.num_workers):
self.queue.put(None)
#### 对所有环境做初始化命令。得到每个子线程的观测和任务号。标记done"是否完成"为否。
observations, batch_ids = self.envs.reset()
dones = [False]
#### 如果所有队列都没有完成"not all(dones)"且队列没有空,就说明还有队列。
while (not all(dones)) or (not self.queue.empty()): # if all done and queue is empty
# for reinforcement learning, the forward process requires no-gradient
#### 接下来做的是强化学习执行任务过程。因为这本身是输出结果,是前馈过程,那么就不需要导数。
with torch.no_grad():
# convert observation to cuda
# compute policy on cuda
# convert action to cpu
#### 经典强化学习过程。先得到观测向量,然后获得动作张量,再转成动作array。
observations_tensor = torch.from_numpy(observations).to(device=device)
# forward via policy network
# policy network will return Categorical(logits=logits)
actions_tensor = policy(observations_tensor, params=params).sample()
actions = actions_tensor.cpu().numpy()
#### 最后执行step()函数,得到新的观测、奖励、是否完成信息以及新的批任务号。最后将这些加入episodes的经验池子中。最后做一个更新。
new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(actions)
# here is observations NOT new_observations, batch_ids NOT new_batch_ids
episodes.append(observations, actions, rewards, batch_ids)
observations, batch_ids = new_observations, new_batch_ids
return episodes
#### 重置任务进行新的回合。
def reset_task(self, task):
tasks = [task for _ in range(self.num_workers)]
reset = self.envs.reset_task(tasks)
return all(reset)
#### 通过各种分布获得一批任务。
def sample_tasks(self, num_tasks):
tasks = self._env.unwrapped.sample_tasks(num_tasks)
return tasks
最后
以上就是腼腆鞋垫为你收集整理的MAML-RL Pytorch 代码解读 (17) -- maml_rl/metalearner.pyMAML-RL Pytorch 代码解读 (17) – maml_rl/metalearner.py的全部内容,希望文章能够帮你解决MAML-RL Pytorch 代码解读 (17) -- maml_rl/metalearner.pyMAML-RL Pytorch 代码解读 (17) – maml_rl/metalearner.py所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复