You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/applications/Chat/coati/trainer/utils.py

51 lines
1.4 KiB

from typing import Any
import torch
import torch.distributed as dist
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
class CycledDataLoader:
"""
Why do we need this class?
In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain.
However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...)
NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior.
"""
def __init__(
self,
dataloader: DataLoader,
) -> None:
self.dataloader = dataloader
self.count = 0
self.dataloader_iter = None
def next(self):
# defer initialization
if self.dataloader_iter is None:
self.dataloader_iter = iter(self.dataloader)
self.count += 1
try:
return next(self.dataloader_iter)
except StopIteration:
self.count = 0
self.dataloader_iter = iter(self.dataloader)
return next(self.dataloader_iter)
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
def to_device(x: Any, device: torch.device) -> Any:
def _to(t: Any):
if isinstance(t, torch.Tensor):
return t.to(device)
return t
return tree_map(_to, x)