mirror of https://github.com/hpcaitech/ColossalAI
20 lines
393 B
Python
20 lines
393 B
Python
from typing import Any
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
def is_rank_0() -> bool:
|
|
return not dist.is_initialized() or dist.get_rank() == 0
|
|
|
|
|
|
def to_device(x: Any, device: torch.device) -> Any:
|
|
|
|
def _to(t: Any):
|
|
if isinstance(t, torch.Tensor):
|
|
return t.to(device)
|
|
return t
|
|
|
|
return tree_map(_to, x)
|