ColossalAI/colossalai/zero/low_level/bookkeeping/tensor_bucket.py

82 lines
2.9 KiB
Python

from typing import Optional
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
class TensorBucket:
def __init__(self, size):
self._max_size = size
self._current_size = 0
self._bucket = []
self._write_back_pairs = {}
@property
def max_size(self):
return self._max_size
@property
def current_size(self):
return self._current_size
def is_full_or_oversized(self):
return self._current_size >= self._max_size
def is_empty(self):
return len(self._bucket) == 0
def add_to_bucket(self, tensor, allow_oversize=False, write_back_tensor: Optional[torch.Tensor] = None):
tensor_size = tensor.numel()
if not allow_oversize and self.will_exceed_max_size(tensor_size):
msg = f"The param bucket max size {self._max_size} is exceeded" + f"by tensor (size {tensor_size})"
raise RuntimeError(msg)
self._bucket.append(tensor)
self._current_size += tensor_size
write_back_tensor = write_back_tensor if write_back_tensor is not None else tensor
self._write_back_pairs[tensor] = write_back_tensor
def will_exceed_max_size(self, tensor_size):
expected_size = self._current_size + tensor_size
return expected_size > self._max_size
def get_bucket(self):
return self._bucket
def empty(self):
self._bucket = []
self._current_size = 0
self._write_back_pairs = {}
def flatten(self):
return _flatten_dense_tensors(self._bucket)
def unflatten(self, flat_tensor):
return _unflatten_dense_tensors(flat_tensor, self._bucket)
def unflatten_and_copy(self, flat_tensor):
unflattened_tensor_list = self.unflatten(flat_tensor)
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)
def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten()
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
if fp8_communication:
all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group)
else:
dist.all_gather_into_tensor(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
# transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
write_back_tensor = self._write_back_pairs[tensor]
write_back_tensor.data.copy_(
_flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor)
)
self.empty()