mirror of https://github.com/hpcaitech/ColossalAI
33 lines
1.3 KiB
Python
33 lines
1.3 KiB
Python
|
import torch.distributed as dist
|
||
|
|
||
|
from colossalai.context import ParallelMode
|
||
|
from colossalai.core import global_context as gpc
|
||
|
|
||
|
|
||
|
class barrier_context():
|
||
|
"""
|
||
|
This context manager is used to allow one process to execute while blocking all
|
||
|
other processes in the same process group. This is often useful when downloading is required
|
||
|
as we only want to download in one process to prevent file corruption.
|
||
|
Args:
|
||
|
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
|
||
|
parallel_mode (ParallelMode): the parallel mode corresponding to a process group
|
||
|
Usage:
|
||
|
with barrier_context():
|
||
|
dataset = CIFAR10(root='./data', download=True)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL):
|
||
|
# the class name is lowercase by convention
|
||
|
current_rank = gpc.get_local_rank(parallel_mode=parallel_mode)
|
||
|
self.should_block = current_rank != executor_rank
|
||
|
self.group = gpc.get_group(parallel_mode=parallel_mode)
|
||
|
|
||
|
def __enter__(self):
|
||
|
if self.should_block:
|
||
|
dist.barrier(group=self.group)
|
||
|
|
||
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||
|
if not self.should_block:
|
||
|
dist.barrier(group=self.group)
|