You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_checkpoint_io/utils.py

22 lines
609 B

import tempfile
from contextlib import contextmanager, nullcontext
from typing import Iterator
import torch.distributed as dist
@contextmanager
def shared_tempdir() -> Iterator[str]:
"""
A temporary directory that is shared across all processes.
"""
ctx_fn = tempfile.TemporaryDirectory if dist.get_rank() == 0 else nullcontext
with ctx_fn() as tempdir:
try:
obj = [tempdir]
dist.broadcast_object_list(obj, src=0)
tempdir = obj[0] # use the same directory on all ranks
yield tempdir
finally:
dist.barrier()