mirror of https://github.com/hpcaitech/ColossalAI
[example] opt does not depend on Titans (#1811)
parent
6fa71d65d3
commit
350ccc0481
|
@ -0,0 +1,32 @@
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
|
class barrier_context():
|
||||||
|
"""
|
||||||
|
This context manager is used to allow one process to execute while blocking all
|
||||||
|
other processes in the same process group. This is often useful when downloading is required
|
||||||
|
as we only want to download in one process to prevent file corruption.
|
||||||
|
Args:
|
||||||
|
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
|
||||||
|
parallel_mode (ParallelMode): the parallel mode corresponding to a process group
|
||||||
|
Usage:
|
||||||
|
with barrier_context():
|
||||||
|
dataset = CIFAR10(root='./data', download=True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL):
|
||||||
|
# the class name is lowercase by convention
|
||||||
|
current_rank = gpc.get_local_rank(parallel_mode=parallel_mode)
|
||||||
|
self.should_block = current_rank != executor_rank
|
||||||
|
self.group = gpc.get_group(parallel_mode=parallel_mode)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self.should_block:
|
||||||
|
dist.barrier(group=self.group)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
|
if not self.should_block:
|
||||||
|
dist.barrier(group=self.group)
|
File diff suppressed because one or more lines are too long
|
@ -3,3 +3,4 @@ torch >= 1.8.1
|
||||||
datasets >= 1.8.0
|
datasets >= 1.8.0
|
||||||
sentencepiece != 0.1.92
|
sentencepiece != 0.1.92
|
||||||
protobuf
|
protobuf
|
||||||
|
accelerate == 0.13.2
|
||||||
|
|
|
@ -32,9 +32,9 @@ import datasets
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from context import barrier_context
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from titans.utils import barrier_context
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from utils import colo_memory_cap
|
from utils import colo_memory_cap
|
||||||
|
|
Loading…
Reference in New Issue