mirror of https://github.com/hpcaitech/ColossalAI
feat: add `DataPrefetcher`
parent
a1ab2d374e
commit
93aaa21d4a
|
@ -1,11 +1,36 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
|
||||
class DataPrefetcher:
|
||||
"""
|
||||
Asynchronously prefetch data from the loader via a separate stream to improve the performance.
|
||||
"""
|
||||
|
||||
def __init__(self, loader: Iterable):
|
||||
self.loader = iter(loader)
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.preload()
|
||||
|
||||
def next(self):
|
||||
self.stream.synchronize()
|
||||
data = self.data
|
||||
self.preload()
|
||||
return data
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.data = next(self.loader)
|
||||
except StopIteration:
|
||||
self.data = None
|
||||
return
|
||||
|
||||
|
||||
def flatten(input_):
|
||||
return _flatten_dense_tensors(input_)
|
||||
|
||||
|
@ -190,6 +215,7 @@ def calculate_global_norm_from_list(norm_list):
|
|||
total_norm += norm**2.0
|
||||
return math.sqrt(total_norm)
|
||||
|
||||
|
||||
def sync_tensor(flat_tensor, tensor_list):
|
||||
"""
|
||||
Synchronize the flattened tensor and unflattened tensor list. When
|
||||
|
|
Loading…
Reference in New Issue