Browse Source

feat: add `DataPrefetcher`

pull/5817/head
Wenhao Chen 8 months ago committed by アマデウス
parent
commit
93aaa21d4a
  1. 28
      colossalai/zero/low_level/_utils.py

28
colossalai/zero/low_level/_utils.py

@ -1,11 +1,36 @@
import math
from typing import Optional
from typing import Iterable, Optional
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class DataPrefetcher:
"""
Asynchronously prefetch data from the loader via a separate stream to improve the performance.
"""
def __init__(self, loader: Iterable):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.preload()
def next(self):
self.stream.synchronize()
data = self.data
self.preload()
return data
def preload(self):
try:
with torch.cuda.stream(self.stream):
self.data = next(self.loader)
except StopIteration:
self.data = None
return
def flatten(input_):
return _flatten_dense_tensors(input_)
@ -190,6 +215,7 @@ def calculate_global_norm_from_list(norm_list):
total_norm += norm**2.0
return math.sqrt(total_norm)
def sync_tensor(flat_tensor, tensor_list):
"""
Synchronize the flattened tensor and unflattened tensor list. When

Loading…
Cancel
Save