diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index de08ecf3d..f3faa1d55 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -1,11 +1,36 @@ import math -from typing import Optional +from typing import Iterable, Optional import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +class DataPrefetcher: + """ + Asynchronously prefetch data from the loader via a separate stream to improve the performance. + """ + + def __init__(self, loader: Iterable): + self.loader = iter(loader) + self.stream = torch.cuda.Stream() + self.preload() + + def next(self): + self.stream.synchronize() + data = self.data + self.preload() + return data + + def preload(self): + try: + with torch.cuda.stream(self.stream): + self.data = next(self.loader) + except StopIteration: + self.data = None + return + + def flatten(input_): return _flatten_dense_tensors(input_) @@ -190,6 +215,7 @@ def calculate_global_norm_from_list(norm_list): total_norm += norm**2.0 return math.sqrt(total_norm) + def sync_tensor(flat_tensor, tensor_list): """ Synchronize the flattened tensor and unflattened tensor list. When