2023-05-19 11:42:31 +00:00
|
|
|
import tempfile
|
|
|
|
from contextlib import contextmanager, nullcontext
|
|
|
|
from typing import Iterator
|
|
|
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def shared_tempdir() -> Iterator[str]:
|
|
|
|
"""
|
|
|
|
A temporary directory that is shared across all processes.
|
|
|
|
"""
|
|
|
|
ctx_fn = tempfile.TemporaryDirectory if dist.get_rank() == 0 else nullcontext
|
|
|
|
with ctx_fn() as tempdir:
|
|
|
|
try:
|
|
|
|
obj = [tempdir]
|
|
|
|
dist.broadcast_object_list(obj, src=0)
|
2023-09-19 06:20:26 +00:00
|
|
|
tempdir = obj[0] # use the same directory on all ranks
|
2023-05-19 11:42:31 +00:00
|
|
|
yield tempdir
|
|
|
|
finally:
|
|
|
|
dist.barrier()
|