mirror of https://github.com/hpcaitech/ColossalAI
[example] opt does not depend on Titans (#1811)
parent
6fa71d65d3
commit
350ccc0481
|
@ -0,0 +1,32 @@
|
|||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class barrier_context():
|
||||
"""
|
||||
This context manager is used to allow one process to execute while blocking all
|
||||
other processes in the same process group. This is often useful when downloading is required
|
||||
as we only want to download in one process to prevent file corruption.
|
||||
Args:
|
||||
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
|
||||
parallel_mode (ParallelMode): the parallel mode corresponding to a process group
|
||||
Usage:
|
||||
with barrier_context():
|
||||
dataset = CIFAR10(root='./data', download=True)
|
||||
"""
|
||||
|
||||
def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL):
|
||||
# the class name is lowercase by convention
|
||||
current_rank = gpc.get_local_rank(parallel_mode=parallel_mode)
|
||||
self.should_block = current_rank != executor_rank
|
||||
self.group = gpc.get_group(parallel_mode=parallel_mode)
|
||||
|
||||
def __enter__(self):
|
||||
if self.should_block:
|
||||
dist.barrier(group=self.group)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
if not self.should_block:
|
||||
dist.barrier(group=self.group)
|
File diff suppressed because one or more lines are too long
|
@ -3,3 +3,4 @@ torch >= 1.8.1
|
|||
datasets >= 1.8.0
|
||||
sentencepiece != 0.1.92
|
||||
protobuf
|
||||
accelerate == 0.13.2
|
||||
|
|
|
@ -32,9 +32,9 @@ import datasets
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.utils import set_seed
|
||||
from context import barrier_context
|
||||
from datasets import load_dataset
|
||||
from packaging import version
|
||||
from titans.utils import barrier_context
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from utils import colo_memory_cap
|
||||
|
|
Loading…
Reference in New Issue