import tempfile
from contextlib import contextmanager, nullcontext
from typing import Iterator

import torch.distributed as dist


@contextmanager
def shared_tempdir() -> Iterator[str]:
    """
    A temporary directory that is shared across all processes.
    """
    ctx_fn = tempfile.TemporaryDirectory if dist.get_rank() == 0 else nullcontext
    with ctx_fn() as tempdir:
        try:
            obj = [tempdir]
            dist.broadcast_object_list(obj, src=0)
            tempdir = obj[0]    # use the same directory on all ranks
            yield tempdir
        finally:
            dist.barrier()