mirror of https://github.com/hpcaitech/ColossalAI
22 lines
609 B
Python
22 lines
609 B
Python
|
import tempfile
|
||
|
from contextlib import contextmanager, nullcontext
|
||
|
from typing import Iterator
|
||
|
|
||
|
import torch.distributed as dist
|
||
|
|
||
|
|
||
|
@contextmanager
|
||
|
def shared_tempdir() -> Iterator[str]:
|
||
|
"""
|
||
|
A temporary directory that is shared across all processes.
|
||
|
"""
|
||
|
ctx_fn = tempfile.TemporaryDirectory if dist.get_rank() == 0 else nullcontext
|
||
|
with ctx_fn() as tempdir:
|
||
|
try:
|
||
|
obj = [tempdir]
|
||
|
dist.broadcast_object_list(obj, src=0)
|
||
|
tempdir = obj[0] # use the same directory on all ranks
|
||
|
yield tempdir
|
||
|
finally:
|
||
|
dist.barrier()
|