当前位置:网站首页>Torch utils. Data: analyze the whole process of data processing
Torch utils. Data: analyze the whole process of data processing
2022-06-21 15:36:00 【Official openmmlab account】
0 Preface
The source code involved in this article is PyTorch 1.7 Subject to
iterator
understand Python The iterator of is to interpret PyTorch in torch.utils.data The key to the module .
stay Dataset, Sampler and DataLoader All three classes use python Magic methods of abstract classes , Include __len__(self),__getitem__(self) and __iter__(self)
__len__(self): Definition should belen()The behavior of a function call , Returns the number of elements in an iterator__getitem__(self): Defines the behavior of getting the specified element in the container , amount toself[key], That is, class objects are allowed to have index operations__iter__(self): Define the behavior when iterating elements in a container
Iteration means something like a loop , Each repeated process is called an iterative process , The result of each iteration will be used as the initial value of the next iteration . The container that provides iterative methods is called an iterator , The usual iterators that come into contact are sequences ( list 、 Tuples and strings ) And dictionaries , These data structures all support iterative operations .
There are two magic ways to implement iterators :__iter__(self) and __next__(self)
If a container is an iterator , Then it has to be realized __iter__(self) Magic methods , This method actually returns an iterator ( Usually the iterator itself ). The next important thing to realize is __next__(self) Magic methods , Because it determines the rules of iteration .
class Fibs:
def __init__(self, n=20):
self.a = 0
self.b = 1
self.n = n
def __iter__(self):
return self
def __next__(self):
self.a, self.b = self.b, self.a + self.b
if self.a > self.n:
raise StopIteration
return self.a
fibs = Fibs()
for each in fibs:
print(each)
# Output
# 1 1 2 3 5 8 13Generally speaking , Iterators have the following characteristics :
- An iterator is ⼀ Objects
- Iterators can be next() A function ⽤, And back to ⼀ It's worth
- Iterators can be iter() A function ⽤, And return an iterator ( It can be itself )
- Continuous by next() transfer ⽤ Return to ⼀ The value of the series
- If you get to the end of the iteration , Throw out StopIteration abnormal
- Iterators can also have no end , Just be next() transfer ⽤, Just ⼀ Will return to ⼀ It's worth
- Python in , next() Built in function call ⽤ It's about objects next() ⽅ Law
- Python in , iter() Built in function call ⽤ It's about objects iter() ⽅ Law
- ⼀ Objects that implement the iterator protocol can be for The statement iterates through the loop until the end ⽌
After learning what iterators are , We can start to read torch.utils.data modular
about torch.utils.data for , The point is that Dataset, Sampler, DataLoader modular , supplemented collate, fetch, pin_memory And other components to support specific functions .
1 Dataset
Dataset Responsible for raw data source encapsulation , Encapsulate it as Python Recognizable data structures , It must provide an interface to extract individual data .
Dataset share Map-style datasets and Iterable-style datasets Two kinds of :
1.1 Map-style dataset
torch.utils.data.Dataset
It's a way to achieve __getitem__() and __len()__ To get data Dataset, It means from ( It could be a non integer ) Indexes / Keyword to data sample mapping . During the interview , This data set uses dataset[idx] visit idx Corresponding data .
Usually we use Map-style Type of dataset Mostly , Its data interface is defined as follows :
class Dataset(Generic[T_co]):
# Generic is an Abstract base class for generic types.
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])PyTorch All defined in Dataset They are all subclasses .
For general computer vision tasks , We usually do some of these resize, crop, flip And so on
It is worth mentioning that ,PyTorch There is no default in the source code __len__() Method realization , as a result of return NotImplemented perhaps raise NotImplementedError() Default implementations like this have their own problems , This is also reflected in the comments in the source code .
1.2 Iterable-style dataset
torch.utils.data.IterableDataset
It's an implementation __iter__() To get data Dataset, This type of dataset is especially suitable for the following situations : Random reading is expensive or even impossible , And batch size Depends on the data obtained . Its interface is defined as follows :
class IterableDataset(Dataset[T_co]):
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other]) Specially , When DataLoader Of num_workers > 0 when , Every worker Will have different samples of data objects . So each replica needs to be configured independently , To prevent every worker The resulting data is not duplicated . meanwhile , Data loading order Completely user defined Iterative style control for . This allows easier implementation of block read and dynamic batch size ( for example , By producing one batch of samples at a time )
1.3 other Dataset
except Map-style dataset and Iterable-style dataset outside ,PyTorch On this basis, other types of Dataset Subclass
torch.utils.data.ConcatDataset: Used to connect multipleConcatDatasetData setstorch.utils.data.ChainDataset: Used to connect multipleIterableDatasetData sets , stayIterableDatasetOf__add__()Method is calledtorch.utils.data.Subset: It is used to obtain the corresponding sub data set of a specified index sequence
class Subset(Dataset[T_co]):
dataset: Dataset[T_co]
indices: Sequence[int]
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)torch.utils.data.TensorDataset: It is used to obtain the data encapsulated into tensor Data set of , Each sample is obtained by index tensor .
class TensorDataset(Dataset):
def __init__(self, *tensor):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in tensors
def __len__(self):
return self.tensors[0].size(0)2 Sampler
torch.utils.data.Sampler Responsible for providing a way to traverse the index of all elements in the dataset . Can support user-defined , It can also be used. PyTorch Provided , The base class interface is defined as follows :
lass Sampler(Generic[T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source: Optional[Sized]) -> None:
pass
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError Specially ,__len__() Methods are not necessary , But when DataLoader Need to compute len() You have to define , This is also reflected in the comments in the source code .
Again ,PyTorch On this basis, other types of Sampler Subclass
torch.utils.data.SequentialSampler: Sequential sampling , Always in the same ordertorch.utils.data.RandomSampler: You can specify whether to put it back or not , Random sampling of sample elementstorch.utils.data.SubsetRandomSampler: Sampling sample elements according to a given index list without putting them backtorch.utils.data.WeightedRandomSampler: Sampling samples according to a given probability . Sample elements come from[0,…,len(weights)-1], Given the probability ( The weight )torch.utils.data.BatchSampler: In a batch Encapsulate an additional sampler in the , Return to one batch The size of index Indexestorch.utils.data.DistributedSample: Samplers that limit data loading to subsets of the dataset . Andtorch.nn.parallel.DistributedDataParallelUse a combination of . under these circumstances , Every process can putDistributedSamplerInstance asDataLoaderThe sampler passes
3 DataLoader
torch.utils.data.DataLoader yes PyTorch The core of data loading , Responsible for loading data , Support at the same time Map-style and Iterable-style Dataset, Support single process / Multi process , You can also set loading order, batch size, pin memory And so on . Its interface is defined as follows :
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)For the meaning of each parameter , The following is a table for corresponding introduction :
attribute | meaning | default value | type |
|---|---|---|---|
dataset | Data sets that load data | Dataset | |
batch_size | Every batch How many samples to load | 1 | int |
shuffle | Set to True when , call RandomSampler Random index | False | bool |
sampler | Defines the policy for extracting samples from a dataset, if specified , shuffle Parameter must be False,( Otherwise, it will be with RandomSampler Mutually exclusive ) | None | Sampler, Iterable |
batch_sampler | and sampler similar , But in general BatchSampler, One at a time batch Index of size and batch_size, shuffle And so on are mutually exclusive | None | Sampler, Iterable |
num_workers | Number of child processes to use for data loading ,0 Indicates that data will be loaded in the main process | 0 | int |
collate_fn | Will be Map-style datase t The extracted data is integrated into batch When using , Merge the sample list to form a batch | None | callable |
pin_memory | If True, be DataLoader Copy the tensor to before returning it CUDA Fixed memory | False | bool |
drop_last | Set to True Delete the last incomplete batch , If the data set size cannot be divided by the batch size . If False And the size of the data set cannot be divided by the batch size , Then the last batch will be smaller | False | bool |
timeout | If is positive , From the worker collect batch The timeout value , It should always be a non negative number. If the data is not read after this time, an error will be reported | 0 | numeric |
worker_init_fn | If not for None, It will be used by every worker Subprocess call , With worker id ([0, num_workers - 1] The plastic surgery inside ) For input | None | callable |
prefetch_factor | Every worker Advance loading Of sample Number | 2 | int |
persistent_workers | If True,dataloader It will not end worker process , until dataset Iteration complete | False | bool |
From the parameter definition , We can see DataLoader It mainly supports the following functions
- Support loading
map-styleanditerable-styleOf dataset, The main parameters involved aredataset - Custom data loading order , The main parameters involved are
shuffle,sampler,batch_sampler,collate_fn - Automatically organize the data into batch Sequence , The main parameters involved are
batch_size,batch_sampler,collate_fn,drop_last - Single process and multi process data loading , The main parameters involved are
num_workers,worker_init_fn - Automatic page locking memory read (memory pinning), The main parameters involved
pin_memory - Support data preloading , Main parameters involved
prefetch_factor
3.1 The relationship between the three (Dataset, Sampler, Dataloader)
It is not difficult to deduce their internal relationship through the work contents of the three mentioned above :
- Set up Dataset, Put the data data source Package as Dataset class , Exposed extraction interface .
- Set up Sampler, Determine the sampling method . We can learn from Dataset The elements are extracted from the , Still need to set up Sampler Tell the program to extract Dataset The strategy of .
- Will set up Dataset and Sampler Pass in DataLoader, At the same time, you can set
shuffle,batch_sizeEqual parameter . Use DataLoader Objects can be easily and quickly traversed over data sets .
In conclusion , namely Dataloader Responsible for the overall scheduling , command Sampler Define how to traverse the index , Then use the index to Dataset Extract elements from . So we can traverse the given data set .
3.2 The batch
3.2.1 Automatic batch processing ( Default )
DataLoader Support through parameters batch_size, drop_last, batch_sampler, Automatically sort out the extracted data (collate) Batch samples (batch)
batch_size and drop_last Parameter to specify DataLoader How to get dataset Of key. Specially , about map-style Type of dataset, The user can choose to specify batch_sample Parameters , One at a time keys list
In the use of sampler Produced indices When getting the sampled data ,DataLoader Use collate_fn Parameter to organize the sample list into batch. Abstract this process , Its expression is as follows
# For Map-style
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
# For Iterable-style
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])3.2.2 Turn off automatic batch processing
When users want to use dataset The code is handled manually batch, Or just load a single sample data when , Can be batch_size and batch_sampler Set to None, Automatic batching will be turned off . here , from Dataset Produced sample Will be directly collate_fn Handle . Abstract this process , Its expression is as follows :
# For Map-style
for index in sampler:
yield collate_fn(dataset[index])
# For Iterable-style
for data in iter(dataset):
yield collate_fn(data)3.2.3 collate_fn
When auto batch is turned off (automatic batching) when ,collate_fn Acting on a single data sample , It's just PyTorch Transformation in tensors NumPy Array .
When automatic batch is on (automatic batching) when ,collate_fn Act on data sample list , Organize the input sample into a batch, Generally do the following 3 thing
- Add a new batch dimension ( It's usually the first dimension )
- It will automatically NumPy Array and Python The value is converted to PyTorch tensor
- It preserves the data structure , for example , If every sample is
dict, Then output the dictionary with the same key set but batch processed tensor as the value ( or list, When you can't convert ).list,tuples,namedtuplesThe same applies
Customize collate_fn Can be used to customize collations , for example , Populates sequential data to the maximum length of the batch , Add support for custom data types, etc .
3.3 Multiprocessing (multi-process)
To avoid blocking computation code when loading data ,PyTorch Provides a simple switch , Just set the parameters num_workers If it is a positive integer, multi process data loading can be performed , Set to 0 Single threaded data loading is performed when .
4. Single process
In single process mode ,DataLoader The initialization process is the same as the data fetching process . therefore , Data loading may prevent computation .
however , When resources are used to share data between processes ( For example, shared memory , File descriptor ) Finite time , Or when the entire data set is small and can be fully loaded into memory , This mode may be preferred .
Besides , Single process loading usually shows more readable error traces , therefore Useful for debugging .
5. Multi process
In multiprocess mode , Every time DataLoader establish iterator when ( for example , When called enumerate(dataloader)), Will create num_workers Working process .dataset, collate_fn, worker_init_fn Will be sent to every worker in , Every worker All with independent processes .
about map-style data , The main thread will use Sampler produce indice, And send them to worker in . therefore ,shuffle It's done in the main thread
about iterable-style data , Because of every worker They all have the same data Copy the sample , And do different operations in each process , To prevent the output data of each process from being duplicated , So it's commonly used torch.utils.data.get_worker_info() For auxiliary treatment .
here ,torch.utils.data.get_worker_info() return worker Some information about the process (id, dataset, num_workers, seed), If you run in the main thread, return None
Be careful , It is generally not recommended to return... In multiprocess loading CUDA tensor , Because in use CUDA And sharing in multiprocessing CUDA There are many subtleties in tensors ( The document suggests : As long as the receiving process keeps a copy of the tensor , You need to send the process to keep the original tensor ). The proposal USES pin_memory=True , To quickly transfer data to support CUDA Of GPU. In short , It is not recommended to return with multithreading CUDA Of tensor.
6 Lock page memory (Memory Pinning)
First of all, let's explain the concept of lock page memory .
Memory in the host , There are two ways of being , One is the lock page , Second, do not lock the page , The contents stored in the lock page memory will not be exchanged with the virtual memory of the host under any circumstances ( notes : Virtual memory is hard disk ), When there is not enough memory in the lock page , The data is stored in virtual memory . Host to GPU Copy comes from fixed ( Page lock ) Memory time , It's much faster .CPU Tensor and storage expose a kind of pin_memory() Method , This method returns a copy of the object , And put the data in a fixed area .
and All the video memory in the graphics card is lock page memory ! When the computer has enough memory , You can set pin_memory=True. Set up pin_memory=True, It means the generated Tensor Data initially belongs to the lock page memory in memory , In this way, the memory of Tensor Transferred to GPU It'll be faster if you've got more memory . meanwhile , because pin_memory The function of is to copy the tensor to before returning it CUDA Fixed memory , So only in CUDA It works with the support of the environment .
PyTorch Native pin_memory The method is as follows , It supports most of python Data type processing :
def pin_memory(data):
if isinstance(data, torch.Tensor):
return data.pin_memory()
elif isinstance(data, string_classes):
return data
elif isinstance(data, container_abcs.Mapping):
return {k: pin_memory(sample) for k, sample in data.items()}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return type(data)(*(pin_memory(sample) for sample in data))
elif isinstance(data, container_abcs.Sequence):
return [pin_memory(sample) for sample in data]
elif hasattr(data, "pin_memory"):
return data.pin_memory()
else:
return data By default , If fixed logic sees one that belongs to a custom type (custom type) Of batch( If there is a collate_fn Returns a batch of a custom batch type , It will happen ), Or if every element of the batch is custom type, Then fixed logic will not recognize them , It will return the batch ( Or those elements ) Without fixed memory . To enable memory fixation for a custom batch or data type , Need to be pin_memory() Define a method on a custom type . as follows
class SimpleCustomBatch:
# Customize a class , This class cannot be PyTorch Native pin_memory Method supports
def __init__(self, data):
transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0)
self.tgt = torch.stack(transposed_data[1], 0)
# custom memory pinning method on custom type
def pin_memory(self):
self.inp = self.inp.pin_memory()
self.tgt = self.tgt.pin_memory()
return self
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
pin_memory=True)
for batch_ndx, sample in enumerate(loader):
print(sample.inp.is_pinned()) # True
print(sample.tgt.is_pinned()) # True7 Prefetch (prefetch)
DataLoader By designation prefetch_factor ( The default is 2) To prefetch data .
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
...
self._reset(loader, first_iter=True)
def _reset(self, loader, first_iter=False):
...
# prime the prefetch loop
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()You can see through the source code ,prefetch The function only applies to Multi process Loading ( There will be multiple processes dataloader Code analysis of )
8 Code details
Let's take a look at the specific code call process :
for data, label in train_loader:
......for The loop calls dataloader Of __iter__(self) Method , So we get iterators to traverse dataset
class DataLoader(Generic[T_co]):
...
def __iter__(self) -> '_BaseDataLoaderIter':
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator() stay __iter__(self) In the method ,dataloader Called self._get_iterator() Method , according to num_worker Acquiring iterator , And indicate whether single process or multi process
class DataLoader(Generic[T_co]):
...
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self) For the sake of clarity , We only consider single process code . Here is class _SingleProcessDataLoaderIter(_BaseDataLoaderIter) , And its parents class _BaseDataLoaderIter(object): Key code snippet of :
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
# Initialization assignment is a little bit DataLoader Parameters ,
# And verify the validity of user input
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._index_sampler = loader._index_sampler
...
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
self._num_yielded = 0
self._IterableDataset_len_called = loader._IterableDataset_len_called
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
raise NotImplementedError
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data() # Key lines of code , Get data from this
self._num_yielded += 1
...
return data
next = __next__ # Python 2 compatibility
def __len__(self) -> int:
return len(self._index_sampler) # len(_BaseDataLoaderIter) == len(self._index_sampler)
def __getstate__(self):
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)_BaseDataLoaderIter It's all DataLoaderIter Parent class of .dataloader After you get the iterator ,for The loop needs to call __next__() To get the next object , To achieve traversal . adopt __next__ Method call _next_data() get data
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data from _SingleProcessDataLoaderIter You can see that , It's in the parent class _BaseDataLoaderIter Based on the definition of _dataset_fetcher, And pass in _dataset, _auto_collation, _collate_fn Equal parameter , Used to define how to get data . Its implementation will be explained later .
stay _next_data() After being invoked , It needs next_index() obtain index, And by getting index Pass in _dataset_fetcher Get the corresponding sample from
class DataLoader(Generic[T_co]):
...
@property
def _auto_collation(self):
return self.batch_sampler is not None
@property
def _index_sampler(self):
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
class _BaseDataLoaderIter(object):
...
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
...
def _next_index(self):
# sampler_iter From index_sampler
return next(self._sampler_iter) # may raise StopIteration As you can see from here ,dataloader Provides sampler ( It can be batch_sampler Or something else sampler Subclass ), then _SingleProcessDataLoaderIter iteration sampler Get index
Let's take a look at fetcher,fetcher need index To get elements , And at the same time support Map-style dataset( Corresponding _MapDatasetFetcher) and Iterable-style dataset( Corresponding _IterableDatasetFetcher), Make it Dataloader You can use the same interface fetch, The code is more concise .
- about Map-style: Enter the index directly index, As map Of key, Get the corresponding sample ( namely value)
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
# Yes batch_sampler,_auto_collation for True,
# Priority use batch_sampler, Corresponding to fetcher What's coming in is a batch The index of
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)- about Iterable-style:
__init__Method is set with dataset The initial iterator ,fetch Method ,index Actually, it doesn't work much anymore
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
def fetch(self, possibly_batched_index):
if self.auto_collation:
# about batch_sampler( namely auto_collation==True)
# Directly use backward traversal and extract len(possibly_batched_index) Samples ( namely 1 individual batch The sample of )
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
raise StopIteration
else:
# about sampler, Go straight back and extract 1 Samples
data = next(self.dataset_iter)
return self.collate_fn(data)Last , We pass in the index fetcher,fetch Get the sample you want
therefore , Summary of the whole procedure call relationship as follows :
loader.__iter__ --> self._get_iterator() --> class _SingleProcessDataLoaderIter --> class _BaseDataLoaderIter --> __next__() --> self._next_data() --> self._next_index() -->next(self._sampler_iter) namely next(iter(self._index_sampler)) --> get index --> self._dataset_fetcher.fetch(index) --> get data
For multiple processes , To borrow PyTorch Comments on the source code , Its operation process is explained as follows
# Our data model looks like this (queues are indicated with curly brackets):
#
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
# `pin_memory=False`. First dataloader be based on multiprocessing Generate multiple processes , The input and output of each subprocess go through two main queues (multiprocessing.Queue() class ) produce , Respectively :
index_queue: The subscript of the task to be processed in the queue of each subprocess_worker_result_queue: Handle the subscript of the task on returndata_queue: Show that after pin_memory The processed data queue
And there are the following more important flag Parameters to coordinate the various worker Between the work :
_send_idx: Send index , It's used to record this time index_queue in batch Of idx_rcvd_idx: Accept index , Record from data_queue Out of batch Of idx_task_info: Storage is going to produce data The information of dict,key by task idx( from 0 The beginning of the index ),value by(worker_id,)or(worker_id, data), Corresponding data respectively Not taken and Taken The situation of_tasks_outstanding: plastic , The representatives are ready to task/batch The number of ( Maybe some of them are in preparation )
Every worker One at a time batch The data of , return batch Before the data is put into the data subscript of the next batch to be processed , The corresponding constructor subprocess is initialized as follows
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
...
self._worker_result_queue = multiprocessing_context.Queue() # Take this worker Put the number in the queue , For interprocess communication
...
self._workers_done_event = multiprocessing_context.Event()
self._index_queues = []
self._workers = []
for i in range(self._num_workers):
index_queue = multiprocessing_context.Queue() # Index queue , Each subprocess has a queue for subscripts to be processed
index_queue.cancel_join_thread()
# _worker_loop The role of is : from index_queue Take index from , And then through collate_fn Processing data ,
# And then deal with it batch Data placement data_queue in .( Sent to the queue idx yes self.send_idx)
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop, # Every worker Functions executed by a subprocess loop , The data are mainly divided into (idx, data) By _worker_result_queue in
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed + i, self._worker_init_fn, i, self._num_workers,
self._persistent_workers))
w.daemon = True
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()
self._data_queue = queue.Queue() # It is used to access the data pin_memory The result of the operation
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
torch.cuda.current_device(),
self._pin_memory_thread_done_event))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue
...
self._reset(loader, first_iter=True)
def _reset(self, loader, first_iter=False):
super()._reset(loader, first_iter)
self._send_idx = 0 # idx of the next task to be sent to workers, Send index , It's used to record this time index_queue in batch Of idx
self._rcvd_idx = 0 # idx of the next task to be returned in __next__, Accept index , Record from data_queue Out of batch Of idx
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
# \ (worker_id, data) if data is already fetched (out-of-order)
self._task_info = {}
# _tasks_outstanding Indicates what is currently ready task/batch The number of ( Maybe some of them are in preparation )
# The initial value is 0, stay self._try_put_index() in +1, stay self._next_data in -1
self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
# this indicates status that a worker still has work to do *for this epoch*.
self._workers_status = [True for i in range(self._num_workers)]
# We resume the prefetching in case it was enabled
if not first_iter:
for idx in range(self._num_workers):
self._index_queues[idx].put(_utils.worker._ResumeIteration())
resume_iteration_cnt = self._num_workers
while resume_iteration_cnt > 0:
data = self._get_data()
if isinstance(data, _utils.worker._ResumeIteration):
resume_iteration_cnt -= 1
...
# When initializing , will 2*num_workers individual (batch_idx, sampler_indices) Put it in index_queue in
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index() # Prefetch dataloader When initializing , Every worker Of index_queue It will be put in by default Two batch Of index, from index_queue Take out the subscript to be processed
def _try_put_index(self):
# self._prefetch_factor The default is 2
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
try:
index = self._next_index()
except StopIteration:
return
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
break
else:
# not found (i.e., didn't break)
return
self._index_queues[worker_queue_idx].put((self._send_idx, index)) # Put in Task subscript and Data subscript
self._task_info[self._send_idx] = (worker_queue_idx,)
# _tasks_outstanding + 1, Show that you are ready batch Number +1
self._tasks_outstanding += 1
# send_idx Send index , Record from sample_iter Send index to index_queue The number of times
self._send_idx += 1 call _next_data Method to read data , among _process_data Used to return data
def _next_data(self):
while True:
while self._rcvd_idx < self._send_idx: # Make sure that the tasks to be dealt with ( To be taken batch) Subscript > Task to return after processing ( It's finished batch) Subscript
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
if not self._persistent_workers:
self._shutdown_workers()
raise StopIteration
# Now `self._rcvd_idx` is the batch index we want to fetch
# Check if the next sample has already been generated
if len(self._task_info[self._rcvd_idx]) == 2:
data = self._task_info.pop(self._rcvd_idx)[1]
return self._process_data(data)
assert not self._shutdown and self._tasks_outstanding > 0
idx, data = self._get_data() # call self._try_get_data() from self._data_queue Middle out
self._tasks_outstanding -= 1 # Show that you are ready batch The number needs to be reduced 1
if self._dataset_kind == _DatasetKind.Iterable:
# Check for _IterableDatasetStopIteration
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
if self._persistent_workers:
self._workers_status[data.worker_id] = False
else:
self._mark_worker_as_unavailable(data.worker_id)
self._try_put_index()
continue
if idx != self._rcvd_idx:
# store out-of-order samples
self._task_info[idx] += (data,)
else:
del self._task_info[idx]
return self._process_data(data) # Return the data
def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index() # ditto , Mainly put into the queue index as well as to update flag
if isinstance(data, ExceptionWrapper):
data.reraise()
return datasuch , Multithreading dataloader You can go through multiple worker To complete the data loading together .
Reference resources
边栏推荐
- Program for counting black and white pixel values in pictures
- MySQL transaction
- GO语言-接口
- Phantom star VR product details 34: Happy pitching
- Get the mobile number of QQ friends through exhaustive search
- 小蓝做实验(统计质数个数)
- Fflush(), fflush (stdin), fflush (stdout) in C language
- First Canary deployment with rancher
- Metric win computer application
- Phantom star VR product details 32: Infinite War
猜你喜欢

