[example] opt does not depend on Titans (#1811)

pull/1807/head
Jiarui Fang 2022-11-08 12:02:20 +08:00 committed by GitHub
parent 6fa71d65d3
commit 350ccc0481
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 11 deletions

View File

@ -0,0 +1,32 @@
import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
class barrier_context():
"""
This context manager is used to allow one process to execute while blocking all
other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption.
Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
parallel_mode (ParallelMode): the parallel mode corresponding to a process group
Usage:
with barrier_context():
dataset = CIFAR10(root='./data', download=True)
"""
def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL):
# the class name is lowercase by convention
current_rank = gpc.get_local_rank(parallel_mode=parallel_mode)
self.should_block = current_rank != executor_rank
self.group = gpc.get_group(parallel_mode=parallel_mode)
def __enter__(self):
if self.should_block:
dist.barrier(group=self.group)
def __exit__(self, exc_type, exc_value, exc_traceback):
if not self.should_block:
dist.barrier(group=self.group)

File diff suppressed because one or more lines are too long

View File

@ -3,3 +3,4 @@ torch >= 1.8.1
datasets >= 1.8.0 datasets >= 1.8.0
sentencepiece != 0.1.92 sentencepiece != 0.1.92
protobuf protobuf
accelerate == 0.13.2

View File

@ -32,9 +32,9 @@ import datasets
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.utils import set_seed from accelerate.utils import set_seed
from context import barrier_context
from datasets import load_dataset from datasets import load_dataset
from packaging import version from packaging import version
from titans.utils import barrier_context
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
from utils import colo_memory_cap from utils import colo_memory_cap