我是靠谱客的博主 腼腆鞋垫,这篇文章主要介绍MAML-RL Pytorch 代码解读 (17) -- maml_rl/metalearner.pyMAML-RL Pytorch 代码解读 (17) – maml_rl/metalearner.py,现在分享给大家,希望可以做个参考。
MAML-RL Pytorch 代码解读 (17) – maml_rl/metalearner.py
文章目录
- MAML-RL Pytorch 代码解读 (17) -- maml_rl/metalearner.py
- 基本介绍
- 源码链接
- 文件路径
- `import` 包
- `make_env()` 函数
- `BatchSampler()` 类
基本介绍
在网上看到的元学习 MAML 的代码大多是跟图像相关的,强化学习这边的代码比较少。
因为自己的思路跟 MAML-RL 相关,所以打算读一些源码。
MAML 的原始代码是基于 tensorflow 的,在 Github 上找到了基于 Pytorch 源码包,学习这个包。
源码链接
https://github.com/dragen1860/MAML-Pytorch-RL
文件路径
./maml_rl/sampler.py
import
包
复制代码
1
2
3
4
5
6
7
8
9
10
11import gym import torch import multiprocessing as mp from maml_rl.envs.subproc_vec_env import SubprocVecEnv from maml_rl.episode import BatchEpisodes
make_env()
函数
复制代码
1
2
3
4
5
6
7
8
9
10
11#### 对gym库的.make(env_name)做了一个简单的包装。 def make_env(env_name): """ return a function :param env_name: :return: """ def _make_env(): return gym.make(env_name) return _make_env
BatchSampler()
类
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66class BatchSampler: def __init__(self, env_name, batch_size, num_workers=mp.cpu_count()): """ :param env_name: :param batch_size: fast batch size :param num_workers: """ #### 将环境名字self.env_name、一批任务的数量self.batch_size和能参与工作的cpu数量self.num_workers初始化,初始化多线程队列。 self.env_name = env_name self.batch_size = batch_size self.num_workers = num_workers self.queue = mp.Queue() #### 对于self.num_workers数量做迭代,也就是为每个线程开辟一个环境。最后通过列表的形式存储到env_factorys变量中。 # [lambda function] env_factorys = [make_env(env_name) for _ in range(num_workers)] #### 创建父进程,用于管理self.num_workers数量的线程。最后再创建一个环境,应该是用于内环更新结束后的用于测试的环境。 # this is the main process manager, and it will be in charge of num_workers sub-processes interacting with environment. self.envs = SubprocVecEnv(env_factorys, queue_=self.queue) self._env = gym.make(env_name) def sample(self, policy, params=None, gamma=0.95, device='cpu'): """ :param policy: :param params: :param gamma: :param device: :return: """ #### 创建一个批处理实例。现在队列中加入批任务大小的数字,然后再加入self.num_workers数量的None这样应该是做一个标志。 episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device) for i in range(self.batch_size): self.queue.put(i) for _ in range(self.num_workers): self.queue.put(None) #### 对所有环境做初始化命令。得到每个子线程的观测和任务号。标记done"是否完成"为否。 observations, batch_ids = self.envs.reset() dones = [False] #### 如果所有队列都没有完成"not all(dones)"且队列没有空,就说明还有队列。 while (not all(dones)) or (not self.queue.empty()): # if all done and queue is empty # for reinforcement learning, the forward process requires no-gradient #### 接下来做的是强化学习执行任务过程。因为这本身是输出结果,是前馈过程,那么就不需要导数。 with torch.no_grad(): # convert observation to cuda # compute policy on cuda # convert action to cpu #### 经典强化学习过程。先得到观测向量,然后获得动作张量,再转成动作array。 observations_tensor = torch.from_numpy(observations).to(device=device) # forward via policy network # policy network will return Categorical(logits=logits) actions_tensor = policy(observations_tensor, params=params).sample() actions = actions_tensor.cpu().numpy() #### 最后执行step()函数,得到新的观测、奖励、是否完成信息以及新的批任务号。最后将这些加入episodes的经验池子中。最后做一个更新。 new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(actions) # here is observations NOT new_observations, batch_ids NOT new_batch_ids episodes.append(observations, actions, rewards, batch_ids) observations, batch_ids = new_observations, new_batch_ids return episodes #### 重置任务进行新的回合。 def reset_task(self, task): tasks = [task for _ in range(self.num_workers)] reset = self.envs.reset_task(tasks) return all(reset) #### 通过各种分布获得一批任务。 def sample_tasks(self, num_tasks): tasks = self._env.unwrapped.sample_tasks(num_tasks) return tasks
最后
以上就是腼腆鞋垫最近收集整理的关于MAML-RL Pytorch 代码解读 (17) -- maml_rl/metalearner.pyMAML-RL Pytorch 代码解读 (17) – maml_rl/metalearner.py的全部内容,更多相关MAML-RL内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复