mirror of https://github.com/hpcaitech/ColossalAI
feat: add `DataPrefetcher`
parent
a1ab2d374e
commit
93aaa21d4a
|
@ -1,11 +1,36 @@
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
|
||||||
|
|
||||||
|
class DataPrefetcher:
|
||||||
|
"""
|
||||||
|
Asynchronously prefetch data from the loader via a separate stream to improve the performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loader: Iterable):
|
||||||
|
self.loader = iter(loader)
|
||||||
|
self.stream = torch.cuda.Stream()
|
||||||
|
self.preload()
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
self.stream.synchronize()
|
||||||
|
data = self.data
|
||||||
|
self.preload()
|
||||||
|
return data
|
||||||
|
|
||||||
|
def preload(self):
|
||||||
|
try:
|
||||||
|
with torch.cuda.stream(self.stream):
|
||||||
|
self.data = next(self.loader)
|
||||||
|
except StopIteration:
|
||||||
|
self.data = None
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def flatten(input_):
|
def flatten(input_):
|
||||||
return _flatten_dense_tensors(input_)
|
return _flatten_dense_tensors(input_)
|
||||||
|
|
||||||
|
@ -190,6 +215,7 @@ def calculate_global_norm_from_list(norm_list):
|
||||||
total_norm += norm**2.0
|
total_norm += norm**2.0
|
||||||
return math.sqrt(total_norm)
|
return math.sqrt(total_norm)
|
||||||
|
|
||||||
|
|
||||||
def sync_tensor(flat_tensor, tensor_list):
|
def sync_tensor(flat_tensor, tensor_list):
|
||||||
"""
|
"""
|
||||||
Synchronize the flattened tensor and unflattened tensor list. When
|
Synchronize the flattened tensor and unflattened tensor list. When
|
||||||
|
|
Loading…
Reference in New Issue