|
|
|
@ -1,11 +1,36 @@
|
|
|
|
|
import math |
|
|
|
|
from typing import Optional |
|
|
|
|
from typing import Iterable, Optional |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataPrefetcher: |
|
|
|
|
""" |
|
|
|
|
Asynchronously prefetch data from the loader via a separate stream to improve the performance. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, loader: Iterable): |
|
|
|
|
self.loader = iter(loader) |
|
|
|
|
self.stream = torch.cuda.Stream() |
|
|
|
|
self.preload() |
|
|
|
|
|
|
|
|
|
def next(self): |
|
|
|
|
self.stream.synchronize() |
|
|
|
|
data = self.data |
|
|
|
|
self.preload() |
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
def preload(self): |
|
|
|
|
try: |
|
|
|
|
with torch.cuda.stream(self.stream): |
|
|
|
|
self.data = next(self.loader) |
|
|
|
|
except StopIteration: |
|
|
|
|
self.data = None |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def flatten(input_): |
|
|
|
|
return _flatten_dense_tensors(input_) |
|
|
|
|
|
|
|
|
@ -190,6 +215,7 @@ def calculate_global_norm_from_list(norm_list):
|
|
|
|
|
total_norm += norm**2.0 |
|
|
|
|
return math.sqrt(total_norm) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sync_tensor(flat_tensor, tensor_list): |
|
|
|
|
""" |
|
|
|
|
Synchronize the flattened tensor and unflattened tensor list. When |
|
|
|
|