ColossalAI/applications/Chat/coati/trainer/utils.py

20 lines
393 B
Python
Raw Normal View History

from typing import Any
import torch
import torch.distributed as dist
from torch.utils._pytree import tree_map
2023-03-28 12:25:36 +00:00
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
def to_device(x: Any, device: torch.device) -> Any:
def _to(t: Any):
if isinstance(t, torch.Tensor):
return t.to(device)
return t
return tree_map(_to, x)