P24 de noise

Apple was fined by Dutch regulators, totaling about RMB 180million

Metric win computer application
![[go] goroutine pool](/img/0c/4e78c59f9b4f963035c911cee3d62d.jpg)
[go] goroutine pool
Counter attack of flour dregs: MySQL 66 questions, 20000 words + 50 pictures!

New project template of punctual atom F103 based on firmware library

Build an efficient and scalable result cache
![[Yugong series] February 2022 wechat applet -app Networktimeout of JSON configuration attribute](/img/51/dcd0062dbf5fbbd04c6fc3737a0be2.jpg)
[Yugong series] February 2022 wechat applet -app Networktimeout of JSON configuration attribute

Non local network: early human attempts to tame transformer in CV | CVPR 2018

Build an efficient and scalable result cache
随机推荐
Promotion guide for large enterprises: material preparation, PPT writing and on-site defense
I don't really want to open an account online. Is it safe to open an account online
Shared memory communication between processes
Perfect partner of ebpf: cilium connected to cloud native network
Gee Registration Guide
Go admin framework analysis (2-1)
C multithreading
[Yugong series] February 2022 wechat applet -app Networktimeout of JSON configuration attribute
Reasonably set the number of threads 【 rpm 】
Build an efficient and scalable result cache
Best practice | how to use Tencent cloud micro build to develop enterprise portal applications from 0 to 1
After the uproar, is the yuan universe "cool"?
2022 latest MySQL interview questions
Someone is storing credit card data - how do they do it- Somebody is storing credit card data - how are they doing it?
理财产品预约赎回确认日是什么?
Fflush(), fflush (stdin), fflush (stdout) in C language
[leetcode] sum of two numbers - go language solution
Redis5.0 installation and production startup steps
Idea restart
Integration of sparkstreaming and sparksql