我是靠谱客的博主 孝顺背包,这篇文章主要介绍Dataloader重要参数与内部机制一、pytorch数据输入二、Dataloader参数汇总三、DataLoader的并行一、pytorch数据输入二、Dataloader参数汇总三、DataLoader的并行,现在分享给大家,希望可以做个参考。

文章目录

一、pytorch数据输入

  1. Dataset
  2. DataLoader

二、Dataloader参数汇总

2.1 sampler:分布式训练需DistributedSampler
2.2 collate_fn:将batch的数据重新组装
2.3 pin_memory=True:提高数据从cpu到gpu传输效率

三、DataLoader的并行

3.1 index_queue 要处理的数据下标
3.2 worker_result_queue 返回结果
参考文献

一、pytorch数据输入

Dataset负责生产数据,DataLoader负责数据的分批(batch_size)、采样(sampler)、传输
Pytorch版本:1.0.1

1. Dataset

继承torch.utils.data.Dataset,实现两个函数即可:

def len(self) 数据总数
def getitem(self, index) 根据下标获取其中一条数据

2. DataLoader

将Dataset作为参数,构造一个torch.utils.data.DataLoader对象即可。
DataLoader其他参数见下文。

二、Dataloader参数汇总

dataset(Dataset):
传入的数据集

batch_size(int, optional):
每个batch有多少个样本

shuffle(bool, optional):
在每个epoch开始的时候,对数据进行重新打乱

sampler(Sampler, optional):
自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

batch_sampler(Sampler, optional):
与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

num_workers (int, optional):
这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

collate_fn (callable, optional):
将一个list的sample组成一个mini-batch的函数

pin_memory (bool, optional):
如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

drop_last (bool, optional):
如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

timeout(numeric, optional):
如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

worker_init_fn (callable, optional):
每个worker初始化函数 If not None, this will be called on each
worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

2.1 sampler:分布式训练需DistributedSampler

train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)

DataLoader构造函数中相关代码:

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
if batch_sampler is None: if sampler is None: if shuffle: sampler = RandomSampler(dataset) ##如果shuffer就随机 else: sampler = SequentialSampler(dataset) ##否则顺序采样 batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler

batch_sampler是sampler的封装,可自定义批次数据的构造。默认BatchSampler相关源码:

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) ##遍历sampler获取数据,满batch_size就yield if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch

2.2 collate_fn:将batch的数据重新组装

例如cirtorch中将数据拆成input_data和target两个数据。
因Dataset中get_item返回input_data和target两个值,如果不用该函数,每个batch的数据应该是[batch_size,2(先input_data再target),],经过该函数将变成([batch_size,],[batch_size,]),第一个数据全是input_data,第二个数据全是target。

2.3 pin_memory=True:提高数据从cpu到gpu传输效率

pin_memory可在cpu主存(内存)中分配不可交换到swap(缓存)的内存。。默认内存分配中的数据都可交换到swap中,那CUDA驱动会通过DRAM机制将数据从内存传到GPU显存时会复制2次(先复制到一临时不可见pinned固定内存,再往显存中复制),因此pin_memory=True可提高约2倍cpu到gpu传输效率(.cuda()或 .to(device)的时候)。相见CPU和GPU内存交互。

【拓展】Elasticsearch中的Memlock(内存锁定)可申请固定大小且不可交换内存空间。

三、DataLoader的并行

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Our data model looks like this (queues are indicated with curly brackets): # # main process || # | || # {index_queue} || # | || # worker processes || DATA # | || # {worker_result_queue} || FLOW # | || # pin_memory_thread of main process || DIRECTION # | || # {data_queue} || # | || # data output / # # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if # `pin_memory=False`.

