2023-04-26 10:11:49 +00:00
|
|
|
from typing import Any
|
|
|
|
|
2023-04-17 06:46:50 +00:00
|
|
|
import torch
|
2023-04-26 10:11:49 +00:00
|
|
|
import torch.distributed as dist
|
|
|
|
from torch.utils._pytree import tree_map
|
2023-06-29 02:48:09 +00:00
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
|
|
|
class CycledDataLoader:
|
|
|
|
"""
|
|
|
|
Why do we need this class?
|
|
|
|
In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain.
|
|
|
|
However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...)
|
|
|
|
NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior.
|
|
|
|
"""
|
|
|
|
|
2023-07-18 10:02:35 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
dataloader: DataLoader,
|
|
|
|
) -> None:
|
2023-06-29 02:48:09 +00:00
|
|
|
self.dataloader = dataloader
|
|
|
|
|
|
|
|
self.count = 0
|
2023-08-02 02:17:36 +00:00
|
|
|
self.dataloader_iter = None
|
2023-06-29 02:48:09 +00:00
|
|
|
|
|
|
|
def next(self):
|
2023-08-02 02:17:36 +00:00
|
|
|
# defer initialization
|
|
|
|
if self.dataloader_iter is None:
|
|
|
|
self.dataloader_iter = iter(self.dataloader)
|
|
|
|
|
2023-06-29 02:48:09 +00:00
|
|
|
self.count += 1
|
|
|
|
try:
|
|
|
|
return next(self.dataloader_iter)
|
|
|
|
except StopIteration:
|
|
|
|
self.count = 0
|
|
|
|
self.dataloader_iter = iter(self.dataloader)
|
|
|
|
return next(self.dataloader_iter)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
def is_rank_0() -> bool:
|
|
|
|
return not dist.is_initialized() or dist.get_rank() == 0
|
2023-04-17 06:46:50 +00:00
|
|
|
|
|
|
|
|
2023-04-26 10:11:49 +00:00
|
|
|
def to_device(x: Any, device: torch.device) -> Any:
|
|
|
|
def _to(t: Any):
|
|
|
|
if isinstance(t, torch.Tensor):
|
|
|
|
return t.to(device)
|
|
|
|
return t
|
|
|
|
|
|
|
|
return tree_map(_to, x)
|