基于multiprocessing多进程
每个子进程的输入输出,通过两个主要的队列(multiprocessing.Queue()): index_queue要处理的下标、worker_result_queue要返回的下标。
每个worker一次产生一个batch的数据
返回batch数据前放入下一个批次数据下标
构造函数子进程初始化:

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
self.index_queues = [] self.workers = [] for i in range(self.num_workers): index_queue = multiprocessing.Queue() # 1.每个子进程一个队列放要处理的下标 index_queue.cancel_join_thread() w = multiprocessing.Process( target=_utils.worker._worker_loop, # 每个子进程循环执行的函数 args=(self.dataset, index_queue, self.worker_result_queue, self.done_event, #2.self.worker_result_queue 多子进程公用要返回batch数据的队列 self.collate_fn, base_seed + i, self.worker_init_fn, i)) w.daemon = True # NB: Process.start() actually take some time as it needs to # start a process and pass the arguments over via a pipe. # Therefore, we only add a worker to self.workers list after # it started, so that we do not call .join() if program dies # before it starts, and __del__ tries to join but will get: # AssertionError: can only join a started process. w.start() self.index_queues.append(index_queue) self.workers.append(w)

3.1 index_queue 要处理的数据下标

每个worker有一个index_queue dataloader.py#L544-L552
每个worker从index_queue取要处理的下标 dataloader.py#L124
dataloader输出一次数据前先往index_queue中放一次下标, _process_next_batch函数:

复制代码
1
2
3
4
5
6
7
8
9
10
def _process_next_batch(self, batch): self.rcvd_idx += 1 self._put_indices() ## 先放下一批数据下标 if isinstance(batch, ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch ## 再返回该批数据

_put_indices依次往不同worker所属的index_queue中放 dataloader.py#L644-L652

完整的dataloader next函数:

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def __next__(self): if self.num_workers == 0: # same-process loading indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch # check if the next sample has already been generated if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) ## 5. 之前以及取出来该下标数据,直接返回 if self.batches_outstanding == 0: self._shutdown_workers() raise StopIteration while True: ## 1.直到取的数据下标正确才return assert (not self.shutdown and self.batches_outstanding > 0) idx, batch = self._get_batch() ## 2.从worker_result_queue中获取数据 self.batches_outstanding -= 1 if idx != self.rcvd_idx: # store out-of-order samples self.reorder_dict[idx] = batch ## 3.下标不对先存一下 continue return self._process_next_batch(batch) ## 4.内部先放下一批数据下标再返回batch数据

3.2 worker_result_queue 返回结果

每个worker一直在执行的循环_worker_loop,其中worker_result_queue作为_worker_loop函数的data_queue传入(dataloader.py#L544-L552),相见:

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id): # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the # logic of this function. try: global _use_shared_memory _use_shared_memory = True # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal happened again already. # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 _set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) data_queue.cancel_join_thread() if init_fn is not None: init_fn(worker_id) watchdog = ManagerWatchdog() while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) ##从index_queue中获取要处理的下标 except queue.Empty: continue if r is None: # Received the final signal assert done_event.is_set() return elif done_event.is_set(): # Done event is set. But I haven't received the final signal # (None) yet. I will keep continuing until get it, and skip the # processing steps. continue idx, batch_indices = r try: samples = collate_fn([dataset[i] for i in batch_indices]) ##1.根据下标取样本数据 except Exception: # It is important that we don't store exc_info in a variable, # see NOTE [ Python Traceback Reference Cycle Problem ] data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: ## 2. 没有抛异常就将样本数据放入结果返回队列 data_queue.put((idx, samples)) del samples except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass

最后

以上就是孝顺背包最近收集整理的关于Dataloader重要参数与内部机制一、pytorch数据输入二、Dataloader参数汇总三、DataLoader的并行一、pytorch数据输入二、Dataloader参数汇总三、DataLoader的并行的全部内容,更多相关Dataloader重要参数与内部机制一、pytorch数据输入二、Dataloader参数汇总三、DataLoader内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(68)

评论列表共有 0 条评论

立即
投稿
返回
顶部